File size: 85,181 Bytes
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
7be2c11
 
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
7be2c11
 
6e82d28
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
7be2c11
 
 
 
6e82d28
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
7be2c11
 
 
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
7be2c11
6e82d28
 
7be2c11
 
6e82d28
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
 
 
7be2c11
 
 
 
 
 
6e82d28
7be2c11
 
 
6e82d28
7be2c11
 
 
 
 
 
 
6e82d28
7be2c11
 
 
6e82d28
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
 
 
 
 
 
 
 
7be2c11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e82d28
 
 
 
 
 
7be2c11
 
6e82d28
 
 
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
7be2c11
6e82d28
 
7be2c11
6e82d28
 
 
 
7be2c11
 
6e82d28
7be2c11
 
 
 
6e82d28
 
7be2c11
 
6e82d28
 
7be2c11
6e82d28
 
 
 
 
 
7be2c11
 
 
 
 
6e82d28
7be2c11
6e82d28
 
 
 
 
7be2c11
6e82d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e25090fa-f990-4f1a-84f3-b12159eedae8",
   "metadata": {},
   "source": [
    "# Implement and test prediciton model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bbee2e4-55c8-4b06-9929-72026edf7932",
   "metadata": {},
   "source": [
    "## Prerequisites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f8c28d2d-8458-49fd-8ebf-5e729d6e861f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prerequisites\n",
    "from tabulate import tabulate\n",
    "from transformers import pipeline\n",
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "\n",
    "# get candidate labels\n",
    "with open(\"packing_label_structure.json\", \"r\") as file:\n",
    "    candidate_labels = json.load(file)\n",
    "keys_list = list(candidate_labels.keys())\n",
    "\n",
    "# Load test data (in list of dictionaries)\n",
    "with open(\"test_data.json\", \"r\") as file:\n",
    "    packing_data = json.load(file)\n",
    "# Extract all trip descriptions and trip_types\n",
    "trip_descriptions = [trip['description'] for trip in packing_data]\n",
    "trip_types = [trip['trip_types'] for trip in packing_data]\n",
    "\n",
    "# Access the first trip description\n",
    "first_trip = trip_descriptions[0]\n",
    "# Get the packing list for the secondfirst trip\n",
    "first_trip_type = trip_types[0]\n",
    "\n",
    "# print(f\"First trip: {first_trip} \\n\")\n",
    "# print(f\"Trip type: {first_trip_type}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cf4f76f-0035-44e8-93af-52eafaec686e",
   "metadata": {},
   "source": [
    "**All trip descriptions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "89d42ca7-e871-4eda-b428-69e9bd965428",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 . I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands. \n",
      "\n",
      "beach vacation\n",
      "['swimming', 'going to the beach', 'relaxing', 'hiking']\n",
      "warm destination / summer\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "no special conditions to consider\n",
      "7+ days\n",
      "\n",
      "\n",
      "1 . We are a couple in our thirties traveling to Vienna for a three-day city trip. We’ll be staying at a friend’s house and plan to explore the city by sightseeing, strolling through the streets, visiting markets, and trying out great restaurants and cafés. We also hope to attend a classical music concert. Our journey to Vienna will be by train. \n",
      "\n",
      "city trip\n",
      "['sightseeing']\n",
      "variable weather / spring / autumn\n",
      "luxury (including evening wear)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "no special conditions to consider\n",
      "3 days\n",
      "\n",
      "\n",
      "2 . My partner and I are traveling to the Netherlands and Germany to spend Christmas with our family. We are in our late twenties and will start our journey with a two-hour flight to the Netherlands. From there, we will take a 5.5-hour train ride to northern Germany. \n",
      "\n",
      "city trip\n",
      "['relaxing']\n",
      "cold destination / winter\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "no special conditions to consider\n",
      "7+ days\n",
      "\n",
      "\n",
      "3 . I’m in my twenties and will be traveling to Peru for three weeks. I’m going solo but will meet up with a friend to explore the Sacred Valley and take part in a Machu Picchu tour. We plan to hike, go rafting, and explore the remnants of the ancient Inca Empire. We’re also excited to try Peruvian cuisine and immerse ourselves in the local culture. Depending on our plans, we might also visit the rainforest region, such as Tarapoto. I’ll be flying to Peru on a long-haul flight and will be traveling in August. \n",
      "\n",
      "cultural exploration\n",
      "['sightseeing', 'hiking', 'rafting']\n",
      "variable weather / spring / autumn\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "rainy climate\n",
      "7+ days\n",
      "\n",
      "\n",
      "4 . We’re planning a 10-day trip to Austria in the summer, combining hiking with relaxation by the lake. We love exploring scenic trails and enjoying the outdoors, but we also want to unwind and swim in the lake. It’s the perfect mix of adventure and relaxation. \n",
      "\n",
      "nature escape\n",
      "['swimming', 'relaxing', 'hiking']\n",
      "warm destination / summer\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "no special conditions to consider\n",
      "7+ days\n",
      "\n",
      "\n",
      "5 . I am going on a multiple day hike and passing though mountains and the beach in Croatia. I like to pack light and will stay in refugios/huts with half board and travel to the start of the hike by car. It will be 6-7 days. \n",
      "\n",
      "long-distance hike / thru-hike\n",
      "['going to the beach']\n",
      "tropical / humid\n",
      "minimalist\n",
      "casual\n",
      "huts with half board\n",
      "own vehicle\n",
      "off-grid / no electricity\n",
      "6 days\n",
      "\n",
      "\n",
      "6 . I will go with a friend on a beach holiday and we will do stand-up paddling, and surfing in the North of Spain. The destination is windy and can get cold, but is generally sunny. We will go by car and bring a tent to sleep in. It will be two weeks. \n",
      "\n",
      "beach vacation\n",
      "['stand-up paddleboarding (SUP)', 'surfing']\n",
      "cold destination / winter\n",
      "ultralight\n",
      "casual\n",
      "sleeping in a tent\n",
      "own vehicle\n",
      "off-grid / no electricity\n",
      "6 days\n",
      "\n",
      "\n",
      "7 . We will go to Sweden in the winter, to go for a yoga and sauna/wellness retreat. I prefer lightweight packing and also want clothes to go for fancy dinners and maybe on a winter hike. We stay in hotels. \n",
      "\n",
      "yoga / wellness retreat\n",
      "['hiking', 'yoga']\n",
      "cold destination / winter\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "snow and ice\n",
      "7 days\n",
      "\n",
      "\n",
      "8 . I go on a skitouring trip where we also make videos/photos and the destination is Japan. Mainly sports clothes and isolation are needed (it is winter). We stay in a guesthouse. It will be 10 days. \n",
      "\n",
      "ski tour / skitour\n",
      "['ski touring', 'photography', 'swimming']\n",
      "cold destination / winter\n",
      "minimalist\n",
      "conservative\n",
      "indoor\n",
      "no own vehicle\n",
      "avalanche-prone terrain\n",
      "5 days\n",
      "\n",
      "\n",
      "9 . We plan a wild camping trip with activities such as snorkeling, kayaking and canoeing. It is a warm place and we want to bring little stuff. We stay in tents and hammocks and travel with a car, it will be 3 days. \n",
      "\n",
      "camping trip (wild camping)\n",
      "['scuba diving', 'kayaking / canoeing']\n",
      "tropical / humid\n",
      "lightweight (but comfortable)\n",
      "conservative\n",
      "sleeping in a tent\n",
      "own vehicle\n",
      "no special conditions to consider\n",
      "3 days\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i, item in enumerate(trip_descriptions):\n",
    "    print(i, \".\", item, \"\\n\")\n",
    "    for elem in trip_types[i]:\n",
    "        print(elem)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f60c54b-affc-4d9a-acf1-da70f68c5578",
   "metadata": {},
   "source": [
    "**Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fac51224-9575-4b4b-8567-4ad4e759ecc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function that returns pandas data frame with predictions and true values\n",
    "\n",
    "cut_off = 0.5  # used to choose which activities are relevant\n",
    "\n",
    "def pred_trip(model_name, trip_descr, trip_type, cut_off):\n",
    "    classifier = pipeline(\"zero-shot-classification\", model=model_name)\n",
    "    # Create an empty DataFrame with specified columns\n",
    "    df = pd.DataFrame(columns=['superclass', 'pred_class'])\n",
    "    for i, key in enumerate(keys_list):\n",
    "        if key == 'activities':\n",
    "            result = classifier(trip_descr, candidate_labels[key], multi_label=True)\n",
    "            indices = [i for i, score in enumerate(result['scores']) if score > cut_off]\n",
    "            classes = [result['labels'][i] for i in indices]\n",
    "        else:\n",
    "            result = classifier(trip_descr, candidate_labels[key])\n",
    "            classes = result[\"labels\"][0]\n",
    "        print(result)\n",
    "        print(classes)\n",
    "        print(i)\n",
    "        df.loc[i] = [key, classes]\n",
    "    df['true_class'] = trip_type\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b36ab806-2f35-4950-ac5a-7c192190cba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function for accuracy, perc true classes identified and perc wrong pred classes\n",
    "\n",
    "def perf_measure(df):\n",
    "    df['same_value'] = df['pred_class'] == df['true_class']\n",
    "    correct = sum(df.loc[df.index != 1, 'same_value'])\n",
    "    total = len(df['same_value'])\n",
    "    accuracy = correct/total\n",
    "    pred_class = df.loc[df.index == 1, 'pred_class'].iloc[0]\n",
    "    true_class = df.loc[df.index == 1, 'true_class'].iloc[0]\n",
    "    correct = [label for label in pred_class if label in true_class]\n",
    "    num_correct = len(correct)\n",
    "    correct_perc = num_correct/len(true_class)\n",
    "    num_pred = len(pred_class)\n",
    "    if num_pred == 0:\n",
    "        wrong_perc = math.nan\n",
    "    else:\n",
    "        wrong_perc = (num_pred - num_correct)/num_pred\n",
    "    df_perf = pd.DataFrame({\n",
    "    'accuracy': [accuracy],\n",
    "    'true_ident': [correct_perc],\n",
    "    'false_pred': [wrong_perc]\n",
    "    })\n",
    "    return(df_perf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c10aa57d-d7ed-45c7-bdf5-29af193c7fd5",
   "metadata": {},
   "source": [
    "## Make predictions for many models and trip descriptions\n",
    "\n",
    "Provide a list of candidate models and apply them to the test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dd7869a8-b436-40de-9ea0-28eb4b7d3248",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using model: facebook/bart-large-mnli\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 25\u001b[0m\n\u001b[1;32m     23\u001b[0m current_trip \u001b[38;5;241m=\u001b[39m trip_descriptions[i]\n\u001b[1;32m     24\u001b[0m current_type \u001b[38;5;241m=\u001b[39m trip_types[i]\n\u001b[0;32m---> 25\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mpred_trip\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcurrent_trip\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcurrent_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcut_off\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.5\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     26\u001b[0m \u001b[38;5;28mprint\u001b[39m(df)\n\u001b[1;32m     27\u001b[0m \u001b[38;5;66;03m# accuracy, perc true classes identified and perc wrong pred classes\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[3], line 15\u001b[0m, in \u001b[0;36mpred_trip\u001b[0;34m(model_name, trip_descr, trip_type, cut_off)\u001b[0m\n\u001b[1;32m     13\u001b[0m     classes \u001b[38;5;241m=\u001b[39m [result[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m'\u001b[39m][i] \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m indices]\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 15\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43mclassifier\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrip_descr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcandidate_labels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     16\u001b[0m     classes \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28mprint\u001b[39m(result)\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/zero_shot_classification.py:206\u001b[0m, in \u001b[0;36mZeroShotClassificationPipeline.__call__\u001b[0;34m(self, sequences, *args, **kwargs)\u001b[0m\n\u001b[1;32m    203\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    204\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnable to understand extra arguments \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msequences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/base.py:1294\u001b[0m, in \u001b[0;36mPipeline.__call__\u001b[0;34m(self, inputs, num_workers, batch_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1292\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miterate(inputs, preprocess_params, forward_params, postprocess_params)\n\u001b[1;32m   1293\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mframework \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m, ChunkPipeline):\n\u001b[0;32m-> 1294\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1295\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1296\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_iterator\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1297\u001b[0m \u001b[43m                \u001b[49m\u001b[43m[\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreprocess_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforward_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpostprocess_params\u001b[49m\n\u001b[1;32m   1298\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1299\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1300\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1301\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1302\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_single(inputs, preprocess_params, forward_params, postprocess_params)\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py:124\u001b[0m, in \u001b[0;36mPipelineIterator.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    121\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloader_batch_item()\n\u001b[1;32m    123\u001b[0m \u001b[38;5;66;03m# We're out of items within a batch\u001b[39;00m\n\u001b[0;32m--> 124\u001b[0m item \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    125\u001b[0m processed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfer(item, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams)\n\u001b[1;32m    126\u001b[0m \u001b[38;5;66;03m# We now have a batch of \"inferred things\".\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/pt_utils.py:269\u001b[0m, in \u001b[0;36mPipelinePackIterator.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    266\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m accumulator\n\u001b[1;32m    268\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_last:\n\u001b[0;32m--> 269\u001b[0m     processed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minfer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    270\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloader_batch_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    271\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(processed, torch\u001b[38;5;241m.\u001b[39mTensor):\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/base.py:1209\u001b[0m, in \u001b[0;36mPipeline.forward\u001b[0;34m(self, model_inputs, **forward_params)\u001b[0m\n\u001b[1;32m   1207\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m inference_context():\n\u001b[1;32m   1208\u001b[0m         model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ensure_tensor_on_device(model_inputs, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m-> 1209\u001b[0m         model_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mforward_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1210\u001b[0m         model_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ensure_tensor_on_device(model_outputs, device\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m   1211\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/pipelines/zero_shot_classification.py:229\u001b[0m, in \u001b[0;36mZeroShotClassificationPipeline._forward\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m    227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_cache\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m inspect\u001b[38;5;241m.\u001b[39msignature(model_forward)\u001b[38;5;241m.\u001b[39mparameters\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m    228\u001b[0m     model_inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_cache\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 229\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    231\u001b[0m model_outputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    232\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcandidate_label\u001b[39m\u001b[38;5;124m\"\u001b[39m: candidate_label,\n\u001b[1;32m    233\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msequence\u001b[39m\u001b[38;5;124m\"\u001b[39m: sequence,\n\u001b[1;32m    234\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_last\u001b[39m\u001b[38;5;124m\"\u001b[39m: inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis_last\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m    235\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moutputs,\n\u001b[1;32m    236\u001b[0m }\n\u001b[1;32m    237\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model_outputs\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:1763\u001b[0m, in \u001b[0;36mBartForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1758\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1759\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\n\u001b[1;32m   1760\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing input embeddings is currently not supported for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1761\u001b[0m     )\n\u001b[0;32m-> 1763\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1764\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1765\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1766\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1767\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1768\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1769\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1770\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1771\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1772\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1773\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1774\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1775\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1776\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1777\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1778\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1779\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]  \u001b[38;5;66;03m# last hidden state\u001b[39;00m\n\u001b[1;32m   1781\u001b[0m eos_mask \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39meos_token_id)\u001b[38;5;241m.\u001b[39mto(hidden_states\u001b[38;5;241m.\u001b[39mdevice)\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:1528\u001b[0m, in \u001b[0;36mBartModel.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1521\u001b[0m     encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m   1522\u001b[0m         last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m   1523\u001b[0m         hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1524\u001b[0m         attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1525\u001b[0m     )\n\u001b[1;32m   1527\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1528\u001b[0m decoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1530\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1531\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_outputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1532\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1533\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1534\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1535\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1536\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecoder_inputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1537\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1538\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1539\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1540\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1541\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict:\n\u001b[1;32m   1544\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m decoder_outputs \u001b[38;5;241m+\u001b[39m encoder_outputs\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:1380\u001b[0m, in \u001b[0;36mBartDecoder.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask, cross_attn_head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1367\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m   1368\u001b[0m         decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m   1369\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1377\u001b[0m         use_cache,\n\u001b[1;32m   1378\u001b[0m     )\n\u001b[1;32m   1379\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1380\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1381\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1382\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1383\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1384\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1385\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1386\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcross_attn_layer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1387\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mcross_attn_head_mask\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m   1388\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1389\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1390\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1391\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1392\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1393\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1395\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:666\u001b[0m, in \u001b[0;36mBartDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m    664\u001b[0m self_attn_past_key_value \u001b[38;5;241m=\u001b[39m past_key_value[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    665\u001b[0m \u001b[38;5;66;03m# add present self-attn cache to positions 1,2 of present_key_value tuple\u001b[39;00m\n\u001b[0;32m--> 666\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    667\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    668\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_past_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    669\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    670\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    671\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    672\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    673\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(hidden_states, p\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining)\n\u001b[1;32m    674\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py:450\u001b[0m, in \u001b[0;36mBartSdpaAttention.forward\u001b[0;34m(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions)\u001b[0m\n\u001b[1;32m    447\u001b[0m bsz, tgt_len, _ \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39msize()\n\u001b[1;32m    449\u001b[0m \u001b[38;5;66;03m# get query proj\u001b[39;00m\n\u001b[0;32m--> 450\u001b[0m query_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mq_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    451\u001b[0m \u001b[38;5;66;03m# get key, value proj\u001b[39;00m\n\u001b[1;32m    452\u001b[0m \u001b[38;5;66;03m# `past_key_value[0].shape[2] == key_value_states.shape[1]`\u001b[39;00m\n\u001b[1;32m    453\u001b[0m \u001b[38;5;66;03m# is checking that the `sequence_length` of the `past_key_value` is the same as\u001b[39;00m\n\u001b[1;32m    454\u001b[0m \u001b[38;5;66;03m# the provided `key_value_states` to support prefix tuning\u001b[39;00m\n\u001b[1;32m    455\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m    456\u001b[0m     is_cross_attention\n\u001b[1;32m    457\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    458\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m past_key_value[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m==\u001b[39m key_value_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m    459\u001b[0m ):\n\u001b[1;32m    460\u001b[0m     \u001b[38;5;66;03m# reuse k,v, cross_attentions\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/huggingface_env/lib/python3.8/site-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# List of Hugging Face model names\n",
    "model_names = [\n",
    "    \"facebook/bart-large-mnli\",\n",
    "    \"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli\",\n",
    "    ##\"cross-encoder/nli-deberta-v3-base\",\n",
    "    \"cross-encoder/nli-deberta-v3-large\",\n",
    "    \"MoritzLaurer/mDeBERTa-v3-base-mnli-xnli\",\n",
    "    ##\"joeddav/bart-large-mnli-yahoo-answers\",\n",
    "    \"MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli\",\n",
    "    \"MoritzLaurer/deberta-v3-large-zeroshot-v2.0\",\n",
    "    \"valhalla/distilbart-mnli-12-1\",\n",
    "    #\"joeddav/xlm-roberta-large-xnli\" # keeps giving errors\n",
    "]\n",
    "\n",
    "# Apply each model to the test data\n",
    "for model_name in model_names:\n",
    "    print(f\"\\nUsing model: {model_name}\")\n",
    "    result_list = []\n",
    "    performance = pd.DataFrame(columns=['accuracy', 'true_ident', 'false_pred'])\n",
    "    \n",
    "    start_time = time.time()\n",
    "    for i in range(len(trip_descriptions)):\n",
    "        current_trip = trip_descriptions[i]\n",
    "        current_type = trip_types[i]\n",
    "        df = pred_trip(model_name, current_trip, current_type, cut_off = 0.5)\n",
    "        print(df)\n",
    "        # accuracy, perc true classes identified and perc wrong pred classes\n",
    "        performance = pd.concat([performance, perf_measure(df)])\n",
    "        print(performance)\n",
    "        \n",
    "        result_list.append(df)\n",
    "    end_time = time.time()\n",
    "    elapsed_time = end_time - start_time\n",
    "    # Extract \"same_value\" column from each DataFrame\n",
    "    sv_columns = [df['same_value'] for df in result_list]  # 'same' needs to be changed\n",
    "    sv_columns.insert(0, result_list[0]['superclass'])\n",
    "    # Combine into a new DataFrame (columns side-by-side)\n",
    "    sv_df = pd.concat(sv_columns, axis=1)\n",
    "    print(sv_df)\n",
    "    # Compute accuracy per superclass (row means of same_value matrix excluding the first column)\n",
    "    row_means = sv_df.iloc[:, 1:].mean(axis=1)\n",
    "    df_row_means = pd.DataFrame({\n",
    "        'superclass': sv_df['superclass'],\n",
    "        'accuracy': row_means\n",
    "    })\n",
    "    print(df_row_means)\n",
    "    # Compute performance measures per trip (mean for each column of performance table)\n",
    "    column_means = performance.mean()\n",
    "    print(column_means)\n",
    "    # save results\n",
    "    model = model_name.replace(\"/\", \"-\")\n",
    "    model_result = {\n",
    "        'model': model,\n",
    "        'predictions': result_list,\n",
    "        'performance': performance,\n",
    "        'perf_summary': column_means,\n",
    "        'perf_superclass': df_row_means,\n",
    "        'elapsed_time': elapsed_time\n",
    "    }\n",
    "    # File path with folder\n",
    "    filename = os.path.join('results', f'{model}_results.pkl')\n",
    "    # Save the object\n",
    "    with open(filename, 'wb') as f:\n",
    "        pickle.dump(model_result, f)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1cbb54e-abe6-49b6-957e-0683196f3199",
   "metadata": {},
   "source": [
    "## Load and compare results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "62ca82b0-6909-4e6c-9d2c-fed87971e5b6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: cross-encoder-nli-deberta-v3-base\n",
      "Performance Summary:\n",
      "accuracy      0.444444\n",
      "true_ident    0.533333\n",
      "false_pred    0.712500\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: joeddav-bart-large-mnli-yahoo-answers\n",
      "Performance Summary:\n",
      "accuracy      0.344444\n",
      "true_ident    0.650000\n",
      "false_pred    0.553792\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: cross-encoder-nli-deberta-v3-large\n",
      "Performance Summary:\n",
      "accuracy      0.466667\n",
      "true_ident    0.566667\n",
      "false_pred    0.541667\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-DeBERTa-v3-large-mnli-fever-anli-ling-wanli\n",
      "Performance Summary:\n",
      "accuracy      0.566667\n",
      "true_ident    0.841667\n",
      "false_pred    0.546667\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-mDeBERTa-v3-base-mnli-xnli\n",
      "Performance Summary:\n",
      "accuracy      0.466667\n",
      "true_ident    0.408333\n",
      "false_pred    0.481250\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-deberta-v3-large-zeroshot-v2.0\n",
      "Performance Summary:\n",
      "accuracy      0.455556\n",
      "true_ident    0.325000\n",
      "false_pred    0.500000\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: facebook-bart-large-mnli\n",
      "Performance Summary:\n",
      "accuracy      0.466667\n",
      "true_ident    0.708333\n",
      "false_pred    0.400000\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: valhalla-distilbart-mnli-12-1\n",
      "Performance Summary:\n",
      "accuracy      0.522222\n",
      "true_ident    0.300000\n",
      "false_pred    0.533333\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\n",
      "Performance Summary:\n",
      "accuracy      0.522222\n",
      "true_ident    0.841667\n",
      "false_pred    0.572381\n",
      "dtype: float64\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# Folder where .pkl files are saved\n",
    "results_dir = 'results/before'\n",
    "\n",
    "# Dictionary to store all loaded results\n",
    "all_results = {}\n",
    "\n",
    "# Loop through all .pkl files in the folder\n",
    "for filename in os.listdir(results_dir):\n",
    "    if filename.endswith('.pkl'):\n",
    "        model_name = filename.replace('_results.pkl', '')  # Extract model name\n",
    "        file_path = os.path.join(results_dir, filename)\n",
    "        \n",
    "        # Load the result\n",
    "        with open(file_path, 'rb') as f:\n",
    "            result = pickle.load(f)\n",
    "            all_results[model_name] = result\n",
    "\n",
    "# Compare performance across models\n",
    "for model, data in all_results.items():\n",
    "    print(f\"Model: {model}\")\n",
    "    print(f\"Performance Summary:\\n{data['perf_summary']}\")\n",
    "    print(\"-\" * 40)\n",
    "\n",
    "# Compare performance across models\n",
    "#for model, data in all_results.items():\n",
    "#    print(f\"Model: {model}\")\n",
    "#    print(f\"Performance Summary:\\n{data['perf_superclass']}\")\n",
    "#    print(\"-\" * 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37849e0b-864e-4377-b06c-0ac70c3861f9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: cross-encoder-nli-deberta-v3-base\n",
      "Performance Summary:\n",
      "accuracy      0.444444\n",
      "true_ident    0.533333\n",
      "false_pred    0.712500\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: joeddav-bart-large-mnli-yahoo-answers\n",
      "Performance Summary:\n",
      "accuracy      0.355556\n",
      "true_ident    0.650000\n",
      "false_pred    0.553792\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: cross-encoder-nli-deberta-v3-large\n",
      "Performance Summary:\n",
      "accuracy      0.466667\n",
      "true_ident    0.566667\n",
      "false_pred    0.541667\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-DeBERTa-v3-large-mnli-fever-anli-ling-wanli\n",
      "Performance Summary:\n",
      "accuracy      0.611111\n",
      "true_ident    0.841667\n",
      "false_pred    0.546667\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-mDeBERTa-v3-base-mnli-xnli\n",
      "Performance Summary:\n",
      "accuracy      0.455556\n",
      "true_ident    0.408333\n",
      "false_pred    0.481250\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-deberta-v3-large-zeroshot-v2.0\n",
      "Performance Summary:\n",
      "accuracy      0.500\n",
      "true_ident    0.325\n",
      "false_pred    0.500\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: facebook-bart-large-mnli\n",
      "Performance Summary:\n",
      "accuracy      0.466667\n",
      "true_ident    0.708333\n",
      "false_pred    0.400000\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: valhalla-distilbart-mnli-12-1\n",
      "Performance Summary:\n",
      "accuracy      0.500000\n",
      "true_ident    0.300000\n",
      "false_pred    0.533333\n",
      "dtype: float64\n",
      "----------------------------------------\n",
      "Model: MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\n",
      "Performance Summary:\n",
      "accuracy      0.522222\n",
      "true_ident    0.841667\n",
      "false_pred    0.572381\n",
      "dtype: float64\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# Folder where .pkl files are saved\n",
    "results_dir = 'results'\n",
    "\n",
    "# Dictionary to store all loaded results\n",
    "all_results = {}\n",
    "\n",
    "# Loop through all .pkl files in the folder\n",
    "for filename in os.listdir(results_dir):\n",
    "    if filename.endswith('.pkl'):\n",
    "        model_name = filename.replace('_results.pkl', '')  # Extract model name\n",
    "        file_path = os.path.join(results_dir, filename)\n",
    "        \n",
    "        # Load the result\n",
    "        with open(file_path, 'rb') as f:\n",
    "            result = pickle.load(f)\n",
    "            all_results[model_name] = result\n",
    "\n",
    "# Compare performance across models\n",
    "for model, data in all_results.items():\n",
    "    print(f\"Model: {model}\")\n",
    "    print(f\"Performance Summary:\\n{data['perf_summary']}\")\n",
    "    print(\"-\" * 40)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f65e5b1-bc32-42c2-bbe9-9e3a6ffc72c1",
   "metadata": {},
   "source": [
    "**Identify trips that are difficult to predict**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "040055c9-5df4-49b0-921a-5bf98ff01a69",
   "metadata": {},
   "source": [
    "Per model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "57fd150d-1cda-4be5-806b-ef380469243a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cross-encoder-nli-deberta-v3-base: Index([0, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64')\n",
      "\n",
      "joeddav-bart-large-mnli-yahoo-answers: RangeIndex(start=0, stop=10, step=1)\n",
      "\n",
      "cross-encoder-nli-deberta-v3-large: Index([0, 1, 2, 3, 4, 6, 7, 8, 9], dtype='int64')\n",
      "\n",
      "MoritzLaurer-DeBERTa-v3-large-mnli-fever-anli-ling-wanli: Index([2, 3, 5, 6, 7, 8, 9], dtype='int64')\n",
      "\n",
      "MoritzLaurer-mDeBERTa-v3-base-mnli-xnli: RangeIndex(start=0, stop=10, step=1)\n",
      "\n",
      "MoritzLaurer-deberta-v3-large-zeroshot-v2.0: Index([1, 2, 3, 5, 6, 7, 9], dtype='int64')\n",
      "\n",
      "facebook-bart-large-mnli: RangeIndex(start=0, stop=10, step=1)\n",
      "\n",
      "valhalla-distilbart-mnli-12-1: Index([0, 1, 2, 3, 4, 7, 9], dtype='int64')\n",
      "\n",
      "MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli: Index([0, 2, 3, 4, 6, 7], dtype='int64')\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def get_difficult_trips(model_result, cut_off = 0.6):\n",
    "    # model_result is a dict with dict_keys(['model', 'predictions', \n",
    "    # 'performance', 'perf_summary', 'perf_superclass', 'elapsed_time'])\n",
    "    # get performance dataframe and repair index\n",
    "    df = model_result['performance'].reset_index(drop=True)\n",
    "    # find index of trips whose accuracy is below cut_off\n",
    "    index_result = df[df['accuracy'] < cut_off].index\n",
    "    return(index_result)\n",
    "\n",
    "# dictionary of trips that have accuracy below cut_off default\n",
    "difficult_trips_dict = {}\n",
    "for model, data in all_results.items():\n",
    "    difficult_trips_dict[data[\"model\"]] = get_difficult_trips(data)\n",
    "\n",
    "for key, value in difficult_trips_dict.items():\n",
    "    print(f\"{key}: {value}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d91fb932-c5aa-472a-9b8d-a0cfc83a87f8",
   "metadata": {},
   "source": [
    "For all models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a2754cb7-59b9-4f1d-ab74-1bf711b3eba2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 . My partner and I are traveling to the Netherlands and Germany to spend Christmas with our family. We are in our late twenties and will start our journey with a two-hour flight to the Netherlands. From there, we will take a 5.5-hour train ride to northern Germany. \n",
      "\n",
      "city trip\n",
      "['relaxing']\n",
      "cold destination / winter\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "no special conditions to consider\n",
      "7+ days\n",
      "\n",
      "\n",
      "3 . I’m in my twenties and will be traveling to Peru for three weeks. I’m going solo but will meet up with a friend to explore the Sacred Valley and take part in a Machu Picchu tour. We plan to hike, go rafting, and explore the remnants of the ancient Inca Empire. We’re also excited to try Peruvian cuisine and immerse ourselves in the local culture. Depending on our plans, we might also visit the rainforest region, such as Tarapoto. I’ll be flying to Peru on a long-haul flight and will be traveling in August. \n",
      "\n",
      "cultural exploration\n",
      "['sightseeing', 'hiking', 'rafting']\n",
      "variable weather / spring / autumn\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "rainy climate\n",
      "7+ days\n",
      "\n",
      "\n",
      "7 . We will go to Sweden in the winter, to go for a yoga and sauna/wellness retreat. I prefer lightweight packing and also want clothes to go for fancy dinners and maybe on a winter hike. We stay in hotels. \n",
      "\n",
      "yoga / wellness retreat\n",
      "['hiking', 'yoga']\n",
      "cold destination / winter\n",
      "lightweight (but comfortable)\n",
      "casual\n",
      "indoor\n",
      "no own vehicle\n",
      "snow and ice\n",
      "7 days\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Which trips are difficult for all models\n",
    "common = set.intersection(*(set(v) for v in difficult_trips_dict.values()))\n",
    "for index in common:\n",
    "    print(index, \".\", trip_descriptions[index], \"\\n\")\n",
    "    for item in trip_types[index]:\n",
    "        print(item)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be58d66f-a491-4f47-98df-2c0aa4af38e7",
   "metadata": {},
   "source": [
    "**Identify superclasses that are difficult to predict**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e833c2d-9356-4d40-9b20-0a1eb6628a30",
   "metadata": {},
   "source": [
    "Per model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "adb491b1-3ac3-4c32-934f-5eb6171f2ec9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cross-encoder-nli-deberta-v3-base: ['activities', 'climate_or_season', 'style_or_comfort', 'special_conditions']\n",
      "\n",
      "joeddav-bart-large-mnli-yahoo-answers: ['activities', 'climate_or_season', 'style_or_comfort', 'dress_code', 'accommodation', 'transportation', 'special_conditions']\n",
      "\n",
      "cross-encoder-nli-deberta-v3-large: ['activities', 'climate_or_season', 'style_or_comfort', 'transportation', 'special_conditions']\n",
      "\n",
      "MoritzLaurer-DeBERTa-v3-large-mnli-fever-anli-ling-wanli: ['activities', 'style_or_comfort']\n",
      "\n",
      "MoritzLaurer-mDeBERTa-v3-base-mnli-xnli: ['activities', 'style_or_comfort', 'accommodation', 'special_conditions', 'trip_length_days']\n",
      "\n",
      "MoritzLaurer-deberta-v3-large-zeroshot-v2.0: ['activities', 'climate_or_season', 'style_or_comfort', 'accommodation', 'special_conditions']\n",
      "\n",
      "facebook-bart-large-mnli: ['activities', 'style_or_comfort', 'accommodation', 'special_conditions']\n",
      "\n",
      "valhalla-distilbart-mnli-12-1: ['activities', 'climate_or_season', 'style_or_comfort', 'accommodation', 'special_conditions']\n",
      "\n",
      "MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli: ['activities', 'climate_or_season', 'style_or_comfort', 'special_conditions']\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def get_difficult_superclasses(model_result, cut_off = 0.6):\n",
    "    # model_result is a dict with dict_keys(['model', 'predictions', \n",
    "    # 'performance', 'perf_summary', 'perf_superclass', 'elapsed_time'])\n",
    "    df = model_result[\"perf_superclass\"]\n",
    "    # find superclass whose accuracy is below cut_off\n",
    "    diff_spc = list(df[df['accuracy'] < cut_off][\"superclass\"])\n",
    "    return(diff_spc)\n",
    "\n",
    "# make dictionary of superclasses that have accuracy below cut_off default\n",
    "difficult_superclass_dict = {}\n",
    "for model, data in all_results.items():\n",
    "    difficult_superclass_dict[data[\"model\"]] = get_difficult_superclasses(data)\n",
    "\n",
    "for key, value in difficult_superclass_dict.items():\n",
    "    print(f\"{key}: {value}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbcebdf8-0975-45cb-96f5-15b4645aa7f6",
   "metadata": {},
   "source": [
    "For all models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4e51c11b-9a0a-4f9d-b20c-a6feda2d5a3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'style_or_comfort', 'activities'}\n"
     ]
    }
   ],
   "source": [
    "# Which trips are difficult for all models\n",
    "common = set.intersection(*(set(v) for v in difficult_superclass_dict.values()))\n",
    "print(common)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f0e31e2c-e87d-4776-b781-991919492430",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Look at particular predicitons in detail\n",
    "# print(all_results[\"joeddav-bart-large-mnli-yahoo-answers\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01e24355-4aac-4ad6-b50c-96f75585ce45",
   "metadata": {},
   "source": [
    "**Comparing models**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b020f584-1468-4c84-9dac-7ca7fac6e8ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "<class 'list'>\n",
      "   accuracy true_ident false_pred  \\\n",
      "0  0.444444   0.533333     0.7125   \n",
      "1  0.355556       0.65   0.553792   \n",
      "2  0.466667   0.566667   0.541667   \n",
      "3  0.611111   0.841667   0.546667   \n",
      "4  0.455556   0.408333    0.48125   \n",
      "5       0.5      0.325        0.5   \n",
      "6  0.466667   0.708333        0.4   \n",
      "7       0.5        0.3   0.533333   \n",
      "8  0.522222   0.841667   0.572381   \n",
      "\n",
      "                                               model  \n",
      "0                  cross-encoder-nli-deberta-v3-base  \n",
      "1              joeddav-bart-large-mnli-yahoo-answers  \n",
      "2                 cross-encoder-nli-deberta-v3-large  \n",
      "3  MoritzLaurer-DeBERTa-v3-large-mnli-fever-anli-...  \n",
      "4            MoritzLaurer-mDeBERTa-v3-base-mnli-xnli  \n",
      "5        MoritzLaurer-deberta-v3-large-zeroshot-v2.0  \n",
      "6                           facebook-bart-large-mnli  \n",
      "7                      valhalla-distilbart-mnli-12-1  \n",
      "8       MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli  \n"
     ]
    }
   ],
   "source": [
    "# Make table of 'perf_summary' for all models inlcude time elapsed\n",
    "#print(type(all_results))\n",
    "#print(type(all_results[\"MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\"]))\n",
    "#print(all_results[\"MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\"].keys())\n",
    "#print(type(all_results[\"MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\"][\"perf_summary\"]))\n",
    "#print(all_results[\"MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\"][\"perf_summary\"])\n",
    "#print(all_results[\"MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\"][\"perf_summary\"][\"accuracy\"])\n",
    "# make empty data frame\n",
    "perf_table = []\n",
    "print(perf_table)\n",
    "\n",
    "# fill in for loop with perf_summary per model\n",
    "for model, result in all_results.items():\n",
    "    row = pd.DataFrame(result[\"perf_summary\"]).T\n",
    "    #print(row.shape)\n",
    "    row[\"model\"] = model\n",
    "    perf_table.append(row)\n",
    "# Concatenate all into one table\n",
    "df_all = pd.concat(perf_table, ignore_index=True)\n",
    "\n",
    "print(df_all)\n",
    "#print(type(df_all))\n",
    "    \n",
    "\n",
    "# Make ranking from that table for each category\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0f7376bd-a50b-47cc-8055-48a6de5dfee6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n"
     ]
    }
   ],
   "source": [
    "print(type(all_results))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17483df4-55c4-41cd-b8a9-61f7a5c7e8a3",
   "metadata": {},
   "source": [
    "# Use gradio for user input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5bf23e10-0a93-4b2f-9508-34bb0974d24c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prerequisites\n",
    "from transformers import pipeline\n",
    "import json\n",
    "import pandas as pd\n",
    "import gradio as gr\n",
    "\n",
    "# get candidate labels\n",
    "with open(\"packing_label_structure.json\", \"r\") as file:\n",
    "    candidate_labels = json.load(file)\n",
    "keys_list = list(candidate_labels.keys())\n",
    "\n",
    "# Load test data (in list of dictionaries)\n",
    "with open(\"test_data.json\", \"r\") as file:\n",
    "    packing_data = json.load(file)\n",
    "# Extract all trip descriptions and trip_types\n",
    "# trip_descriptions = [trip['description'] for trip in packing_data]\n",
    "# trip_types = [trip['trip_types'] for trip in packing_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "61ebbe99-2563-4c99-ba65-d2312c9d5844",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "Running on public URL: https://91493051ab28db8a5a.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://91493051ab28db8a5a.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    }
   ],
   "source": [
    "# function and gradio app\n",
    "cut_off = 0.5  # used to choose which activities are relevant\n",
    "\n",
    "def classify(model_name, trip_descr, cut_off):\n",
    "    classifier = pipeline(\"zero-shot-classification\", model=model_name)\n",
    "    # Create an empty DataFrame with specified columns\n",
    "    df = pd.DataFrame(columns=['superclass', 'pred_class'])\n",
    "    for i, key in enumerate(keys_list):\n",
    "        if key == 'activities':\n",
    "            result = classifier(trip_descr, candidate_labels[key], multi_label=True)\n",
    "            indices = [i for i, score in enumerate(result['scores']) if score > cut_off]\n",
    "            classes = [result['labels'][i] for i in indices]\n",
    "        else:\n",
    "            result = classifier(trip_descr, candidate_labels[key])\n",
    "            classes = result[\"labels\"][0]\n",
    "        df.loc[i] = [key, classes]\n",
    "    return df\n",
    "\n",
    "demo = gr.Interface(\n",
    "    fn=classify,\n",
    "    inputs=[\n",
    "        gr.Textbox(label=\"Model name\", value = \"facebook/bart-large-mnli\"),\n",
    "        gr.Textbox(label=\"Trip description\"),\n",
    "        gr.Number(label=\"Activity cut-off\", value = 0.5),\n",
    "    ],\n",
    "    outputs=\"dataframe\",\n",
    "    title=\"Trip classification\",\n",
    "    description=\"Enter a text describing your trip\",\n",
    ")\n",
    "\n",
    "# Launch the Gradio app\n",
    "if __name__ == \"__main__\":\n",
    "    demo.launch(share=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (huggingface_env)",
   "language": "python",
   "name": "huggingface_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}