Spaces:
Sleeping
Sleeping
File size: 8,193 Bytes
48854b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
{
"cells": [
{
"cell_type": "markdown",
"id": "7b73f12d-1104-4eea-ac08-3716aa9af45b",
"metadata": {},
"source": [
"**Zero shot classification**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"import os\n",
"import requests\n",
"\n",
"load_dotenv() # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n",
"\n",
"API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n",
"headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n",
"\n",
"def query(payload):\n",
" response = requests.post(API_URL, headers=headers, json=payload)\n",
" return response.json()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n"
]
}
],
"source": [
"# Input text to classify\n",
"input_text = \"I just bought a new laptop, and it works amazing!\"\n",
"\n",
"# Candidate labels\n",
"candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n",
"\n",
"# Get the prediction\n",
"output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
"print(output)\n"
]
},
{
"cell_type": "markdown",
"id": "fb7e69c7-b590-4b40-8478-76d055583f2a",
"metadata": {},
"source": [
"**Try packing list labels**"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n"
]
}
],
"source": [
"# Input text to classify\n",
"input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
"\n",
"# Candidate labels\n",
"candidate_labels = [\n",
" \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
" \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
" \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
" \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
" \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
" \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
" \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
" \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
" \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
" \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
" \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
" \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
" \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
" \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
" \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
" \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
" \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
" \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
" \"Toys\", \"Nursing cover\"\n",
"]\n",
"\n",
"\n",
"# Get the prediction\n",
"output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
"metadata": {},
"source": [
"**Use batches of 10 labels and combine results**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
"{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
"\n",
"Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
]
}
],
"source": [
"\n",
"input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
"\n",
"\n",
"# Define the full list of possible packing items (split into groups of 10)\n",
"candidate_labels = [\n",
" [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
" \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
" \n",
" [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
" \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
"]\n",
"\n",
"# Run classification in batches\n",
"packing_list = []\n",
"for batch in candidate_labels:\n",
" result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
" print(result)\n",
" for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
" if score > 0.1: # Adjust threshold as needed\n",
" packing_list.append(label)\n",
"\n",
"# Print the final packing list\n",
"print(\"\\nRecommended packing list:\", packing_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "953b244c-0611-4706-a941-eac5064c643f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (huggingface_env)",
"language": "python",
"name": "huggingface_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.20"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|