{ "cells": [ { "cell_type": "markdown", "id": "7b73f12d-1104-4eea-ac08-3716aa9af45b", "metadata": {}, "source": [ "**Zero shot classification**" ] }, { "cell_type": "code", "execution_count": 1, "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb", "metadata": {}, "outputs": [], "source": [ "from dotenv import load_dotenv\n", "import os\n", "import requests\n", "\n", "load_dotenv() # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n", "\n", "API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n", "# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n", "# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n", "# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n", "headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n", "\n", "def query(payload):\n", " response = requests.post(API_URL, headers=headers, json=payload)\n", " return response.json()\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n" ] } ], "source": [ "# Input text to classify\n", "input_text = \"I just bought a new laptop, and it works amazing!\"\n", "\n", "# Candidate labels\n", "candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n", "\n", "# Get the prediction\n", "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n", "print(output)\n" ] }, { "cell_type": "markdown", "id": "fb7e69c7-b590-4b40-8478-76d055583f2a", "metadata": {}, "source": [ "**Try packing list labels**" ] }, { "cell_type": "code", "execution_count": 25, "id": "c5f75916-aaf2-4ca7-8d1a-070579940952", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n" ] } ], "source": [ "# Input text to classify\n", "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n", "\n", "# Candidate labels\n", "candidate_labels = [\n", " \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n", " \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n", " \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n", " \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n", " \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n", " \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n", " \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n", " \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n", " \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n", " \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n", " \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n", " \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n", " \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n", " \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n", " \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n", " \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n", " \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n", " \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n", " \"Toys\", \"Nursing cover\"\n", "]\n", "\n", "\n", "# Get the prediction\n", "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n", "print(output)" ] }, { "cell_type": "markdown", "id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0", "metadata": {}, "source": [ "**Use batches of 10 labels and combine results**" ] }, { "cell_type": "code", "execution_count": 11, "id": "fe42a222-5ff4-4442-93f4-42fc22001af6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n", "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n", "\n", "Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n" ] } ], "source": [ "\n", "input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n", "\n", "\n", "# Define the full list of possible packing items (split into groups of 10)\n", "candidate_labels = [\n", " [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n", " \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n", " \n", " [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n", " \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n", "]\n", "\n", "# Run classification in batches\n", "packing_list = []\n", "for batch in candidate_labels:\n", " result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n", " print(result)\n", " for label, score in zip(result[\"labels\"], result[\"scores\"]):\n", " if score > 0.1: # Adjust threshold as needed\n", " packing_list.append(label)\n", "\n", "# Print the final packing list\n", "print(\"\\nRecommended packing list:\", packing_list)" ] }, { "cell_type": "code", "execution_count": null, "id": "953b244c-0611-4706-a941-eac5064c643f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (huggingface_env)", "language": "python", "name": "huggingface_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.20" } }, "nbformat": 4, "nbformat_minor": 5 }