Spaces:
Sleeping
Sleeping
File size: 8,193 Bytes
48854b7 |
|
{
"cells": [
{
"cell_type": "markdown",
"id": "7b73f12d-1104-4eea-ac08-3716aa9af45b",
"metadata": {},
"source": [
"**Zero shot classification**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"import os\n",
"import requests\n",
"\n",
"load_dotenv() # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n",
"\n",
"API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n",
"# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n",
"headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n",
"\n",
"def query(payload):\n",
" response = requests.post(API_URL, headers=headers, json=payload)\n",
" return response.json()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n"
]
}
],
"source": [
"# Input text to classify\n",
"input_text = \"I just bought a new laptop, and it works amazing!\"\n",
"\n",
"# Candidate labels\n",
"candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n",
"\n",
"# Get the prediction\n",
"output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
"print(output)\n"
]
},
{
"cell_type": "markdown",
"id": "fb7e69c7-b590-4b40-8478-76d055583f2a",
"metadata": {},
"source": [
"**Try packing list labels**"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n"
]
}
],
"source": [
"# Input text to classify\n",
"input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
"\n",
"# Candidate labels\n",
"candidate_labels = [\n",
" \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
" \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
" \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
" \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
" \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
" \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
" \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
" \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
" \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
" \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
" \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
" \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
" \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
" \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
" \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
" \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
" \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
" \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
" \"Toys\", \"Nursing cover\"\n",
"]\n",
"\n",
"\n",
"# Get the prediction\n",
"output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
"metadata": {},
"source": [
"**Use batches of 10 labels and combine results**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
"{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
"\n",
"Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
]
}
],
"source": [
"\n",
"input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
"\n",
"\n",
"# Define the full list of possible packing items (split into groups of 10)\n",
"candidate_labels = [\n",
" [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
" \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
" \n",
" [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
" \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
"]\n",
"\n",
"# Run classification in batches\n",
"packing_list = []\n",
"for batch in candidate_labels:\n",
" result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
" print(result)\n",
" for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
" if score > 0.1: # Adjust threshold as needed\n",
" packing_list.append(label)\n",
"\n",
"# Print the final packing list\n",
"print(\"\\nRecommended packing list:\", packing_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "953b244c-0611-4706-a941-eac5064c643f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (huggingface_env)",
"language": "python",
"name": "huggingface_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.20"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|