AnjaJuana commited on
Commit
b869c66
·
1 Parent(s): 1e431ad

Auto-update from GitHub

Browse files
labels.txt CHANGED
@@ -1,89 +1,93 @@
1
- Swimsuit
2
- Sunscreen
3
- Flip-flops
4
- Beach towel
5
- Sunglasses
6
- Waterproof phone case
7
- Hat
8
- Beach bag
9
- Snorkel gear
10
- Aloe vera gel
11
- Tent
12
- Sleeping bag
13
- Camping stove
14
- Flashlight
15
- Hiking boots
16
- Water filter
17
- Compass
18
- First aid kit
19
- Bug spray
20
- Multi-tool
21
- Thermal clothing
22
- Ski jacket
23
- Ski goggles
24
- Snow boots
25
- Gloves
26
- Hand warmers
27
- Beanie
28
- Lip balm
29
- Snowboard
30
- Base layers
31
- Passport
32
- Visa documents
33
- Travel adapter
34
- Currency
35
- Language phrasebook
36
- SIM card
37
- Travel pillow
38
- Neck wallet
39
- Travel insurance documents
40
- Power bank
41
- Laptop
42
- Notebook
43
- Business attire
44
- Dress shoes
45
- Charging cables
46
- Presentation materials
47
- Work ID badge
48
- Pen
49
- Headphones
50
- Lightweight backpack
51
- Travel-sized toiletries
52
- Packable rain jacket
53
- Reusable water bottle
54
- Dry bag
55
- Trekking poles
56
- Hostel lock
57
- Quick-dry towel
58
- Travel journal
59
- Energy bars
60
- Car charger
61
- Snacks
62
- Map
63
- Sunglasses
64
- Cooler
65
- Blanket
66
- Emergency roadside kit
67
- Reusable coffee mug
68
- Playlist
69
- Reusable shopping bags
70
- Earplugs
71
- Fanny pack
72
- Portable charger
73
- Poncho
74
- Bandana
75
- Comfortable shoes
76
- Tent
77
- Refillable water bottle
78
- Glow sticks
79
- Festival tickets
80
- Diapers
81
- Baby wipes
82
- Baby food
83
- Stroller
84
- Pacifier
85
- Baby clothes
86
- Baby blanket
87
- Travel crib
88
- Toys
89
- Nursing cover
 
 
 
 
 
1
+ sunscreen
2
+ sunglasses
3
+ sunhat
4
+ flip-flops
5
+ swimsuit
6
+ beach towel
7
+ waterproof phone case
8
+ beach bag
9
+ snorkel gear
10
+ tent
11
+ sleeping bag
12
+ camping stove
13
+ flashlight
14
+ hiking boots
15
+ water filter
16
+ compass
17
+ first aid kit
18
+ bug spray
19
+ multi-tool
20
+ thermal clothing
21
+ ski jacket
22
+ ski goggles
23
+ snow boots
24
+ gloves
25
+ hand warmers
26
+ beanie
27
+ lip balm
28
+ snowboard
29
+ base layers
30
+ passport
31
+ visa documents
32
+ travel adapter
33
+ currency
34
+ sim card
35
+ travel pillow
36
+ neck wallet
37
+ travel insurance documents
38
+ power bank
39
+ laptop
40
+ business attire
41
+ dress shoes
42
+ charging cables
43
+ pen
44
+ headphones
45
+ lightweight backpack
46
+ travel-sized toiletries
47
+ packable rain jacket
48
+ dry bag
49
+ trekking poles
50
+ hostel lock
51
+ quick-dry towel
52
+ travel journal
53
+ snacks
54
+ blanket
55
+ emergency roadside kit
56
+ reusable coffee mug
57
+ reusable shopping bags
58
+ earplugs
59
+ fanny pack
60
+ poncho
61
+ bandana
62
+ comfortable shoes
63
+ bathing suit
64
+ sandals
65
+ light jacket
66
+ entertainment for downtime (e.g. book/ebook, games, laptop, journal)
67
+ short pants/skirts
68
+ t-shirts/tops
69
+ thin scarf
70
+ pants
71
+ shirts
72
+ cardigan/sweater
73
+ gifts
74
+ winter shoes
75
+ long pants
76
+ short pants
77
+ malaria medication
78
+ mosquito repellant
79
+ local currency
80
+ wallet
81
+ tickets
82
+ phone and charger
83
+ painkiller
84
+ necessary medication
85
+ personal toiletries (e.g. toothbrush, toothpaste)
86
+ underwear
87
+ socks
88
+ sleep wear
89
+ snacks for the journey
90
+ refillable water bottle
91
+ day pack
92
+ big backpack/suitcase
93
+
packing_list_api.ipynb CHANGED
@@ -124,7 +124,58 @@
124
  },
125
  {
126
  "cell_type": "code",
127
- "execution_count": 5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  "id": "d0d8f7c0-c2d9-4fbe-b1a7-699a5b99466c",
129
  "metadata": {},
130
  "outputs": [
@@ -140,82 +191,89 @@
140
  }
141
  ],
142
  "source": [
143
- "from transformers import pipeline\n",
144
- "\n",
145
- "# Load the model and create a pipeline for zero-shot classification\n",
146
- "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")"
147
  ]
148
  },
149
  {
150
  "cell_type": "code",
151
- "execution_count": 6,
152
- "id": "4682d620-c9a6-40ad-ab4c-268ee0ef7212",
153
  "metadata": {},
154
  "outputs": [
155
  {
156
  "name": "stderr",
157
  "output_type": "stream",
158
  "text": [
159
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
160
  ]
161
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  {
163
- "name": "stdout",
164
  "output_type": "stream",
165
  "text": [
166
- "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Swimsuit', 'Travel crib', 'Business attire', 'Toys', 'Notebook', 'Travel adapter', 'Compass', 'Travel pillow', 'Headphones', 'Travel journal', 'Playlist', 'Flip-flops', 'Hiking boots', 'Reusable coffee mug', 'Comfortable shoes', 'Nursing cover', 'Gloves', 'Tent', 'Tent', 'Sunglasses', 'Sunglasses', 'Charging cables', 'Travel-sized toiletries', 'Refillable water bottle', 'Energy bars', 'Dress shoes', 'Festival tickets', 'Lightweight backpack', 'Packable rain jacket', 'Flashlight', 'Hostel lock', 'Presentation materials', 'Thermal clothing', 'Snowboard', 'Camping stove', 'Reusable shopping bags', 'Reusable water bottle', 'Blanket', 'Diapers', 'Snorkel gear', 'Snacks', 'Emergency roadside kit', 'Beach towel', 'Sunscreen', 'Car charger', 'Bug spray', 'Passport', 'Currency', 'Beach bag', 'Ski jacket', 'First aid kit', 'Cooler', 'Quick-dry towel', 'Laptop', 'Aloe vera gel', 'Earplugs', 'Baby wipes', 'Ski goggles', 'Travel insurance documents', 'Portable charger', 'Beanie', 'Bandana', 'Multi-tool', 'Pacifier', 'Stroller', 'Language phrasebook', 'Waterproof phone case', 'Dry bag', 'Map', 'Lip balm', 'Fanny pack', 'Trekking poles', 'Power bank', 'Baby clothes', 'Baby food', 'Poncho', 'Sleeping bag', 'Work ID badge', 'Visa documents', 'SIM card', 'Water filter', 'Snow boots', 'Hand warmers', 'Baby blanket', 'Base layers', 'Pen', 'Hat', 'Neck wallet', 'Glow sticks'], 'scores': [0.012542711570858955, 0.012216676957905293, 0.012068654410541058, 0.011977529153227806, 0.011932261288166046, 0.011920000426471233, 0.011883101426064968, 0.011842883192002773, 0.011819617822766304, 0.011810989119112492, 0.011761271394789219, 0.011756575666368008, 0.011726364493370056, 0.011664840392768383, 0.011632450856268406, 0.01163020171225071, 0.01158054918050766, 0.011572858318686485, 0.011572858318686485, 0.011541635729372501, 0.011541635729372501, 0.011517350561916828, 0.011510960757732391, 0.011489875614643097, 0.011469963937997818, 0.011466587893664837, 0.011442759074270725, 0.011438597925007343, 0.011437375098466873, 0.011433145962655544, 0.011407203041017056, 0.011401104740798473, 0.01135423593223095, 0.011333385482430458, 0.011328010819852352, 0.011325137689709663, 0.01131997536867857, 0.011306566186249256, 0.011299673467874527, 0.011281789280474186, 0.011264320462942123, 0.011257764883339405, 0.011256475001573563, 0.011253912933170795, 0.011252702213823795, 0.011248898692429066, 0.011247594840824604, 0.011239985004067421, 0.01121864840388298, 0.011208567768335342, 0.011174682527780533, 0.011166973039507866, 0.011159253306686878, 0.011151333339512348, 0.011140624061226845, 0.011139076203107834, 0.01113345380872488, 0.011126152239739895, 0.011093570850789547, 0.011078842915594578, 0.011067545972764492, 0.011044573038816452, 0.01101986039429903, 0.011016158387064934, 0.011015082709491253, 0.011007890105247498, 0.010997296310961246, 0.010962157510221004, 0.01095755398273468, 0.010940180160105228, 0.01088095735758543, 0.010869039222598076, 0.010858545079827309, 0.010820968076586723, 0.01080892514437437, 0.010798529721796513, 0.01077410951256752, 0.010764310136437416, 0.010748079977929592, 0.010681436397135258, 0.010675576515495777, 0.010557047091424465, 0.010552684776484966, 0.010509641841053963, 0.010396942496299744, 0.01037551462650299, 0.01033466774970293, 0.010237698443233967, 0.009954877197742462]}\n"
167
  ]
168
  }
169
  ],
170
  "source": [
171
- "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
172
- "\n",
173
- "# Candidate labels\n",
174
- "candidate_labels = [\n",
175
- " \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
176
- " \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
177
- " \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
178
- " \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
179
- " \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
180
- " \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
181
- " \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
182
- " \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
183
- " \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
184
- " \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
185
- " \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
186
- " \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
187
- " \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
188
- " \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
189
- " \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
190
- " \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
191
- " \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
192
- " \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
193
- " \"Toys\", \"Nursing cover\"\n",
194
- "]\n",
195
- "\n",
196
- "\n",
197
- "# Run the classification\n",
198
- "result = classifier(input_text, candidate_labels)\n",
199
- "\n",
200
- "# Print the result\n",
201
- "print(result)"
202
  ]
203
  },
204
  {
205
  "cell_type": "code",
206
  "execution_count": null,
207
- "id": "a344a80f-7645-4c2c-b960-580aa0b345f6",
208
  "metadata": {},
209
  "outputs": [],
210
- "source": []
 
 
 
211
  },
212
  {
213
  "cell_type": "code",
214
  "execution_count": null,
215
- "id": "5eb705d6-c31c-406c-9739-ff45b66c7ca4",
216
  "metadata": {},
217
  "outputs": [],
218
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  },
220
  {
221
  "cell_type": "code",
@@ -224,21 +282,17 @@
224
  "metadata": {},
225
  "outputs": [],
226
  "source": [
227
- "# Example text to classify\n",
228
- "text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
229
- "\n",
230
  "# No prompt\n",
231
- "no_prompt = text\n",
232
- "no_result = classifier(no_prompt, candidate_labels)\n",
233
- "\n",
234
  "\n",
235
  "# Simple prompt\n",
236
- "simple_prompt = \"Classify the following text: \" + text\n",
237
- "simple_result = classifier(simple_prompt, candidate_labels)\n",
238
  "\n",
239
  "# Primed prompt\n",
240
- "primed_prompt = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july. What are the most important things to pack for the trip?\"\n",
241
- "primed_result = classifier(primed_prompt, candidate_labels)"
242
  ]
243
  },
244
  {
@@ -436,9 +490,6 @@
436
  }
437
  ],
438
  "source": [
439
- "from tabulate import tabulate\n",
440
- "\n",
441
- "\n",
442
  "# Creating a table\n",
443
  "table = zip(no_result[\"labels\"], no_result[\"scores\"], \n",
444
  " simple_result[\"labels\"], simple_result[\"scores\"], \n",
@@ -447,14 +498,6 @@
447
  "\n",
448
  "print(tabulate(table, headers=headers, tablefmt=\"grid\"))\n"
449
  ]
450
- },
451
- {
452
- "cell_type": "code",
453
- "execution_count": null,
454
- "id": "5ed9bda0-41f2-4c7c-b055-27c1998c1d4e",
455
- "metadata": {},
456
- "outputs": [],
457
- "source": []
458
  }
459
  ],
460
  "metadata": {
 
124
  },
