File size: 8,193 Bytes
48854b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7b73f12d-1104-4eea-ac08-3716aa9af45b",
   "metadata": {},
   "source": [
    "**Zero shot classification**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import requests\n",
    "\n",
    "load_dotenv()  # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n",
    "\n",
    "API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n",
    "# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n",
    "# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n",
    "# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n",
    "headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n",
    "\n",
    "def query(payload):\n",
    "    response = requests.post(API_URL, headers=headers, json=payload)\n",
    "    return response.json()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n"
     ]
    }
   ],
   "source": [
    "# Input text to classify\n",
    "input_text = \"I just bought a new laptop, and it works amazing!\"\n",
    "\n",
    "# Candidate labels\n",
    "candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n",
    "\n",
    "# Get the prediction\n",
    "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
    "print(output)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb7e69c7-b590-4b40-8478-76d055583f2a",
   "metadata": {},
   "source": [
    "**Try packing list labels**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n"
     ]
    }
   ],
   "source": [
    "# Input text to classify\n",
    "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
    "\n",
    "# Candidate labels\n",
    "candidate_labels = [\n",
    "    \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
    "    \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
    "    \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
    "    \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
    "    \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
    "    \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
    "    \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
    "    \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
    "    \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
    "    \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
    "    \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
    "    \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
    "    \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
    "    \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
    "    \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
    "    \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
    "    \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
    "    \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
    "    \"Toys\", \"Nursing cover\"\n",
    "]\n",
    "\n",
    "\n",
    "# Get the prediction\n",
    "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
   "metadata": {},
   "source": [
    "**Use batches of 10 labels and combine results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
      "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
      "\n",
      "Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
    "\n",
    "\n",
    "# Define the full list of possible packing items (split into groups of 10)\n",
    "candidate_labels = [\n",
    "    [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
    "     \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
    "    \n",
    "    [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
    "     \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
    "]\n",
    "\n",
    "# Run classification in batches\n",
    "packing_list = []\n",
    "for batch in candidate_labels:\n",
    "    result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
    "    print(result)\n",
    "    for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
    "        if score > 0.1:  # Adjust threshold as needed\n",
    "            packing_list.append(label)\n",
    "\n",
    "# Print the final packing list\n",
    "print(\"\\nRecommended packing list:\", packing_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "953b244c-0611-4706-a941-eac5064c643f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (huggingface_env)",
   "language": "python",
   "name": "huggingface_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}