125
  {
126
  "cell_type": "code",
127
+ "execution_count": 35,
128
+ "id": "1d01a363-572b-450c-8fce-0721234f9a1a",
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "First trip: 7-Day Island Beach Holiday in Greece (Summer). I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands. \n",
136
+ "\n",
137
+ "Packing list: ['bathing suit', 'beach towel', 'beach bag', 'sandals', 'comfortable walking shoes', 'light jacket', 'sunscreen', 'sunglasses', 'sunhat', 'entertainment for downtime (e.g. book/ebook, games, laptop, journal)', 'short pants/skirts', 't-shirts/tops']\n"
138
+ ]
139
+ }
140
+ ],
141
+ "source": [
142
+ "# Prerequisites\n",
143
+ "from tabulate import tabulate\n",
144
+ "from transformers import pipeline\n",
145
+ "import json\n",
146
+ "\n",
147
+ "# input text\n",
148
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
149
+ "\n",
150
+ "# Load labels from a txt file\n",
151
+ "with open(\"labels.txt\", \"r\", encoding=\"utf-8\") as f:\n",
152
+ " class_labels = [line.strip() for line in f if line.strip()]\n",
153
+ "\n",
154
+ "# Load test data (in dictionary)\n",
155
+ "with open(\"test_data.json\", \"r\") as file:\n",
156
+ " packing_data = json.load(file)\n",
157
+ "# Get a list of trip descriptions (keys)\n",
158
+ "trips = list(packing_data.keys())\n",
159
+ "# Access the first trip description\n",
160
+ "first_trip = trips[0]\n",
161
+ "# Get the packing list for the second trip\n",
162
+ "first_trip_items = packing_data[first_trip]\n",
163
+ "\n",
164
+ "print(f\"First trip: {first_trip} \\n\")\n",
165
+ "print(f\"Packing list: {first_trip_items}\")"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "id": "88aa1d7e-8a32-4530-9ddd-60fa38e4a342",
171
+ "metadata": {},
172
+ "source": [
173
+ "Load classifiers"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 36,
179
  "id": "d0d8f7c0-c2d9-4fbe-b1a7-699a5b99466c",
180
  "metadata": {},
181
  "outputs": [
 
191
  }
192
  ],
193
  "source": [
194
+ "# Load smaller the model and create a pipeline for zero-shot classification (1min loading + classifying with 89 labels)\n",
195
+ "classifier_bart_base = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")"
 
 
196
  ]
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": 37,
201
+ "id": "a971ca1c-d478-489f-9592-bc243d587eb4",
202
  "metadata": {},
203
  "outputs": [
204
  {
205
  "name": "stderr",
206
  "output_type": "stream",
207
  "text": [
208
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
209
  ]
210
+ }
211
+ ],
212
+ "source": [
213
+ "# Load larger the model and create a pipeline for zero-shot classification (5min loading model + classifying with 89 labels)\n",
214
+ "classifier_bart_large_mnli = pipeline(\"zero-shot-classification\", model=\"facebook/bart-large-mnli\")"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "38805499-9919-40fe-9d42-de6869ba01dc",
220
+ "metadata": {},
221
+ "source": [
222
+ "Try classifiers"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 38,
228
+ "id": "abb13524-71c6-448d-948d-fb22a0e0ceeb",
229
+ "metadata": {},
230
+ "outputs": [
231
  {
232
+ "name": "stderr",
233
  "output_type": "stream",
234
  "text": [
235
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
236
  ]
237
  }
238
  ],
239
  "source": [
240
+ "# Run the classification (ca 30 seconds classifying)\n",
241
+ "result_bart_base = classifier_bart_base(first_trip, class_labels)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  ]
243
  },
244
  {
245
  "cell_type": "code",
246
  "execution_count": null,
247
+ "id": "116c7ee3-2b59-4623-a416-162c487aab70",
248
  "metadata": {},
249
  "outputs": [],
250
+ "source": [
251
+ "# Run the classification (ca 1 minute classifying)\n",
252
+ "result_bart_large_mnli = classifier_bart_large_mnli(first_trip, class_labels)"
253
+ ]
254
  },
255
  {
256
  "cell_type": "code",
257
  "execution_count": null,
258
+ "id": "8591425b-ce55-4a36-a4b6-70974e8d4e59",
259
  "metadata": {},
260
  "outputs": [],
261
+ "source": [
262
+ "# Creating a table\n",
263
+ "table = zip(result_bart_base[\"labels\"], \n",
264
+ " result_bart_large_mnli[\"labels\"])\n",
265
+ "headers = [\"bart_base\", \"bart_large_mnli\"]\n",
266
+ "\n",
267
+ "print(tabulate(table, headers=headers, tablefmt=\"grid\"))\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "markdown",
272
+ "id": "21a35d0c-9451-433a-b14c-87e8dac21d68",
273
+ "metadata": {},
274
+ "source": [
275
+ "**Try simple prompt engineering**"
276
+ ]
277
  },
278
  {
279
  "cell_type": "code",
 
282
  "metadata": {},
283
  "outputs": [],
284
  "source": [
 
 
 
285
  "# No prompt\n",
286
+ "no_prompt = input_text\n",
287
+ "no_result = classifier(no_prompt, class_labels)\n",
 
288
  "\n",
289
  "# Simple prompt\n",
290
+ "simple_prompt = \"Classify the following text: \" + input_text\n",
291
+ "simple_result = classifier(simple_prompt, class_labels)\n",
292
  "\n",
293
  "# Primed prompt\n",
294
+ "primed_prompt = input_text + \"What are the most important things to pack for the trip?\"\n",
295
+ "primed_result = classifier(primed_prompt, class_labels)"
296
  ]
297
  },
298
  {
 
490
  }
491
  ],
492
  "source": [
 
 
 
493
  "# Creating a table\n",
494
  "table = zip(no_result[\"labels\"], no_result[\"scores\"], \n",
495
  " simple_result[\"labels\"], simple_result[\"scores\"], \n",
 
498
  "\n",
499
  "print(tabulate(table, headers=headers, tablefmt=\"grid\"))\n"
500
  ]
 
 
 
 
 
 
 
 
501
  }
502
  ],
503
  "metadata": {
space/.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -1,25 +1,25 @@
1
- import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Initialize the zero-shot classification pipeline
5
  classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
 
7
- # Define the classification function
8
- def classify_text(text, labels):
9
- labels = labels.split(",") # Convert the comma-separated string into a list
10
- result = classifier(text, candidate_labels=labels)
11
- return result
12
 
13
- # Set up the Gradio interface
14
- with gr.Blocks() as demo:
15
- gr.Markdown("# Zero-Shot Classification")
16
- text_input = gr.Textbox(label="Input Text")
17
- label_input = gr.Textbox(label="Comma-separated Labels")
18
- output = gr.JSON(label="Result")
19
- classify_button = gr.Button("Classify")
20
 
21
- # Link the button to the classification function
22
- classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)
 
 
 
 
 
23
 
24
- # Launch the Gradio interface
25
- demo.launch()
 
 
 
1
  from transformers import pipeline
2
+ import gradio as gr
3
 
4
+ # Load the model and create a pipeline for zero-shot classification
5
  classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
 
7
+ # Load labels from a txt file
8
+ with open("labels.txt", "r", encoding="utf-8") as f:
9
+ class_labels = [line.strip() for line in f if line.strip()]
 
 
10
 
11
+ # Define the Gradio interface
12
+ def classify(text):
13
+ return classifier(text, class_labels)
 
 
 
 
14
 
15
+ demo = gr.Interface(
16
+ fn=classify,
17
+ inputs="text",
18
+ outputs="json",
19
+ title="Zero-Shot Classification",
20
+ description="Enter a text describing your trip",
21
+ )
22
 
23
+ # Launch the Gradio app
24
+ if __name__ == "__main__":
25
+ demo.launch()
space/.ipynb_checkpoints/gradio_tryout-checkpoint.ipynb CHANGED
@@ -1,6 +1,186 @@
1
  {
2
- "cells": [],
3
- "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "nbformat": 4,
5
  "nbformat_minor": 5
6
  }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e25090fa-f990-4f1a-84f3-b12159eedae8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Try out gradio"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "3bbee2e4-55c8-4b06-9929-72026edf7932",
14
+ "metadata": {},
15
+ "source": [
16
+ "Try model"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "id": "fa0d8126-e346-4412-9197-7d51baf868da",
23
+ "metadata": {
24
+ "scrolled": true
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n",
32
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
33
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
34
+ "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.\n",
35
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
36
+ ]
37
+ },
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Map', 'Compass', 'Laptop', 'Car charger', 'Toys', 'Travel crib', 'Hat', 'Playlist', 'Stroller', 'Currency', 'Travel adapter', 'Hostel lock', 'Pen', 'Charging cables', 'Flip-flops', 'Pacifier', 'Camping stove', 'Multi-tool', 'Passport', 'Poncho', 'Hiking boots', 'Portable charger', 'Power bank', 'Trekking poles', 'Snowboard', 'Base layers', 'Bandana', 'Aloe vera gel', 'Gloves', 'Baby blanket', 'Tent', 'Tent', 'Snorkel gear', 'Water filter', 'Diapers', 'Presentation materials', 'Nursing cover', 'Headphones', 'Sunscreen', 'Beach towel', 'Snacks', 'Ski jacket', 'Earplugs', 'Ski goggles', 'Flashlight', 'Neck wallet', 'Swimsuit', 'Notebook', 'Thermal clothing', 'Blanket', 'Snow boots', 'Sleeping bag', 'Lightweight backpack', 'Refillable water bottle', 'Quick-dry towel', 'Comfortable shoes', 'Reusable shopping bags', 'Travel journal', 'Travel pillow', 'Beach bag', 'Reusable coffee mug', 'Reusable water bottle', 'Festival tickets', 'Waterproof phone case', 'Business attire', 'Sunglasses', 'Sunglasses', 'Cooler', 'Baby clothes', 'Fanny pack', 'Beanie', 'First aid kit', 'Emergency roadside kit', 'Dry bag', 'SIM card', 'Energy bars', 'Baby food', 'Work ID badge', 'Packable rain jacket', 'Hand warmers', 'Visa documents', 'Glow sticks', 'Bug spray', 'Travel-sized toiletries', 'Dress shoes', 'Language phrasebook', 'Baby wipes', 'Lip balm', 'Travel insurance documents'], 'scores': [0.013028442859649658, 0.012909057550132275, 0.0124660674482584, 0.012431488372385502, 0.012379261665046215, 0.012377972714602947, 0.012329353019595146, 0.012096051126718521, 0.012086767703294754, 0.011947661638259888, 0.011939236894249916, 0.011935302056372166, 0.011887168511748314, 0.011814153753221035, 0.011788924224674702, 0.011783207766711712, 0.01177265401929617, 0.011771135963499546, 0.011747810058295727, 0.011738969013094902, 0.01169698778539896, 0.01166312862187624, 0.011658026836812496, 0.011596457101404667, 0.01158847101032734, 0.011561167426407337, 0.011526867747306824, 0.01149983424693346, 0.011472185142338276, 0.011455104686319828, 0.011445573531091213, 0.011445573531091213, 0.011444379575550556, 0.011416648514568806, 0.01136692427098751, 0.011363024823367596, 0.011361461132764816, 0.011328471824526787, 0.011299548670649529, 0.011291779577732086, 0.011282541789114475, 0.01127372495830059, 0.011270811781287193, 0.011263585649430752, 0.011179029010236263, 0.011149592697620392, 0.01113132108002901, 0.011122703552246094, 0.011105425655841827, 0.011101326905190945, 0.011090466752648354, 0.011066330596804619, 0.011058374308049679, 0.011055233888328075, 0.01103114802390337, 0.011022195219993591, 0.011012199334800243, 0.01100123766809702, 0.010985593311488628, 0.010961917228996754, 0.010958753526210785, 0.010938071645796299, 0.010903625749051571, 0.010879918932914734, 0.010863620787858963, 0.010824359022080898, 0.010824359022080898, 0.010805793106555939, 0.010763236321508884, 0.010710005648434162, 0.010690474882721901, 0.010647830553352833, 0.010583569295704365, 0.010571518912911415, 0.010570857673883438, 0.010552200488746166, 0.0105352271348238, 0.010523369535803795, 0.010514546185731888, 0.010479346849024296, 0.010450395755469799, 0.010436479933559895, 0.01043587177991867, 0.010400519706308842, 0.010214710608124733, 0.010052643716335297, 0.010041419416666031, 0.010003888048231602, 0.009946384467184544]}\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "from transformers import pipeline\n",
48
+ "import gradio as gr\n",
49
+ "\n",
50
+ "# Load the model and create a pipeline for zero-shot classification\n",
51
+ "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")\n",
52
+ "\n",
53
+ "# Load labels from a txt file\n",
54
+ "with open(\"labels.txt\", \"r\", encoding=\"utf-8\") as f:\n",
55
+ " class_labels = [line.strip() for line in f if line.strip()]\n",
56
+ "\n",
57
+ "# Example text to classify\n",
58
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
59
+ "\n",
60
+ "# Perform classification\n",
61
+ "result = classifier(input_text, class_labels)\n",
62
+ "\n",
63
+ "print(result)"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "8e856a9c-a66c-4c4b-b7cf-8c52abbbc6fa",
69
+ "metadata": {},
70
+ "source": [
71
+ "Use model with gradio"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 2,
77
+ "id": "521d9118-b59d-4cc6-b637-20202eaf8f33",
78
+ "metadata": {
79
+ "scrolled": true
80
+ },
81
+ "outputs": [
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Running on local URL: http://127.0.0.1:7860\n",
87
+ "\n",
88
+ "To create a public link, set `share=True` in `launch()`.\n"
89
+ ]
90
+ },
91
+ {
92
+ "data": {
93
+ "text/html": [
94
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
95
+ ],
96
+ "text/plain": [
97
+ "<IPython.core.display.HTML object>"
98
+ ]
99
+ },
100
+ "metadata": {},
101
+ "output_type": "display_data"
102
+ }
103
+ ],
104
+ "source": [
105
+ "# Define the Gradio interface\n",
106
+ "def classify(text):\n",
107
+ " return classifier(text, class_labels)\n",
108
+ "\n",
109
+ "demo = gr.Interface(\n",
110
+ " fn=classify,\n",
111
+ " inputs=\"text\",\n",
112
+ " outputs=\"json\",\n",
113
+ " title=\"Zero-Shot Classification\",\n",
114
+ " description=\"Enter a text describing your trip\",\n",
115
+ ")\n",
116
+ "\n",
117
+ "# Launch the Gradio app\n",
118
+ "if __name__ == \"__main__\":\n",
119
+ " demo.launch()"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "id": "d6526d18-6ba6-4a66-8310-21337b832d84",
125
+ "metadata": {},
126
+ "source": [
127
+ "Simple app"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "5496ded9-7294-4da4-af05-00e5846cdd04",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "import gradio as gr\n",
138
+ "from transformers import pipeline\n",
139
+ "\n",
140
+ "# Initialize the zero-shot classification pipeline\n",
141
+ "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")\n",
142
+ "\n",
143
+ "# Define the classification function\n",
144
+ "def classify_text(text, labels):\n",
145
+ " labels = labels.split(\",\") # Convert the comma-separated string into a list\n",
146
+ " result = classifier(text, candidate_labels=labels)\n",
147
+ " return result\n",
148
+ "\n",
149
+ "# Set up the Gradio interface\n",
150
+ "with gr.Blocks() as demo:\n",
151
+ " gr.Markdown(\"# Zero-Shot Classification\")\n",
152
+ " text_input = gr.Textbox(label=\"Input Text\")\n",
153
+ " label_input = gr.Textbox(label=\"Comma-separated Labels\")\n",
154
+ " output = gr.JSON(label=\"Result\")\n",
155
+ " classify_button = gr.Button(\"Classify\")\n",
156
+ "\n",
157
+ " # Link the button to the classification function\n",
158
+ " classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)\n",
159
+ "\n",
160
+ "# Launch the Gradio interface\n",
161
+ "demo.launch()"
162
+ ]
163
+ }
164
+ ],
165
+ "metadata": {
166
+ "kernelspec": {
167
+ "display_name": "Python (huggingface_env)",
168
+ "language": "python",
169
+ "name": "huggingface_env"
170
+ },
171
+ "language_info": {
172
+ "codemirror_mode": {
173
+ "name": "ipython",
174
+ "version": 3
175
+ },
176
+ "file_extension": ".py",
177
+ "mimetype": "text/x-python",
178
+ "name": "python",
179
+ "nbconvert_exporter": "python",
180
+ "pygments_lexer": "ipython3",
181
+ "version": "3.8.20"
182
+ }
183
+ },
184
  "nbformat": 4,
185
  "nbformat_minor": 5
186
  }
space/app.py CHANGED
@@ -1,25 +1,25 @@
1
- import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Initialize the zero-shot classification pipeline
5
  classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
 
7
- # Define the classification function
8
- def classify_text(text, labels):
9
- labels = labels.split(",") # Convert the comma-separated string into a list
10
- result = classifier(text, candidate_labels=labels)
11
- return result
12
 
13
- # Set up the Gradio interface
14
- with gr.Blocks() as demo:
15
- gr.Markdown("# Zero-Shot Classification")
16
- text_input = gr.Textbox(label="Input Text")
17
- label_input = gr.Textbox(label="Comma-separated Labels")
18
- output = gr.JSON(label="Result")
19
- classify_button = gr.Button("Classify")
20
 
21
- # Link the button to the classification function
22
- classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)
 
 
 
 
 
23
 
24
- # Launch the Gradio interface
25
- demo.launch()
 
 
 
1
  from transformers import pipeline
2
+ import gradio as gr
3
 
4
+ # Load the model and create a pipeline for zero-shot classification
5
  classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
 
7
+ # Load labels from a txt file
8
+ with open("labels.txt", "r", encoding="utf-8") as f:
9
+ class_labels = [line.strip() for line in f if line.strip()]
 
 
10
 
11
+ # Define the Gradio interface
12
+ def classify(text):
13
+ return classifier(text, class_labels)
 
 
 
 
14
 
15
+ demo = gr.Interface(
16
+ fn=classify,
17
+ inputs="text",
18
+ outputs="json",
19
+ title="Zero-Shot Classification",
20
+ description="Enter a text describing your trip",
21
+ )
22
 
23
+ # Launch the Gradio app
24
+ if __name__ == "__main__":
25
+ demo.launch()
space/gradio_tryout.ipynb CHANGED
@@ -18,7 +18,7 @@
18
  },
19
  {
20
  "cell_type": "code",
21
- "execution_count": 6,
22
  "id": "fa0d8126-e346-4412-9197-7d51baf868da",
23
  "metadata": {
24
  "scrolled": true
@@ -39,7 +39,7 @@
39
  "name": "stdout",
40
  "output_type": "stream",
41
  "text": [
42
- "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Compass', 'Travel insurance documents', 'Cooler', 'Poncho', 'Comfortable shoes', 'Thermal clothing', 'Business attire', 'Stroller', 'Refillable water bottle', 'Sunscreen', 'Hiking boots', 'Trekking poles', 'Tent', 'Tent', 'Swimsuit', 'Lightweight backpack', 'Diapers', 'Pen', 'Lip balm', 'Bandana', 'Presentation materials', 'Snorkel gear', 'Sunglasses', 'Sunglasses', 'Snowboard', 'Baby wipes', 'Emergency roadside kit', 'Blanket', 'Passport', 'Aloe vera gel', 'Currency', 'Beanie', 'Hand warmers', 'Reusable shopping bags', 'Hat', 'Travel-sized toiletries', 'Waterproof phone case', 'Energy bars', 'Baby food', 'Reusable water bottle', 'Flashlight', 'Gloves', 'Baby clothes', 'Hostel lock', 'Visa documents', 'Camping stove', 'Bug spray', 'Packable rain jacket', 'Travel pillow', 'Power bank', 'Earplugs', 'Quick-dry towel', 'Reusable coffee mug', 'Travel journal', 'Fanny pack', 'Headphones', 'Notebook', 'Dress shoes', 'Nursing cover', 'Playlist', 'Base layers', 'Work ID badge', 'Festival tickets', 'Sleeping bag', 'Laptop', 'Baby blanket', 'Charging cables', 'Snow boots', 'First aid kit', 'Snacks', 'Flip-flops', 'Toys', 'Car charger', 'Ski jacket', 'Dry bag', 'Pacifier', 'Map', 'Portable charger', 'Travel crib', 'Multi-tool', 'Beach bag', 'Ski goggles', 'SIM card', 'Glow sticks', 'Beach towel', 'Travel adapter', 'Neck wallet', 'Language phrasebook', 'Water filter'], 'scores': [0.011984821408987045, 0.011970506981015205, 0.011933253146708012, 0.011915490962564945, 0.011904211714863777, 0.011892491020262241, 0.01188766211271286, 0.011866495944559574, 0.011842762120068073, 0.011789090000092983, 0.011770269833505154, 0.011769718490540981, 0.011746660806238651, 0.011746660806238651, 0.011718676425516605, 0.01164235919713974, 0.011551206931471825, 0.011529732495546341, 0.011518468149006367, 0.011516833677887917, 0.011508049443364143, 0.011507270857691765, 0.01149584911763668, 0.01149584911763668, 0.011495097540318966, 0.01149324607104063, 0.011486946605145931, 0.01148668210953474, 0.011478666216135025, 0.011473646387457848, 0.011412998661398888, 0.011398673988878727, 0.011378799565136433, 0.01135518029332161, 0.011335738934576511, 0.011330211535096169, 0.011329339817166328, 0.011324702762067318, 0.01131915021687746, 0.01131164189428091, 0.011294065974652767, 0.011273612268269062, 0.011272135190665722, 0.011252084746956825, 0.01122584380209446, 0.011216048151254654, 0.011204490438103676, 0.011203117668628693, 0.01117485947906971, 0.01117344293743372, 0.011145292781293392, 0.011137993074953556, 0.011128612793982029, 0.011123239994049072, 0.011122280731797218, 0.011065744794905186, 0.011053262278437614, 0.011045967228710651, 0.011041177436709404, 0.011033336631953716, 0.01102971937507391, 0.0110141197219491, 0.01100961398333311, 0.011002525687217712, 0.010937424376606941, 0.01093329582363367, 0.010918675921857357, 0.010917853564023972, 0.010890142060816288, 0.01088369358330965, 0.010871977545320988, 0.010870742611587048, 0.010863195173442364, 0.010844682343304157, 0.01084016915410757, 0.010835953988134861, 0.010834810324013233, 0.010826902464032173, 0.010796850547194481, 0.010746038518846035, 0.010692491196095943, 0.010686952620744705, 0.010679351165890694, 0.010655333288013935, 0.010604050010442734, 0.010574583895504475, 0.010439733043313026, 0.010402928106486797, 0.010294477455317974]}\n"
43
  ]
44
  }
45
  ],
@@ -73,15 +73,17 @@
73
  },
74
  {
75
  "cell_type": "code",
76
- "execution_count": 13,
77
  "id": "521d9118-b59d-4cc6-b637-20202eaf8f33",
78
- "metadata": {},
 
 
79
  "outputs": [
80
  {
81
  "name": "stdout",
82
  "output_type": "stream",
83
  "text": [
84
- "Running on local URL: http://127.0.0.1:7866\n",
85
  "\n",
86
  "To create a public link, set `share=True` in `launch()`.\n"
87
  ]
@@ -89,7 +91,7 @@
89
  {
90
  "data": {
91
  "text/html": [
92
- "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
93
  ],
94
  "text/plain": [
95
  "<IPython.core.display.HTML object>"
@@ -116,6 +118,48 @@
116
  "if __name__ == \"__main__\":\n",
117
  " demo.launch()"
118
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }
120
  ],
121
  "metadata": {
 
18
  },
19
  {
20
  "cell_type": "code",
21
+ "execution_count": 1,
22
  "id": "fa0d8126-e346-4412-9197-7d51baf868da",
23
  "metadata": {
24
  "scrolled": true
 
39
  "name": "stdout",
40
  "output_type": "stream",
41
  "text": [
42
+ "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Map', 'Compass', 'Laptop', 'Car charger', 'Toys', 'Travel crib', 'Hat', 'Playlist', 'Stroller', 'Currency', 'Travel adapter', 'Hostel lock', 'Pen', 'Charging cables', 'Flip-flops', 'Pacifier', 'Camping stove', 'Multi-tool', 'Passport', 'Poncho', 'Hiking boots', 'Portable charger', 'Power bank', 'Trekking poles', 'Snowboard', 'Base layers', 'Bandana', 'Aloe vera gel', 'Gloves', 'Baby blanket', 'Tent', 'Tent', 'Snorkel gear', 'Water filter', 'Diapers', 'Presentation materials', 'Nursing cover', 'Headphones', 'Sunscreen', 'Beach towel', 'Snacks', 'Ski jacket', 'Earplugs', 'Ski goggles', 'Flashlight', 'Neck wallet', 'Swimsuit', 'Notebook', 'Thermal clothing', 'Blanket', 'Snow boots', 'Sleeping bag', 'Lightweight backpack', 'Refillable water bottle', 'Quick-dry towel', 'Comfortable shoes', 'Reusable shopping bags', 'Travel journal', 'Travel pillow', 'Beach bag', 'Reusable coffee mug', 'Reusable water bottle', 'Festival tickets', 'Waterproof phone case', 'Business attire', 'Sunglasses', 'Sunglasses', 'Cooler', 'Baby clothes', 'Fanny pack', 'Beanie', 'First aid kit', 'Emergency roadside kit', 'Dry bag', 'SIM card', 'Energy bars', 'Baby food', 'Work ID badge', 'Packable rain jacket', 'Hand warmers', 'Visa documents', 'Glow sticks', 'Bug spray', 'Travel-sized toiletries', 'Dress shoes', 'Language phrasebook', 'Baby wipes', 'Lip balm', 'Travel insurance documents'], 'scores': [0.013028442859649658, 0.012909057550132275, 0.0124660674482584, 0.012431488372385502, 0.012379261665046215, 0.012377972714602947, 0.012329353019595146, 0.012096051126718521, 0.012086767703294754, 0.011947661638259888, 0.011939236894249916, 0.011935302056372166, 0.011887168511748314, 0.011814153753221035, 0.011788924224674702, 0.011783207766711712, 0.01177265401929617, 0.011771135963499546, 0.011747810058295727, 0.011738969013094902, 0.01169698778539896, 0.01166312862187624, 0.011658026836812496, 0.011596457101404667, 0.01158847101032734, 0.011561167426407337, 0.011526867747306824, 0.01149983424693346, 0.011472185142338276, 0.011455104686319828, 0.011445573531091213, 0.011445573531091213, 0.011444379575550556, 0.011416648514568806, 0.01136692427098751, 0.011363024823367596, 0.011361461132764816, 0.011328471824526787, 0.011299548670649529, 0.011291779577732086, 0.011282541789114475, 0.01127372495830059, 0.011270811781287193, 0.011263585649430752, 0.011179029010236263, 0.011149592697620392, 0.01113132108002901, 0.011122703552246094, 0.011105425655841827, 0.011101326905190945, 0.011090466752648354, 0.011066330596804619, 0.011058374308049679, 0.011055233888328075, 0.01103114802390337, 0.011022195219993591, 0.011012199334800243, 0.01100123766809702, 0.010985593311488628, 0.010961917228996754, 0.010958753526210785, 0.010938071645796299, 0.010903625749051571, 0.010879918932914734, 0.010863620787858963, 0.010824359022080898, 0.010824359022080898, 0.010805793106555939, 0.010763236321508884, 0.010710005648434162, 0.010690474882721901, 0.010647830553352833, 0.010583569295704365, 0.010571518912911415, 0.010570857673883438, 0.010552200488746166, 0.0105352271348238, 0.010523369535803795, 0.010514546185731888, 0.010479346849024296, 0.010450395755469799, 0.010436479933559895, 0.01043587177991867, 0.010400519706308842, 0.010214710608124733, 0.010052643716335297, 0.010041419416666031, 0.010003888048231602, 0.009946384467184544]}\n"
43
  ]
44
  }
45
  ],
 
73
  },
74
  {
75
  "cell_type": "code",
76
+ "execution_count": 2,
77
  "id": "521d9118-b59d-4cc6-b637-20202eaf8f33",
78
+ "metadata": {
79
+ "scrolled": true
80
+ },
81
  "outputs": [
82
  {
83
  "name": "stdout",
84
  "output_type": "stream",
85
  "text": [
86
+ "Running on local URL: http://127.0.0.1:7860\n",
87
  "\n",
88
  "To create a public link, set `share=True` in `launch()`.\n"
89
  ]
 
91
  {
92
  "data": {
93
  "text/html": [
94
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
95
  ],
96
  "text/plain": [
97
  "<IPython.core.display.HTML object>"
 
118
  "if __name__ == \"__main__\":\n",
119
  " demo.launch()"
120
  ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "id": "d6526d18-6ba6-4a66-8310-21337b832d84",
125
+ "metadata": {},
126
+ "source": [
127
+ "Simple app"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "5496ded9-7294-4da4-af05-00e5846cdd04",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "import gradio as gr\n",
138
+ "from transformers import pipeline\n",
139
+ "\n",
140
+ "# Initialize the zero-shot classification pipeline\n",
141
+ "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")\n",
142
+ "\n",
143
+ "# Define the classification function\n",
144
+ "def classify_text(text, labels):\n",
145
+ " labels = labels.split(\",\") # Convert the comma-separated string into a list\n",
146
+ " result = classifier(text, candidate_labels=labels)\n",
147
+ " return result\n",
148
+ "\n",
149
+ "# Set up the Gradio interface\n",
150
+ "with gr.Blocks() as demo:\n",
151
+ " gr.Markdown(\"# Zero-Shot Classification\")\n",
152
+ " text_input = gr.Textbox(label=\"Input Text\")\n",
153
+ " label_input = gr.Textbox(label=\"Comma-separated Labels\")\n",
154
+ " output = gr.JSON(label=\"Result\")\n",
155
+ " classify_button = gr.Button(\"Classify\")\n",
156
+ "\n",
157
+ " # Link the button to the classification function\n",
158
+ " classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)\n",
159
+ "\n",
160
+ "# Launch the Gradio interface\n",
161
+ "demo.launch()"
162
+ ]
163
  }
164
  ],
165
  "metadata": {
space/packing_list_api.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": 12,
14
  "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
15
  "metadata": {},
16
  "outputs": [],
@@ -35,7 +35,7 @@
35
  },
36
  {
37
  "cell_type": "code",
38
- "execution_count": 13,
39
  "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
40
  "metadata": {},
41
  "outputs": [
@@ -69,7 +69,7 @@
69
  },
70
  {
71
  "cell_type": "code",
72
- "execution_count": 14,
73
  "id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
74
  "metadata": {},
75
  "outputs": [
@@ -114,66 +114,6 @@
114
  "print(output)"
115
  ]
116
  },
117
- {
118
- "cell_type": "markdown",
119
- "id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
120
- "metadata": {},
121
- "source": [
122
- "**Use batches of 10 labels and combine results**"
123
- ]
124
- },
125
- {
126
- "cell_type": "code",
127
- "execution_count": 16,
128
- "id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
129
- "metadata": {},
130
- "outputs": [
131
- {
132
- "name": "stdout",
133
- "output_type": "stream",
134
- "text": [
135
- "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
136
- "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
137
- "\n",
138
- "Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
139
- ]
140
- }
141
- ],
142
- "source": [
143
- "\n",
144
- "input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
145
- "\n",
146
- "\n",
147
- "# Define the full list of possible packing items (split into groups of 10)\n",
148
- "candidate_labels = [\n",
149
- " [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
150
- " \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
151
- " \n",
152
- " [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
153
- " \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
154
- "]\n",
155
- "\n",
156
- "# Run classification in batches\n",
157
- "packing_list = []\n",
158
- "for batch in candidate_labels:\n",
159
- " result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
160
- " print(result)\n",
161
- " for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
162
- " if score > 0.1: # Adjust threshold as needed\n",
163
- " packing_list.append(label)\n",
164
- "\n",
165
- "# Print the final packing list\n",
166
- "print(\"\\nRecommended packing list:\", packing_list)"
167
- ]
168
- },
169
- {
170
- "cell_type": "code",
171
- "execution_count": null,
172
- "id": "660072ea-b72f-4bee-a9ed-81019775ae85",
173
- "metadata": {},
174
- "outputs": [],
175
- "source": []
176
- },
177
  {
178
  "cell_type": "markdown",
179
  "id": "edf44387-d166-4e0f-a8ad-621230aee115",
@@ -184,92 +124,16 @@
184
  },
185
  {
186
  "cell_type": "code",
187
- "execution_count": 1,
188
  "id": "d0d8f7c0-c2d9-4fbe-b1a7-699a5b99466c",
189
  "metadata": {},
190
  "outputs": [
191
- {
192
- "data": {
193
- "application/vnd.jupyter.widget-view+json": {
194
- "model_id": "5e371dee58d64e7b8bf6635e0e88f8db",
195
- "version_major": 2,
196
- "version_minor": 0
197
- },
198
- "text/plain": [
199
- "config.json: 0%| | 0.00/1.72k [00:00<?, ?B/s]"
200
- ]
201
- },
202
- "metadata": {},
203
- "output_type": "display_data"
204
- },
205
- {
206
- "data": {
207
- "application/vnd.jupyter.widget-view+json": {
208
- "model_id": "d479e18a65314ad5927ea2bf7453db7c",
209
- "version_major": 2,
210
- "version_minor": 0
211
- },
212
- "text/plain": [
213
- "model.safetensors: 0%| | 0.00/558M [00:00<?, ?B/s]"
214
- ]
215
- },
216
- "metadata": {},
217
- "output_type": "display_data"
218
- },
219
  {
220
  "name": "stderr",
221
  "output_type": "stream",
222
  "text": [
223
  "Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n",
224
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
225
- ]
226
- },
227
- {
228
- "data": {
229
- "application/vnd.jupyter.widget-view+json": {
230
- "model_id": "8dc911686edb4b15baa880ae657c163d",
231
- "version_major": 2,
232
- "version_minor": 0
233
- },
234
- "text/plain": [
235
- "vocab.json: 0%| | 0.00/899k [00:00<?, ?B/s]"
236
- ]
237
- },
238
- "metadata": {},
239
- "output_type": "display_data"
240
- },
241
- {
242
- "data": {
243
- "application/vnd.jupyter.widget-view+json": {
244
- "model_id": "e60a6df28292441bb5317ef80c9de795",
245
- "version_major": 2,
246
- "version_minor": 0
247
- },
248
- "text/plain": [
249
- "merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
250
- ]
251
- },
252
- "metadata": {},
253
- "output_type": "display_data"
254
- },
255
- {
256
- "data": {
257
- "application/vnd.jupyter.widget-view+json": {
258
- "model_id": "c7eaab50789b42a796d0deb3008f247e",
259
- "version_major": 2,
260
- "version_minor": 0
261
- },
262
- "text/plain": [
263
- "tokenizer.json: 0%| | 0.00/1.36M [00:00<?, ?B/s]"
264
- ]
265
- },
266
- "metadata": {},
267
- "output_type": "display_data"
268
- },
269
- {
270
- "name": "stderr",
271
- "output_type": "stream",
272
- "text": [
273
  "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
274
  "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.\n"
275
  ]
@@ -284,7 +148,7 @@
284
  },
285
  {
286
  "cell_type": "code",
287
- "execution_count": 2,
288
  "id": "4682d620-c9a6-40ad-ab4c-268ee0ef7212",
289
  "metadata": {},
290
  "outputs": [
@@ -299,7 +163,7 @@
299
  "name": "stdout",
300
  "output_type": "stream",
301
  "text": [
302
- "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Travel-sized toiletries', 'Refillable water bottle', 'Aloe vera gel', 'Snorkel gear', 'Waterproof phone case', 'Packable rain jacket', 'Reusable shopping bags', 'Reusable coffee mug', 'Reusable water bottle', 'First aid kit', 'Travel insurance documents', 'Work ID badge', 'Lightweight backpack', 'Presentation materials', 'Flip-flops', 'Charging cables', 'Hiking boots', 'Comfortable shoes', 'Fanny pack', 'Trekking poles', 'Visa documents', 'Baby wipes', 'Quick-dry towel', 'Baby blanket', 'Hostel lock', 'Blanket', 'Business attire', 'Laptop', 'Beanie', 'Bug spray', 'Travel pillow', 'Baby clothes', 'Passport', 'Earplugs', 'Camping stove', 'Travel journal', 'Emergency roadside kit', 'Baby food', 'Pen', 'Bandana', 'Dress shoes', 'Snacks', 'Travel crib', 'Sunscreen', 'Ski goggles', 'Sunglasses', 'Sunglasses', 'Stroller', 'Lip balm', 'Notebook', 'Glow sticks', 'Cooler', 'Snowboard', 'Map', 'Thermal clothing', 'Neck wallet', 'Water filter', 'Travel adapter', 'Currency', 'Nursing cover', 'Snow boots', 'Pacifier', 'Sleeping bag', 'Car charger', 'Diapers', 'Flashlight', 'Ski jacket', 'Portable charger', 'Playlist', 'Swimsuit', 'Tent', 'Tent', 'SIM card', 'Compass', 'Multi-tool', 'Hat', 'Base layers', 'Energy bars', 'Toys', 'Power bank', 'Dry bag', 'Beach towel', 'Beach bag', 'Poncho', 'Headphones', 'Gloves', 'Festival tickets', 'Hand warmers', 'Language phrasebook'], 'scores': [0.014162097126245499, 0.013634984381496906, 0.013528786599636078, 0.013522890396416187, 0.013521893881261349, 0.013390542939305305, 0.013313423842191696, 0.01292099617421627, 0.01269496325403452, 0.01249685138463974, 0.012418625876307487, 0.012351310811936855, 0.012286719866096973, 0.012170663103461266, 0.01216645073145628, 0.012136084027588367, 0.012111806310713291, 0.01203493494540453, 0.011913969181478024, 0.011860690079629421, 0.01184084452688694, 0.011729727499186993, 0.0116303451359272, 0.011585962027311325, 0.011557267978787422, 0.011486714705824852, 0.011480122804641724, 0.011266479268670082, 0.011243777349591255, 0.011239712126553059, 0.011195540428161621, 0.011194570921361446, 0.01118150819092989, 0.011168110184371471, 0.011141857132315636, 0.01114004384726286, 0.011128030717372894, 0.0110848443582654, 0.01107991486787796, 0.01107126846909523, 0.011069754138588905, 0.011015287600457668, 0.01101327408105135, 0.010999458841979504, 0.010981021448969841, 0.010975920595228672, 0.010975920595228672, 0.010966054163873196, 0.010964509099721909, 0.01093060988932848, 0.010892837308347225, 0.010852692648768425, 0.010844447650015354, 0.010827522724866867, 0.010805405676364899, 0.010789167135953903, 0.010784591548144817, 0.010779209434986115, 0.010761956684291363, 0.010743752121925354, 0.010727204382419586, 0.010722712613642216, 0.010696588084101677, 0.01069594919681549, 0.010669016279280186, 0.010664715431630611, 0.010641842149198055, 0.01063066441565752, 0.010608346201479435, 0.010583184659481049, 0.010549037717282772, 0.010549037717282772, 0.010522513650357723, 0.010509520769119263, 0.010469724424183369, 0.010431424714624882, 0.010407780297100544, 0.010376540012657642, 0.01036670058965683, 0.010329049080610275, 0.010298855602741241, 0.01027328334748745, 0.010225902311503887, 0.010063442401587963, 0.01005304791033268, 0.010049044154584408, 0.009841262362897396, 0.009678435511887074, 0.009306504391133785]}\n"
303
  ]
304
  }
305
  ],
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 2,
14
  "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
15
  "metadata": {},
16
  "outputs": [],
 
35
  },
36
  {
37
  "cell_type": "code",
38
+ "execution_count": 3,
39
  "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
40
  "metadata": {},
41
  "outputs": [
 
69
  },
70
  {
71
  "cell_type": "code",
72
+ "execution_count": 4,
73
  "id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
74
  "metadata": {},
75
  "outputs": [
 
114
  "print(output)"
115
  ]
116
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  {
118
  "cell_type": "markdown",
119
  "id": "edf44387-d166-4e0f-a8ad-621230aee115",
 
124
  },
125
  {
126
  "cell_type": "code",
127
+ "execution_count": 5,
128
  "id": "d0d8f7c0-c2d9-4fbe-b1a7-699a5b99466c",
129
  "metadata": {},
130
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  {
132
  "name": "stderr",
133
  "output_type": "stream",
134
  "text": [
135
  "Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n",
136
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
138
  "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.\n"
139
  ]
 
148
  },
149
  {
150
  "cell_type": "code",
151
+ "execution_count": 6,
152
  "id": "4682d620-c9a6-40ad-ab4c-268ee0ef7212",
153
  "metadata": {},
154
  "outputs": [
 
163
  "name": "stdout",
164
  "output_type": "stream",
165
  "text": [
166
+ "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Swimsuit', 'Travel crib', 'Business attire', 'Toys', 'Notebook', 'Travel adapter', 'Compass', 'Travel pillow', 'Headphones', 'Travel journal', 'Playlist', 'Flip-flops', 'Hiking boots', 'Reusable coffee mug', 'Comfortable shoes', 'Nursing cover', 'Gloves', 'Tent', 'Tent', 'Sunglasses', 'Sunglasses', 'Charging cables', 'Travel-sized toiletries', 'Refillable water bottle', 'Energy bars', 'Dress shoes', 'Festival tickets', 'Lightweight backpack', 'Packable rain jacket', 'Flashlight', 'Hostel lock', 'Presentation materials', 'Thermal clothing', 'Snowboard', 'Camping stove', 'Reusable shopping bags', 'Reusable water bottle', 'Blanket', 'Diapers', 'Snorkel gear', 'Snacks', 'Emergency roadside kit', 'Beach towel', 'Sunscreen', 'Car charger', 'Bug spray', 'Passport', 'Currency', 'Beach bag', 'Ski jacket', 'First aid kit', 'Cooler', 'Quick-dry towel', 'Laptop', 'Aloe vera gel', 'Earplugs', 'Baby wipes', 'Ski goggles', 'Travel insurance documents', 'Portable charger', 'Beanie', 'Bandana', 'Multi-tool', 'Pacifier', 'Stroller', 'Language phrasebook', 'Waterproof phone case', 'Dry bag', 'Map', 'Lip balm', 'Fanny pack', 'Trekking poles', 'Power bank', 'Baby clothes', 'Baby food', 'Poncho', 'Sleeping bag', 'Work ID badge', 'Visa documents', 'SIM card', 'Water filter', 'Snow boots', 'Hand warmers', 'Baby blanket', 'Base layers', 'Pen', 'Hat', 'Neck wallet', 'Glow sticks'], 'scores': [0.012542711570858955, 0.012216676957905293, 0.012068654410541058, 0.011977529153227806, 0.011932261288166046, 0.011920000426471233, 0.011883101426064968, 0.011842883192002773, 0.011819617822766304, 0.011810989119112492, 0.011761271394789219, 0.011756575666368008, 0.011726364493370056, 0.011664840392768383, 0.011632450856268406, 0.01163020171225071, 0.01158054918050766, 0.011572858318686485, 0.011572858318686485, 0.011541635729372501, 0.011541635729372501, 0.011517350561916828, 0.011510960757732391, 0.011489875614643097, 0.011469963937997818, 0.011466587893664837, 0.011442759074270725, 0.011438597925007343, 0.011437375098466873, 0.011433145962655544, 0.011407203041017056, 0.011401104740798473, 0.01135423593223095, 0.011333385482430458, 0.011328010819852352, 0.011325137689709663, 0.01131997536867857, 0.011306566186249256, 0.011299673467874527, 0.011281789280474186, 0.011264320462942123, 0.011257764883339405, 0.011256475001573563, 0.011253912933170795, 0.011252702213823795, 0.011248898692429066, 0.011247594840824604, 0.011239985004067421, 0.01121864840388298, 0.011208567768335342, 0.011174682527780533, 0.011166973039507866, 0.011159253306686878, 0.011151333339512348, 0.011140624061226845, 0.011139076203107834, 0.01113345380872488, 0.011126152239739895, 0.011093570850789547, 0.011078842915594578, 0.011067545972764492, 0.011044573038816452, 0.01101986039429903, 0.011016158387064934, 0.011015082709491253, 0.011007890105247498, 0.010997296310961246, 0.010962157510221004, 0.01095755398273468, 0.010940180160105228, 0.01088095735758543, 0.010869039222598076, 0.010858545079827309, 0.010820968076586723, 0.01080892514437437, 0.010798529721796513, 0.01077410951256752, 0.010764310136437416, 0.010748079977929592, 0.010681436397135258, 0.010675576515495777, 0.010557047091424465, 0.010552684776484966, 0.010509641841053963, 0.010396942496299744, 0.01037551462650299, 0.01033466774970293, 0.010237698443233967, 0.009954877197742462]}\n"
167
  ]
168
  }
169
  ],
space/space/space/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ gradio
3
+ torch
space/space/space/space/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
space/space/space/space/.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Initialize the zero-shot classification pipeline
5
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
+
7
+ # Define the classification function
8
+ def classify_text(text, labels):
9
+ labels = labels.split(",") # Convert the comma-separated string into a list
10
+ result = classifier(text, candidate_labels=labels)
11
+ return result
12
+
13
+ # Set up the Gradio interface
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("# Zero-Shot Classification")
16
+ text_input = gr.Textbox(label="Input Text")
17
+ label_input = gr.Textbox(label="Comma-separated Labels")
18
+ output = gr.JSON(label="Result")
19
+ classify_button = gr.Button("Classify")
20
+
21
+ # Link the button to the classification function
22
+ classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)
23
+
24
+ # Launch the Gradio interface
25
+ demo.launch()
space/space/space/space/.ipynb_checkpoints/gradio_tryout-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
space/space/space/space/.ipynb_checkpoints/packing_list_api-checkpoint.ipynb ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7b73f12d-1104-4eea-ac08-3716aa9af45b",
6
+ "metadata": {},
7
+ "source": [
8
+ "**Zero shot classification**"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from dotenv import load_dotenv\n",
19
+ "import os\n",
20
+ "import requests\n",
21
+ "\n",
22
+ "load_dotenv() # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n",
23
+ "\n",
24
+ "API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n",
25
+ "# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n",
26
+ "# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n",
27
+ "# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n",
28
+ "headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n",
29
+ "\n",
30
+ "def query(payload):\n",
31
+ " response = requests.post(API_URL, headers=headers, json=payload)\n",
32
+ " return response.json()\n",
33
+ "\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 2,
39
+ "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "# Input text to classify\n",
52
+ "input_text = \"I just bought a new laptop, and it works amazing!\"\n",
53
+ "\n",
54
+ "# Candidate labels\n",
55
+ "candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n",
56
+ "\n",
57
+ "# Get the prediction\n",
58
+ "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
59
+ "print(output)\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "id": "fb7e69c7-b590-4b40-8478-76d055583f2a",
65
+ "metadata": {},
66
+ "source": [
67
+ "**Try packing list labels**"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 25,
73
+ "id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "name": "stdout",
78
+ "output_type": "stream",
79
+ "text": [
80
+ "{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "# Input text to classify\n",
86
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
87
+ "\n",
88
+ "# Candidate labels\n",
89
+ "candidate_labels = [\n",
90
+ " \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
91
+ " \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
92
+ " \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
93
+ " \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
94
+ " \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
95
+ " \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
96
+ " \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
97
+ " \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
98
+ " \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
99
+ " \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
100
+ " \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
101
+ " \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
102
+ " \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
103
+ " \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
104
+ " \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
105
+ " \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
106
+ " \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
107
+ " \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
108
+ " \"Toys\", \"Nursing cover\"\n",
109
+ "]\n",
110
+ "\n",
111
+ "\n",
112
+ "# Get the prediction\n",
113
+ "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
114
+ "print(output)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
120
+ "metadata": {},
121
+ "source": [
122
+ "**Use batches of 10 labels and combine results**"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 11,
128
+ "id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
136
+ "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
137
+ "\n",
138
+ "Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "\n",
144
+ "input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
145
+ "\n",
146
+ "\n",
147
+ "# Define the full list of possible packing items (split into groups of 10)\n",
148
+ "candidate_labels = [\n",
149
+ " [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
150
+ " \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
151
+ " \n",
152
+ " [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
153
+ " \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
154
+ "]\n",
155
+ "\n",
156
+ "# Run classification in batches\n",
157
+ "packing_list = []\n",
158
+ "for batch in candidate_labels:\n",
159
+ " result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
160
+ " print(result)\n",
161
+ " for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
162
+ " if score > 0.1: # Adjust threshold as needed\n",
163
+ " packing_list.append(label)\n",
164
+ "\n",
165
+ "# Print the final packing list\n",
166
+ "print(\"\\nRecommended packing list:\", packing_list)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "953b244c-0611-4706-a941-eac5064c643f",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": []
176
+ }
177
+ ],
178
+ "metadata": {
179
+ "kernelspec": {
180
+ "display_name": "Python (huggingface_env)",
181
+ "language": "python",
182
+ "name": "huggingface_env"
183
+ },
184
+ "language_info": {
185
+ "codemirror_mode": {
186
+ "name": "ipython",
187
+ "version": 3
188
+ },
189
+ "file_extension": ".py",
190
+ "mimetype": "text/x-python",
191
+ "name": "python",
192
+ "nbconvert_exporter": "python",
193
+ "pygments_lexer": "ipython3",
194
+ "version": "3.8.20"
195
+ }
196
+ },
197
+ "nbformat": 4,
198
+ "nbformat_minor": 5
199
+ }
space/space/space/space/app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Initialize the zero-shot classification pipeline
5
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-base")
6
+
7
+ # Define the classification function
8
+ def classify_text(text, labels):
9
+ labels = labels.split(",") # Convert the comma-separated string into a list
10
+ result = classifier(text, candidate_labels=labels)
11
+ return result
12
+
13
+ # Set up the Gradio interface
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("# Zero-Shot Classification")
16
+ text_input = gr.Textbox(label="Input Text")
17
+ label_input = gr.Textbox(label="Comma-separated Labels")
18
+ output = gr.JSON(label="Result")
19
+ classify_button = gr.Button("Classify")
20
+
21
+ # Link the button to the classification function
22
+ classify_button.click(classify_text, inputs=[text_input, label_input], outputs=output)
23
+
24
+ # Launch the Gradio interface
25
+ demo.launch()
space/space/space/space/gradio_tryout.ipynb ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e25090fa-f990-4f1a-84f3-b12159eedae8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Try out gradio"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "3bbee2e4-55c8-4b06-9929-72026edf7932",
14
+ "metadata": {},
15
+ "source": [
16
+ "Try model"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 6,
22
+ "id": "fa0d8126-e346-4412-9197-7d51baf868da",
23
+ "metadata": {
24
+ "scrolled": true
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stderr",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n",
32
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
33
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
34
+ "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.\n",
35
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
36
+ ]
37
+ },
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Compass', 'Travel insurance documents', 'Cooler', 'Poncho', 'Comfortable shoes', 'Thermal clothing', 'Business attire', 'Stroller', 'Refillable water bottle', 'Sunscreen', 'Hiking boots', 'Trekking poles', 'Tent', 'Tent', 'Swimsuit', 'Lightweight backpack', 'Diapers', 'Pen', 'Lip balm', 'Bandana', 'Presentation materials', 'Snorkel gear', 'Sunglasses', 'Sunglasses', 'Snowboard', 'Baby wipes', 'Emergency roadside kit', 'Blanket', 'Passport', 'Aloe vera gel', 'Currency', 'Beanie', 'Hand warmers', 'Reusable shopping bags', 'Hat', 'Travel-sized toiletries', 'Waterproof phone case', 'Energy bars', 'Baby food', 'Reusable water bottle', 'Flashlight', 'Gloves', 'Baby clothes', 'Hostel lock', 'Visa documents', 'Camping stove', 'Bug spray', 'Packable rain jacket', 'Travel pillow', 'Power bank', 'Earplugs', 'Quick-dry towel', 'Reusable coffee mug', 'Travel journal', 'Fanny pack', 'Headphones', 'Notebook', 'Dress shoes', 'Nursing cover', 'Playlist', 'Base layers', 'Work ID badge', 'Festival tickets', 'Sleeping bag', 'Laptop', 'Baby blanket', 'Charging cables', 'Snow boots', 'First aid kit', 'Snacks', 'Flip-flops', 'Toys', 'Car charger', 'Ski jacket', 'Dry bag', 'Pacifier', 'Map', 'Portable charger', 'Travel crib', 'Multi-tool', 'Beach bag', 'Ski goggles', 'SIM card', 'Glow sticks', 'Beach towel', 'Travel adapter', 'Neck wallet', 'Language phrasebook', 'Water filter'], 'scores': [0.011984821408987045, 0.011970506981015205, 0.011933253146708012, 0.011915490962564945, 0.011904211714863777, 0.011892491020262241, 0.01188766211271286, 0.011866495944559574, 0.011842762120068073, 0.011789090000092983, 0.011770269833505154, 0.011769718490540981, 0.011746660806238651, 0.011746660806238651, 0.011718676425516605, 0.01164235919713974, 0.011551206931471825, 0.011529732495546341, 0.011518468149006367, 0.011516833677887917, 0.011508049443364143, 0.011507270857691765, 0.01149584911763668, 0.01149584911763668, 0.011495097540318966, 0.01149324607104063, 0.011486946605145931, 0.01148668210953474, 0.011478666216135025, 0.011473646387457848, 0.011412998661398888, 0.011398673988878727, 0.011378799565136433, 0.01135518029332161, 0.011335738934576511, 0.011330211535096169, 0.011329339817166328, 0.011324702762067318, 0.01131915021687746, 0.01131164189428091, 0.011294065974652767, 0.011273612268269062, 0.011272135190665722, 0.011252084746956825, 0.01122584380209446, 0.011216048151254654, 0.011204490438103676, 0.011203117668628693, 0.01117485947906971, 0.01117344293743372, 0.011145292781293392, 0.011137993074953556, 0.011128612793982029, 0.011123239994049072, 0.011122280731797218, 0.011065744794905186, 0.011053262278437614, 0.011045967228710651, 0.011041177436709404, 0.011033336631953716, 0.01102971937507391, 0.0110141197219491, 0.01100961398333311, 0.011002525687217712, 0.010937424376606941, 0.01093329582363367, 0.010918675921857357, 0.010917853564023972, 0.010890142060816288, 0.01088369358330965, 0.010871977545320988, 0.010870742611587048, 0.010863195173442364, 0.010844682343304157, 0.01084016915410757, 0.010835953988134861, 0.010834810324013233, 0.010826902464032173, 0.010796850547194481, 0.010746038518846035, 0.010692491196095943, 0.010686952620744705, 0.010679351165890694, 0.010655333288013935, 0.010604050010442734, 0.010574583895504475, 0.010439733043313026, 0.010402928106486797, 0.010294477455317974]}\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "from transformers import pipeline\n",
48
+ "import gradio as gr\n",
49
+ "\n",
50
+ "# Load the model and create a pipeline for zero-shot classification\n",
51
+ "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")\n",
52
+ "\n",
53
+ "# Load labels from a txt file\n",
54
+ "with open(\"labels.txt\", \"r\", encoding=\"utf-8\") as f:\n",
55
+ " class_labels = [line.strip() for line in f if line.strip()]\n",
56
+ "\n",
57
+ "# Example text to classify\n",
58
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
59
+ "\n",
60
+ "# Perform classification\n",
61
+ "result = classifier(input_text, class_labels)\n",
62
+ "\n",
63
+ "print(result)"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "8e856a9c-a66c-4c4b-b7cf-8c52abbbc6fa",
69
+ "metadata": {},
70
+ "source": [
71
+ "Use model with gradio"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 13,
77
+ "id": "521d9118-b59d-4cc6-b637-20202eaf8f33",
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Running on local URL: http://127.0.0.1:7866\n",
85
+ "\n",
86
+ "To create a public link, set `share=True` in `launch()`.\n"
87
+ ]
88
+ },
89
+ {
90
+ "data": {
91
+ "text/html": [
92
+ "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
93
+ ],
94
+ "text/plain": [
95
+ "<IPython.core.display.HTML object>"
96
+ ]
97
+ },
98
+ "metadata": {},
99
+ "output_type": "display_data"
100
+ }
101
+ ],
102
+ "source": [
103
+ "# Define the Gradio interface\n",
104
+ "def classify(text):\n",
105
+ " return classifier(text, class_labels)\n",
106
+ "\n",
107
+ "demo = gr.Interface(\n",
108
+ " fn=classify,\n",
109
+ " inputs=\"text\",\n",
110
+ " outputs=\"json\",\n",
111
+ " title=\"Zero-Shot Classification\",\n",
112
+ " description=\"Enter a text describing your trip\",\n",
113
+ ")\n",
114
+ "\n",
115
+ "# Launch the Gradio app\n",
116
+ "if __name__ == \"__main__\":\n",
117
+ " demo.launch()"
118
+ ]
119
+ }
120
+ ],
121
+ "metadata": {
122
+ "kernelspec": {
123
+ "display_name": "Python (huggingface_env)",
124
+ "language": "python",
125
+ "name": "huggingface_env"
126
+ },
127
+ "language_info": {
128
+ "codemirror_mode": {
129
+ "name": "ipython",
130
+ "version": 3
131
+ },
132
+ "file_extension": ".py",
133
+ "mimetype": "text/x-python",
134
+ "name": "python",
135
+ "nbconvert_exporter": "python",
136
+ "pygments_lexer": "ipython3",
137
+ "version": "3.8.20"
138
+ }
139
+ },
140
+ "nbformat": 4,
141
+ "nbformat_minor": 5
142
+ }
space/space/space/space/labels.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Swimsuit
2
+ Sunscreen
3
+ Flip-flops
4
+ Beach towel
5
+ Sunglasses
6
+ Waterproof phone case
7
+ Hat
8
+ Beach bag
9
+ Snorkel gear
10
+ Aloe vera gel
11
+ Tent
12
+ Sleeping bag
13
+ Camping stove
14
+ Flashlight
15
+ Hiking boots
16
+ Water filter
17
+ Compass
18
+ First aid kit
19
+ Bug spray
20
+ Multi-tool
21
+ Thermal clothing
22
+ Ski jacket
23
+ Ski goggles
24
+ Snow boots
25
+ Gloves
26
+ Hand warmers
27
+ Beanie
28
+ Lip balm
29
+ Snowboard
30
+ Base layers
31
+ Passport
32
+ Visa documents
33
+ Travel adapter
34
+ Currency
35
+ Language phrasebook
36
+ SIM card
37
+ Travel pillow
38
+ Neck wallet
39
+ Travel insurance documents
40
+ Power bank
41
+ Laptop
42
+ Notebook
43
+ Business attire
44
+ Dress shoes
45
+ Charging cables
46
+ Presentation materials
47
+ Work ID badge
48
+ Pen
49
+ Headphones
50
+ Lightweight backpack
51
+ Travel-sized toiletries
52
+ Packable rain jacket
53
+ Reusable water bottle
54
+ Dry bag
55
+ Trekking poles
56
+ Hostel lock
57
+ Quick-dry towel
58
+ Travel journal
59
+ Energy bars
60
+ Car charger
61
+ Snacks
62
+ Map
63
+ Sunglasses
64
+ Cooler
65
+ Blanket
66
+ Emergency roadside kit
67
+ Reusable coffee mug
68
+ Playlist
69
+ Reusable shopping bags
70
+ Earplugs
71
+ Fanny pack
72
+ Portable charger
73
+ Poncho
74
+ Bandana
75
+ Comfortable shoes
76
+ Tent
77
+ Refillable water bottle
78
+ Glow sticks
79
+ Festival tickets
80
+ Diapers
81
+ Baby wipes
82
+ Baby food
83
+ Stroller
84
+ Pacifier
85
+ Baby clothes
86
+ Baby blanket
87
+ Travel crib
88
+ Toys
89
+ Nursing cover
space/space/space/space/packing_list_api.ipynb ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7b73f12d-1104-4eea-ac08-3716aa9af45b",
6
+ "metadata": {},
7
+ "source": [
8
+ "**Zero shot classification**"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 12,
14
+ "id": "05a29daa-b70e-4c7c-ba03-9ab641f424cb",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from dotenv import load_dotenv\n",
19
+ "import os\n",
20
+ "import requests\n",
21
+ "\n",
22
+ "load_dotenv() # Load environment variables from .env file, contains personal access token (HF_API_TOKEN=your_token)\n",
23
+ "\n",
24
+ "API_URL = \"https://api-inference.huggingface.co/models/facebook/bart-large-mnli\"\n",
25
+ "# API_URL = \"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\"\n",
26
+ "# API_URL = \"https://api-inference.huggingface.co/models/cross-encoder/nli-deberta-v3-base\"\n",
27
+ "# API_URL = \"https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3\"\n",
28
+ "headers = {\"Authorization\": f\"Bearer {os.getenv('HF_API_TOKEN')}\"}\n",
29
+ "\n",
30
+ "def query(payload):\n",
31
+ " response = requests.post(API_URL, headers=headers, json=payload)\n",
32
+ " return response.json()\n",
33
+ "\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 13,
39
+ "id": "21b4f8b6-e774-45ad-8054-bf5db2b7b07c",
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "{'sequence': 'I just bought a new laptop, and it works amazing!', 'labels': ['technology', 'health', 'sports', 'politics'], 'scores': [0.9709171652793884, 0.014999167993664742, 0.008272457867860794, 0.005811102222651243]}\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "# Input text to classify\n",
52
+ "input_text = \"I just bought a new laptop, and it works amazing!\"\n",
53
+ "\n",
54
+ "# Candidate labels\n",
55
+ "candidate_labels = [\"technology\", \"sports\", \"politics\", \"health\"]\n",
56
+ "\n",
57
+ "# Get the prediction\n",
58
+ "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
59
+ "print(output)\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "id": "fb7e69c7-b590-4b40-8478-76d055583f2a",
65
+ "metadata": {},
66
+ "source": [
67
+ "**Try packing list labels**"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 14,
73
+ "id": "c5f75916-aaf2-4ca7-8d1a-070579940952",
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "name": "stdout",
78
+ "output_type": "stream",
79
+ "text": [
80
+ "{'error': ['Error in `parameters.candidate_labels`: ensure this value has at most 10 items']}\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "# Input text to classify\n",
86
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
87
+ "\n",
88
+ "# Candidate labels\n",
89
+ "candidate_labels = [\n",
90
+ " \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
91
+ " \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
92
+ " \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
93
+ " \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
94
+ " \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
95
+ " \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
96
+ " \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
97
+ " \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
98
+ " \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
99
+ " \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
100
+ " \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
101
+ " \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
102
+ " \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
103
+ " \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
104
+ " \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
105
+ " \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
106
+ " \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
107
+ " \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
108
+ " \"Toys\", \"Nursing cover\"\n",
109
+ "]\n",
110
+ "\n",
111
+ "\n",
112
+ "# Get the prediction\n",
113
+ "output = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": candidate_labels}})\n",
114
+ "print(output)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "id": "8a6318c1-fa5f-4d16-8507-eaebe6294ac0",
120
+ "metadata": {},
121
+ "source": [
122
+ "**Use batches of 10 labels and combine results**"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 16,
128
+ "id": "fe42a222-5ff4-4442-93f4-42fc22001af6",
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Flashlight', 'Gloves', 'Camping stove', 'Water filter', 'Sleeping bag'], 'scores': [0.30358555912971497, 0.12884855270385742, 0.10985139012336731, 0.10500500351190567, 0.10141848027706146, 0.08342219144105911, 0.0704946368932724, 0.05127469450235367, 0.024876652285456657, 0.021222807466983795]}\n",
136
+ "{'sequence': \"I'm going on a 2-week hiking trip in the Alps during winter.\", 'labels': ['Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie', 'Ski goggles', 'Flip-flops', 'First aid kit', 'Sunscreen', 'Swimsuit', 'Lip balm'], 'scores': [0.20171622931957245, 0.1621972620487213, 0.12313881516456604, 0.10742709040641785, 0.09418268501758575, 0.08230196684598923, 0.07371978461742401, 0.06208840385079384, 0.05506424233317375, 0.038163457065820694]}\n",
137
+ "\n",
138
+ "Recommended packing list: ['Map', 'Backpack', 'Tent', 'Thermal clothing', 'Hiking boots', 'Ski jacket', 'Snow boots', 'Hand warmers', 'Beanie']\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "\n",
144
+ "input_text = \"I'm going on a 2-week hiking trip in the Alps during winter.\"\n",
145
+ "\n",
146
+ "\n",
147
+ "# Define the full list of possible packing items (split into groups of 10)\n",
148
+ "candidate_labels = [\n",
149
+ " [\"Hiking boots\", \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Backpack\",\n",
150
+ " \"Water filter\", \"Flashlight\", \"Thermal clothing\", \"Gloves\", \"Map\"],\n",
151
+ " \n",
152
+ " [\"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Ski jacket\", \"Ski goggles\",\n",
153
+ " \"Snow boots\", \"Beanie\", \"Hand warmers\", \"Lip balm\", \"First aid kit\"]\n",
154
+ "]\n",
155
+ "\n",
156
+ "# Run classification in batches\n",
157
+ "packing_list = []\n",
158
+ "for batch in candidate_labels:\n",
159
+ " result = query({\"inputs\": input_text, \"parameters\": {\"candidate_labels\": batch}})\n",
160
+ " print(result)\n",
161
+ " for label, score in zip(result[\"labels\"], result[\"scores\"]):\n",
162
+ " if score > 0.1: # Adjust threshold as needed\n",
163
+ " packing_list.append(label)\n",
164
+ "\n",
165
+ "# Print the final packing list\n",
166
+ "print(\"\\nRecommended packing list:\", packing_list)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "660072ea-b72f-4bee-a9ed-81019775ae85",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": []
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "id": "edf44387-d166-4e0f-a8ad-621230aee115",
180
+ "metadata": {},
181
+ "source": [
182
+ "**Try to run a model locally**"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 1,
188
+ "id": "d0d8f7c0-c2d9-4fbe-b1a7-699a5b99466c",
189
+ "metadata": {},
190
+ "outputs": [
191
+ {
192
+ "data": {
193
+ "application/vnd.jupyter.widget-view+json": {
194
+ "model_id": "5e371dee58d64e7b8bf6635e0e88f8db",
195
+ "version_major": 2,
196
+ "version_minor": 0
197
+ },
198
+ "text/plain": [
199
+ "config.json: 0%| | 0.00/1.72k [00:00<?, ?B/s]"
200
+ ]
201
+ },
202
+ "metadata": {},
203
+ "output_type": "display_data"
204
+ },
205
+ {
206
+ "data": {
207
+ "application/vnd.jupyter.widget-view+json": {
208
+ "model_id": "d479e18a65314ad5927ea2bf7453db7c",
209
+ "version_major": 2,
210
+ "version_minor": 0
211
+ },
212
+ "text/plain": [
213
+ "model.safetensors: 0%| | 0.00/558M [00:00<?, ?B/s]"
214
+ ]
215
+ },
216
+ "metadata": {},
217
+ "output_type": "display_data"
218
+ },
219
+ {
220
+ "name": "stderr",
221
+ "output_type": "stream",
222
+ "text": [
223
+ "Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n",
224
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
225
+ ]
226
+ },
227
+ {
228
+ "data": {
229
+ "application/vnd.jupyter.widget-view+json": {
230
+ "model_id": "8dc911686edb4b15baa880ae657c163d",
231
+ "version_major": 2,
232
+ "version_minor": 0
233
+ },
234
+ "text/plain": [
235
+ "vocab.json: 0%| | 0.00/899k [00:00<?, ?B/s]"
236
+ ]
237
+ },
238
+ "metadata": {},
239
+ "output_type": "display_data"
240
+ },
241
+ {
242
+ "data": {
243
+ "application/vnd.jupyter.widget-view+json": {
244
+ "model_id": "e60a6df28292441bb5317ef80c9de795",
245
+ "version_major": 2,
246
+ "version_minor": 0
247
+ },
248
+ "text/plain": [
249
+ "merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
250
+ ]
251
+ },
252
+ "metadata": {},
253
+ "output_type": "display_data"
254
+ },
255
+ {
256
+ "data": {
257
+ "application/vnd.jupyter.widget-view+json": {
258
+ "model_id": "c7eaab50789b42a796d0deb3008f247e",
259
+ "version_major": 2,
260
+ "version_minor": 0
261
+ },
262
+ "text/plain": [
263
+ "tokenizer.json: 0%| | 0.00/1.36M [00:00<?, ?B/s]"
264
+ ]
265
+ },
266
+ "metadata": {},
267
+ "output_type": "display_data"
268
+ },
269
+ {
270
+ "name": "stderr",
271
+ "output_type": "stream",
272
+ "text": [
273
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
274
+ "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.\n"
275
+ ]
276
+ }
277
+ ],
278
+ "source": [
279
+ "from transformers import pipeline\n",
280
+ "\n",
281
+ "# Load the model and create a pipeline for zero-shot classification\n",
282
+ "classifier = pipeline(\"zero-shot-classification\", model=\"facebook/bart-base\")"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 2,
288
+ "id": "4682d620-c9a6-40ad-ab4c-268ee0ef7212",
289
+ "metadata": {},
290
+ "outputs": [
291
+ {
292
+ "name": "stderr",
293
+ "output_type": "stream",
294
+ "text": [
295
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
296
+ ]
297
+ },
298
+ {
299
+ "name": "stdout",
300
+ "output_type": "stream",
301
+ "text": [
302
+ "{'sequence': 'I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.', 'labels': ['Travel-sized toiletries', 'Refillable water bottle', 'Aloe vera gel', 'Snorkel gear', 'Waterproof phone case', 'Packable rain jacket', 'Reusable shopping bags', 'Reusable coffee mug', 'Reusable water bottle', 'First aid kit', 'Travel insurance documents', 'Work ID badge', 'Lightweight backpack', 'Presentation materials', 'Flip-flops', 'Charging cables', 'Hiking boots', 'Comfortable shoes', 'Fanny pack', 'Trekking poles', 'Visa documents', 'Baby wipes', 'Quick-dry towel', 'Baby blanket', 'Hostel lock', 'Blanket', 'Business attire', 'Laptop', 'Beanie', 'Bug spray', 'Travel pillow', 'Baby clothes', 'Passport', 'Earplugs', 'Camping stove', 'Travel journal', 'Emergency roadside kit', 'Baby food', 'Pen', 'Bandana', 'Dress shoes', 'Snacks', 'Travel crib', 'Sunscreen', 'Ski goggles', 'Sunglasses', 'Sunglasses', 'Stroller', 'Lip balm', 'Notebook', 'Glow sticks', 'Cooler', 'Snowboard', 'Map', 'Thermal clothing', 'Neck wallet', 'Water filter', 'Travel adapter', 'Currency', 'Nursing cover', 'Snow boots', 'Pacifier', 'Sleeping bag', 'Car charger', 'Diapers', 'Flashlight', 'Ski jacket', 'Portable charger', 'Playlist', 'Swimsuit', 'Tent', 'Tent', 'SIM card', 'Compass', 'Multi-tool', 'Hat', 'Base layers', 'Energy bars', 'Toys', 'Power bank', 'Dry bag', 'Beach towel', 'Beach bag', 'Poncho', 'Headphones', 'Gloves', 'Festival tickets', 'Hand warmers', 'Language phrasebook'], 'scores': [0.014162097126245499, 0.013634984381496906, 0.013528786599636078, 0.013522890396416187, 0.013521893881261349, 0.013390542939305305, 0.013313423842191696, 0.01292099617421627, 0.01269496325403452, 0.01249685138463974, 0.012418625876307487, 0.012351310811936855, 0.012286719866096973, 0.012170663103461266, 0.01216645073145628, 0.012136084027588367, 0.012111806310713291, 0.01203493494540453, 0.011913969181478024, 0.011860690079629421, 0.01184084452688694, 0.011729727499186993, 0.0116303451359272, 0.011585962027311325, 0.011557267978787422, 0.011486714705824852, 0.011480122804641724, 0.011266479268670082, 0.011243777349591255, 0.011239712126553059, 0.011195540428161621, 0.011194570921361446, 0.01118150819092989, 0.011168110184371471, 0.011141857132315636, 0.01114004384726286, 0.011128030717372894, 0.0110848443582654, 0.01107991486787796, 0.01107126846909523, 0.011069754138588905, 0.011015287600457668, 0.01101327408105135, 0.010999458841979504, 0.010981021448969841, 0.010975920595228672, 0.010975920595228672, 0.010966054163873196, 0.010964509099721909, 0.01093060988932848, 0.010892837308347225, 0.010852692648768425, 0.010844447650015354, 0.010827522724866867, 0.010805405676364899, 0.010789167135953903, 0.010784591548144817, 0.010779209434986115, 0.010761956684291363, 0.010743752121925354, 0.010727204382419586, 0.010722712613642216, 0.010696588084101677, 0.01069594919681549, 0.010669016279280186, 0.010664715431630611, 0.010641842149198055, 0.01063066441565752, 0.010608346201479435, 0.010583184659481049, 0.010549037717282772, 0.010549037717282772, 0.010522513650357723, 0.010509520769119263, 0.010469724424183369, 0.010431424714624882, 0.010407780297100544, 0.010376540012657642, 0.01036670058965683, 0.010329049080610275, 0.010298855602741241, 0.01027328334748745, 0.010225902311503887, 0.010063442401587963, 0.01005304791033268, 0.010049044154584408, 0.009841262362897396, 0.009678435511887074, 0.009306504391133785]}\n"
303
+ ]
304
+ }
305
+ ],
306
+ "source": [
307
+ "input_text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
308
+ "\n",
309
+ "# Candidate labels\n",
310
+ "candidate_labels = [\n",
311
+ " \"Swimsuit\", \"Sunscreen\", \"Flip-flops\", \"Beach towel\", \"Sunglasses\", \n",
312
+ " \"Waterproof phone case\", \"Hat\", \"Beach bag\", \"Snorkel gear\", \"Aloe vera gel\",\n",
313
+ " \"Tent\", \"Sleeping bag\", \"Camping stove\", \"Flashlight\", \"Hiking boots\",\n",
314
+ " \"Water filter\", \"Compass\", \"First aid kit\", \"Bug spray\", \"Multi-tool\",\n",
315
+ " \"Thermal clothing\", \"Ski jacket\", \"Ski goggles\", \"Snow boots\", \"Gloves\",\n",
316
+ " \"Hand warmers\", \"Beanie\", \"Lip balm\", \"Snowboard\", \"Base layers\",\n",
317
+ " \"Passport\", \"Visa documents\", \"Travel adapter\", \"Currency\", \"Language phrasebook\",\n",
318
+ " \"SIM card\", \"Travel pillow\", \"Neck wallet\", \"Travel insurance documents\", \"Power bank\",\n",
319
+ " \"Laptop\", \"Notebook\", \"Business attire\", \"Dress shoes\", \"Charging cables\",\n",
320
+ " \"Presentation materials\", \"Work ID badge\", \"Pen\", \"Headphones\", \n",
321
+ " \"Lightweight backpack\", \"Travel-sized toiletries\", \"Packable rain jacket\",\n",
322
+ " \"Reusable water bottle\", \"Dry bag\", \"Trekking poles\", \"Hostel lock\", \"Quick-dry towel\",\n",
323
+ " \"Travel journal\", \"Energy bars\", \"Car charger\", \"Snacks\", \"Map\",\n",
324
+ " \"Sunglasses\", \"Cooler\", \"Blanket\", \"Emergency roadside kit\", \"Reusable coffee mug\",\n",
325
+ " \"Playlist\", \"Reusable shopping bags\", \"Earplugs\", \"Fanny pack\", \"Portable charger\",\n",
326
+ " \"Poncho\", \"Bandana\", \"Comfortable shoes\", \"Tent\", \"Refillable water bottle\",\n",
327
+ " \"Glow sticks\", \"Festival tickets\", \"Diapers\", \"Baby wipes\", \"Baby food\",\n",
328
+ " \"Stroller\", \"Pacifier\", \"Baby clothes\", \"Baby blanket\", \"Travel crib\",\n",
329
+ " \"Toys\", \"Nursing cover\"\n",
330
+ "]\n",
331
+ "\n",
332
+ "\n",
333
+ "# Run the classification\n",
334
+ "result = classifier(input_text, candidate_labels)\n",
335
+ "\n",
336
+ "# Print the result\n",
337
+ "print(result)"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "id": "a344a80f-7645-4c2c-b960-580aa0b345f6",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": []
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": null,
351
+ "id": "5eb705d6-c31c-406c-9739-ff45b66c7ca4",
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": []
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 12,
359
+ "id": "ee734de6-bbcb-427d-8987-ab41286f7907",
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# Example text to classify\n",
364
+ "text = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july.\"\n",
365
+ "\n",
366
+ "# No prompt\n",
367
+ "no_prompt = text\n",
368
+ "no_result = classifier(no_prompt, candidate_labels)\n",
369
+ "\n",
370
+ "\n",
371
+ "# Simple prompt\n",
372
+ "simple_prompt = \"Classify the following text: \" + text\n",
373
+ "simple_result = classifier(simple_prompt, candidate_labels)\n",
374
+ "\n",
375
+ "# Primed prompt\n",
376
+ "primed_prompt = \"I like to cycle and I burn easily. I also love culture and like to post on social media about my food. I will go on a trip to italy in july. What are the most important things to pack for the trip?\"\n",
377
+ "primed_result = classifier(primed_prompt, candidate_labels)"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": 13,
383
+ "id": "96deb877-b3b0-4048-9960-d8b0b0d56cd0",
384
+ "metadata": {},
385
+ "outputs": [
386
+ {
387
+ "name": "stdout",
388
+ "output_type": "stream",
389
+ "text": [
390
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
391
+ "| no_prompt | no_prompt | simple_prompt | simple_prompt | primed_prompt | primed_prompt |\n",
392
+ "+============================+=============+============================+=================+============================+=================+\n",
393
+ "| Travel-sized toiletries | 0.0141621 | Beanie | 0.0126489 | First aid kit | 0.0126422 |\n",
394
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
395
+ "| Refillable water bottle | 0.013635 | Baby wipes | 0.0125994 | Work ID badge | 0.0125781 |\n",
396
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
397
+ "| Aloe vera gel | 0.0135288 | Bandana | 0.0125701 | Travel insurance documents | 0.0125387 |\n",
398
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
399
+ "| Snorkel gear | 0.0135229 | Blanket | 0.0125266 | Business attire | 0.0124256 |\n",
400
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
401
+ "| Waterproof phone case | 0.0135219 | Sunglasses | 0.0123896 | Baby wipes | 0.012401 |\n",
402
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
403
+ "| Packable rain jacket | 0.0133905 | Sunglasses | 0.0123896 | Blanket | 0.0122619 |\n",
404
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
405
+ "| Reusable shopping bags | 0.0133134 | Laptop | 0.0123645 | Lightweight backpack | 0.0122291 |\n",
406
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
407
+ "| Reusable coffee mug | 0.012921 | Snacks | 0.0123038 | Sunglasses | 0.0121536 |\n",
408
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
409
+ "| Reusable water bottle | 0.012695 | Sunscreen | 0.0122985 | Sunglasses | 0.0121536 |\n",
410
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
411
+ "| First aid kit | 0.0124969 | Pen | 0.0122703 | Laptop | 0.0121034 |\n",
412
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
413
+ "| Travel insurance documents | 0.0124186 | Cooler | 0.0122299 | Passport | 0.0121023 |\n",
414
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
415
+ "| Work ID badge | 0.0123513 | Snowboard | 0.012205 | Beanie | 0.0120397 |\n",
416
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
417
+ "| Lightweight backpack | 0.0122867 | Passport | 0.0121188 | Baby clothes | 0.0120325 |\n",
418
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
419
+ "| Presentation materials | 0.0121707 | Visa documents | 0.0121176 | Snacks | 0.0119757 |\n",
420
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
421
+ "| Flip-flops | 0.0121665 | Swimsuit | 0.0120711 | Packable rain jacket | 0.011946 |\n",
422
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
423
+ "| Charging cables | 0.0121361 | Flashlight | 0.0120105 | Baby food | 0.0119228 |\n",
424
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
425
+ "| Hiking boots | 0.0121118 | Stroller | 0.0119368 | Baby blanket | 0.0118852 |\n",
426
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
427
+ "| Comfortable shoes | 0.0120349 | Map | 0.01193 | Dress shoes | 0.0118458 |\n",
428
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
429
+ "| Fanny pack | 0.011914 | First aid kit | 0.0119121 | Bug spray | 0.0118403 |\n",
430
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
431
+ "| Trekking poles | 0.0118607 | Notebook | 0.0118809 | Travel journal | 0.0118067 |\n",
432
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
433
+ "| Visa documents | 0.0118408 | Hat | 0.011833 | Travel pillow | 0.0118006 |\n",
434
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
435
+ "| Baby wipes | 0.0117297 | Currency | 0.0118279 | Visa documents | 0.0117734 |\n",
436
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
437
+ "| Quick-dry towel | 0.0116303 | Work ID badge | 0.0117867 | Emergency roadside kit | 0.0117412 |\n",
438
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
439
+ "| Baby blanket | 0.011586 | Travel insurance documents | 0.01168 | SIM card | 0.0117407 |\n",
440
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
441
+ "| Hostel lock | 0.0115573 | Business attire | 0.0116774 | Cooler | 0.0117317 |\n",
442
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
443
+ "| Blanket | 0.0114867 | Compass | 0.0116575 | Snowboard | 0.0117232 |\n",
444
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
445
+ "| Business attire | 0.0114801 | Playlist | 0.0116254 | Diapers | 0.0117056 |\n",
446
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
447
+ "| Laptop | 0.0112665 | Bug spray | 0.0115941 | Notebook | 0.011676 |\n",
448
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
449
+ "| Beanie | 0.0112438 | Tent | 0.0115531 | Bandana | 0.0116441 |\n",
450
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
451
+ "| Bug spray | 0.0112397 | Tent | 0.0115531 | Pen | 0.011614 |\n",
452
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
453
+ "| Travel pillow | 0.0111955 | Diapers | 0.0115231 | Flashlight | 0.011587 |\n",
454
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
455
+ "| Baby clothes | 0.0111946 | Travel journal | 0.0114808 | Playlist | 0.0115787 |\n",
456
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
457
+ "| Passport | 0.0111815 | Hiking boots | 0.0114734 | Sunscreen | 0.0115577 |\n",
458
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
459
+ "| Earplugs | 0.0111681 | Reusable shopping bags | 0.0114722 | Swimsuit | 0.0115468 |\n",
460
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
461
+ "| Camping stove | 0.0111419 | SIM card | 0.0114319 | Reusable coffee mug | 0.0115091 |\n",
462
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
463
+ "| Travel journal | 0.01114 | Toys | 0.0114257 | Trekking poles | 0.011476 |\n",
464
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
465
+ "| Emergency roadside kit | 0.011128 | Dress shoes | 0.0113439 | Sleeping bag | 0.0114472 |\n",
466
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
467
+ "| Baby food | 0.0110848 | Waterproof phone case | 0.0113438 | Hiking boots | 0.0114388 |\n",
468
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
469
+ "| Pen | 0.0110799 | Travel pillow | 0.0113271 | Snorkel gear | 0.0114219 |\n",
470
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
471
+ "| Bandana | 0.0110713 | Refillable water bottle | 0.0113269 | Reusable shopping bags | 0.0113664 |\n",
472
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
473
+ "| Dress shoes | 0.0110698 | Fanny pack | 0.0113193 | Portable charger | 0.0113632 |\n",
474
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
475
+ "| Snacks | 0.0110153 | Baby blanket | 0.0113175 | Fanny pack | 0.011333 |\n",
476
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
477
+ "| Travel crib | 0.0110133 | Aloe vera gel | 0.0113123 | Headphones | 0.0113156 |\n",
478
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
479
+ "| Sunscreen | 0.0109995 | Snorkel gear | 0.011283 | Currency | 0.0112893 |\n",
480
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
481
+ "| Ski goggles | 0.010981 | Pacifier | 0.0112826 | Travel adapter | 0.0112652 |\n",
482
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
483
+ "| Sunglasses | 0.0109759 | Headphones | 0.0112543 | Travel crib | 0.011224 |\n",
484
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
485
+ "| Sunglasses | 0.0109759 | Packable rain jacket | 0.0112416 | Presentation materials | 0.0112228 |\n",
486
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
487
+ "| Stroller | 0.0109661 | Poncho | 0.0112411 | Waterproof phone case | 0.0112181 |\n",
488
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
489
+ "| Lip balm | 0.0109645 | Nursing cover | 0.0112323 | Nursing cover | 0.0111811 |\n",
490
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
491
+ "| Notebook | 0.0109306 | Comfortable shoes | 0.0112138 | Beach bag | 0.0111739 |\n",
492
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
493
+ "| Glow sticks | 0.0108928 | Reusable coffee mug | 0.0112081 | Stroller | 0.0111447 |\n",
494
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
495
+ "| Cooler | 0.0108527 | Travel crib | 0.0111724 | Car charger | 0.0110935 |\n",
496
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
497
+ "| Snowboard | 0.0108444 | Baby clothes | 0.0111683 | Neck wallet | 0.0110586 |\n",
498
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
499
+ "| Map | 0.0108275 | Presentation materials | 0.0111555 | Lip balm | 0.0110534 |\n",
500
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
501
+ "| Thermal clothing | 0.0108054 | Baby food | 0.0111165 | Comfortable shoes | 0.0110398 |\n",
502
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
503
+ "| Neck wallet | 0.0107892 | Sleeping bag | 0.0110978 | Poncho | 0.0109919 |\n",
504
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
505
+ "| Water filter | 0.0107846 | Lightweight backpack | 0.011038 | Reusable water bottle | 0.0109792 |\n",
506
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
507
+ "| Travel adapter | 0.0107792 | Gloves | 0.010946 | Energy bars | 0.0109684 |\n",
508
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
509
+ "| Currency | 0.010762 | Portable charger | 0.0108962 | Map | 0.0109623 |\n",
510
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
511
+ "| Nursing cover | 0.0107438 | Trekking poles | 0.0108781 | Hostel lock | 0.0109603 |\n",
512
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
513
+ "| Snow boots | 0.0107272 | Charging cables | 0.0108504 | Power bank | 0.0109483 |\n",
514
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
515
+ "| Pacifier | 0.0107227 | Reusable water bottle | 0.0108255 | Thermal clothing | 0.0109311 |\n",
516
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
517
+ "| Sleeping bag | 0.0106966 | Neck wallet | 0.0108161 | Earplugs | 0.0109061 |\n",
518
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
519
+ "| Car charger | 0.0106959 | Beach bag | 0.0108042 | Charging cables | 0.0108819 |\n",
520
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
521
+ "| Diapers | 0.010669 | Travel-sized toiletries | 0.0107921 | Toys | 0.0108427 |\n",
522
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
523
+ "| Flashlight | 0.0106647 | Travel adapter | 0.0107415 | Ski jacket | 0.0108272 |\n",
524
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
525
+ "| Ski jacket | 0.0106418 | Hostel lock | 0.0106021 | Base layers | 0.0107343 |\n",
526
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
527
+ "| Portable charger | 0.0106307 | Thermal clothing | 0.0105911 | Glow sticks | 0.0106845 |\n",
528
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
529
+ "| Playlist | 0.0106083 | Car charger | 0.0105783 | Beach towel | 0.010634 |\n",
530
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
531
+ "| Swimsuit | 0.0105832 | Ski goggles | 0.0105752 | Water filter | 0.0106173 |\n",
532
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
533
+ "| Tent | 0.010549 | Ski jacket | 0.0105524 | Festival tickets | 0.0106124 |\n",
534
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
535
+ "| Tent | 0.010549 | Water filter | 0.010523 | Dry bag | 0.0105999 |\n",
536
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
537
+ "| SIM card | 0.0105225 | Festival tickets | 0.0105077 | Hat | 0.010555 |\n",
538
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
539
+ "| Compass | 0.0105095 | Dry bag | 0.0104999 | Tent | 0.0105432 |\n",
540
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
541
+ "| Multi-tool | 0.0104697 | Glow sticks | 0.0104861 | Tent | 0.0105432 |\n",
542
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
543
+ "| Hat | 0.0104314 | Beach towel | 0.0104595 | Refillable water bottle | 0.0105226 |\n",
544
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
545
+ "| Base layers | 0.0104078 | Earplugs | 0.0104484 | Language phrasebook | 0.0104878 |\n",
546
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
547
+ "| Energy bars | 0.0103765 | Emergency roadside kit | 0.01042 | Aloe vera gel | 0.0104495 |\n",
548
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
549
+ "| Toys | 0.0103667 | Energy bars | 0.010328 | Compass | 0.0102844 |\n",
550
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
551
+ "| Power bank | 0.010329 | Flip-flops | 0.010279 | Pacifier | 0.0102553 |\n",
552
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
553
+ "| Dry bag | 0.0102989 | Power bank | 0.0102667 | Flip-flops | 0.0102396 |\n",
554
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
555
+ "| Beach towel | 0.0102733 | Base layers | 0.0102346 | Ski goggles | 0.010229 |\n",
556
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
557
+ "| Beach bag | 0.0102259 | Multi-tool | 0.0101584 | Multi-tool | 0.0100441 |\n",
558
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
559
+ "| Poncho | 0.0100634 | Lip balm | 0.0101392 | Gloves | 0.0100095 |\n",
560
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
561
+ "| Headphones | 0.010053 | Snow boots | 0.0101161 | Hand warmers | 0.00999101 |\n",
562
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
563
+ "| Gloves | 0.010049 | Camping stove | 0.00999308 | Camping stove | 0.00982307 |\n",
564
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
565
+ "| Festival tickets | 0.00984126 | Language phrasebook | 0.00958238 | Travel-sized toiletries | 0.0097547 |\n",
566
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
567
+ "| Hand warmers | 0.00967844 | Quick-dry towel | 0.00957849 | Snow boots | 0.00964112 |\n",
568
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n",
569
+ "| Language phrasebook | 0.0093065 | Hand warmers | 0.00916433 | Quick-dry towel | 0.00960495 |\n",
570
+ "+----------------------------+-------------+----------------------------+-----------------+----------------------------+-----------------+\n"
571
+ ]
572
+ }
573
+ ],
574
+ "source": [
575
+ "from tabulate import tabulate\n",
576
+ "\n",
577
+ "\n",
578
+ "# Creating a table\n",
579
+ "table = zip(no_result[\"labels\"], no_result[\"scores\"], \n",
580
+ " simple_result[\"labels\"], simple_result[\"scores\"], \n",
581
+ " primed_result[\"labels\"], primed_result[\"scores\"])\n",
582
+ "headers = [\"no_prompt\", \"no_prompt\", \"simple_prompt\", \"simple_prompt\", \"primed_prompt\", \"primed_prompt\"]\n",
583
+ "\n",
584
+ "print(tabulate(table, headers=headers, tablefmt=\"grid\"))\n"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "id": "5ed9bda0-41f2-4c7c-b055-27c1998c1d4e",
591
+ "metadata": {},
592
+ "outputs": [],
593
+ "source": []
594
+ }
595
+ ],
596
+ "metadata": {
597
+ "kernelspec": {
598
+ "display_name": "Python (huggingface_env)",
599
+ "language": "python",
600
+ "name": "huggingface_env"
601
+ },
602
+ "language_info": {
603
+ "codemirror_mode": {
604
+ "name": "ipython",
605
+ "version": 3
606
+ },
607
+ "file_extension": ".py",
608
+ "mimetype": "text/x-python",
609
+ "name": "python",
610
+ "nbconvert_exporter": "python",
611
+ "pygments_lexer": "ipython3",
612
+ "version": "3.8.20"
613
+ }
614
+ },
615
+ "nbformat": 4,
616
+ "nbformat_minor": 5
617
+ }
space/space/space/space/space/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
space/space/space/space/space/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Packing List
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Receives a trip description and returns a packing list
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
test_data.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "7-Day Island Beach Holiday in Greece (Summer). I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.": [
3
+ "bathing suit", "beach towel", "beach bag", "sandals", "comfortable walking shoes", "light jacket", "sunscreen", "sunglasses", "sunhat", "entertainment for downtime (e.g. book/ebook, games, laptop, journal)", "short pants/skirts", "t-shirts/tops"
4
+ ],
5
+ "3-Day City Trip to Vienna (September). We are a couple in our thirties traveling to Vienna for a three-day city trip. We’ll be staying at a friend’s house and plan to explore the city by sightseeing, strolling through the streets, visiting markets, and trying out great restaurants and cafés. We also hope to attend a classical music concert. Our journey to Vienna will be by train.": [
6
+ "light rain jacket", "comfortable walking shoes", "thin scarf",
7
+ "1 pants", "2 shirts", "1 cardigan/sweater"
8
+ ],
9
+ "8-Day Christmas Trip to Family in the Netherlands and Germany. My partner and I are traveling to the Netherlands and Germany to spend Christmas with our family. We are in our late twenties and will start our journey with a two-hour flight to the Netherlands. From there, we will take a 5.5-hour train ride to northern Germany.": [
10
+ "light rain jacket", "midlayer insulated jacket", "scarf", "hat", "winter shoes", "2 pants", "5 shirts", "2-3 cardigan/sweater", "gifts", "entertainment for downtime (e.g. book/ebook, games, laptop, journal)"
11
+
12
+ ],
13
+ "3-Week Adventure in Peru. I’m in my twenties and will be traveling to Peru for three weeks. I’m going solo but will meet up with a friend to explore the Sacred Valley and take part in a Machu Picchu tour. We plan to hike, go rafting, and explore the remnants of the ancient Inca Empire. We’re also excited to try Peruvian cuisine and immerse ourselves in the local culture. Depending on our plans, we might also visit the rainforest region, such as Tarapoto. I’ll be flying to Peru on a long-haul flight and will be traveling in August.": [
14
+ "light rain jacket", "2 midlayer insulated jackets", "hiking shoes", "comfortable walking shoes", "warm socks", "2-3 cardigans/sweaters", "6 shirts", "3 long pants", "2 short pants", "swimming suit", "travel towel", "malaria medication", "mosquito repellant", "sunscreen", "sunglasses", "sunhat", "local currency"
15
+
16
+ ],
17
+ "We’re planning a 10-day trip to Austria in the summer, combining hiking with relaxation by the lake. We love exploring scenic trails and enjoying the outdoors, but we also want to unwind and swim in the lake. It’s the perfect mix of adventure and relaxation.": [
18
+ "swimming suit", "travel towel", "sunprotection: sunscreen", "sunhat", "sunglasses", "downtime entertainment (e.g. book/ebook, games, laptop, journal)",
19
+ "7 shirts", "3 pants", "light rain jacket", "light jacket", "hiking shoes", "comfortable walking shoes"
20
+
21
+ ]
22
+ }