{ "cells": [ { "cell_type": "markdown", "id": "e25090fa-f990-4f1a-84f3-b12159eedae8", "metadata": {}, "source": [ "# Try out gradio" ] }, { "cell_type": "markdown", "id": "afd23321-1870-44af-82ed-bb241d055dfa", "metadata": {}, "source": [] }, { "cell_type": "markdown", "id": "3bbee2e4-55c8-4b06-9929-72026edf7932", "metadata": {}, "source": [ "**Load prerequisites**" ] }, { "cell_type": "code", "execution_count": 57, "id": "f8c28d2d-8458-49fd-8ebf-5e729d6e861f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "First trip: We are a couple in our thirties traveling to Vienna for a three-day city trip. We’ll be staying at a friend’s house and plan to explore the city by sightseeing, strolling through the streets, visiting markets, and trying out great restaurants and cafés. We also hope to attend a classical music concert. Our journey to Vienna will be by train. \n", "\n", "Trip type: ['city trip', ['sightseeing'], 'variable weather / spring / autumn', 'luxury (including evening wear)', 'casual', 'indoor', 'no own vehicle', 'no special condition', '3 days']\n" ] } ], "source": [ "# Prerequisites\n", "from tabulate import tabulate\n", "from transformers import pipeline\n", "import json\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import pickle\n", "import os\n", "import time\n", "\n", "# Load the model and create a pipeline for zero-shot classification (1min loading + classifying with 89 labels)\n", "classifier = pipeline(\"zero-shot-classification\", model=\"cross-encoder/nli-deberta-v3-base\")\n", "model_name = 'model_cross-encoder-nli-deberta-v3-base'\n", "# tried:\n", "# cross-encoder/nli-deberta-v3-large gave error\n", "# MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli\n", "# facebook/bart-large-mnli\n", "# sileod/deberta-v3-base-tasksource-nli\n", "\n", "# get candidate labels\n", "with open(\"packing_label_structure.json\", \"r\") as file:\n", " candidate_labels = json.load(file)\n", "keys_list = list(candidate_labels.keys())\n", "\n", "# Load test data (in list of dictionaries)\n", "with open(\"test_data.json\", \"r\") as file:\n", " packing_data = json.load(file)\n", "# Extract all trip descriptions and trip_types\n", "trip_descriptions = [trip['description'] for trip in packing_data]\n", "trip_types = [trip['trip_types'] for trip in packing_data]\n", "\n", "# Access the first trip description\n", "first_trip = trip_descriptions[1]\n", "# Get the packing list for the secondfirst trip\n", "first_trip_type = trip_types[1]\n", "\n", "print(f\"First trip: {first_trip} \\n\")\n", "print(f\"Trip type: {first_trip_type}\")" ] }, { "cell_type": "code", "execution_count": 58, "id": "3a762755-872d-43a6-b666-874d6133488c", "metadata": {}, "outputs": [], "source": [ "# function that returns pandas data frame with predictions\n", "\n", "cut_off = 0.5 # used to choose which activities are relevant\n", "\n", "def pred_trip(trip_descr, trip_type, cut_off):\n", " # Create an empty DataFrame with specified columns\n", " df = pd.DataFrame(columns=['superclass', 'pred_class'])\n", " for i, key in enumerate(keys_list):\n", " if key == 'activities':\n", " result = classifier(trip_descr, candidate_labels[key], multi_label=True)\n", " indices = [i for i, score in enumerate(result['scores']) if score > cut_off]\n", " classes = [result['labels'][i] for i in indices]\n", " else:\n", " result = classifier(trip_descr, candidate_labels[key])\n", " classes = result[\"labels\"][0]\n", " print(result)\n", " print(classes)\n", " print(i)\n", " df.loc[i] = [key, classes]\n", " df['true_class'] = trip_type\n", " return df" ] }, { "cell_type": "code", "execution_count": 59, "id": "3b4f3193-3bdd-453c-8664-df84f955600c", "metadata": {}, "outputs": [], "source": [ "# function for accuracy, perc true classes identified and perc wrong pred classes\n", "\n", "def perf_measure(df):\n", " df['same_value'] = df['pred_class'] == df['true_class']\n", " correct = sum(df.loc[df.index != 1, 'same_value'])\n", " total = len(df['same_value'])\n", " accuracy = correct/total\n", " pred_class = df.loc[df.index == 1, 'pred_class'].iloc[0]\n", " true_class = df.loc[df.index == 1, 'true_class'].iloc[0]\n", " correct = [label for label in pred_class if label in true_class]\n", " num_correct = len(correct)\n", " correct_perc = num_correct/len(true_class)\n", " num_pred = len(pred_class)\n", " wrong_perc = (num_pred - num_correct)/num_pred\n", " df_perf = pd.DataFrame({\n", " 'accuracy': [accuracy],\n", " 'true_ident': [correct_perc],\n", " 'false_pred': [wrong_perc]\n", " })\n", " return(df_perf)" ] }, { "cell_type": "markdown", "id": "62c5c18c-58f4-465c-a188-c57cfa7ffa90", "metadata": {}, "source": [ "**Now do the same for all trips**" ] }, { "cell_type": "code", "execution_count": 60, "id": "4dd01755-be8d-4904-8494-ac28aba2fee7", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['micro-adventure / weekend trip', 'digital nomad trip', 'beach vacation', 'festival trip', 'city trip', 'cultural exploration', 'road trip (car/camper)', 'camping trip (wild camping)', 'long-distance hike / thru-hike', 'hut trek (winter)', 'ski tour / skitour', 'snowboard / splitboard trip', 'nature escape', 'yoga / wellness retreat', 'hut trek (summer)', 'camping trip (campground)'], 'scores': [0.9722680449485779, 0.007802918087691069, 0.0075571718625724316, 0.0022959215566515923, 0.0021305829286575317, 0.001222927705384791, 0.0009879637509584427, 0.000805296644102782, 0.0007946204277686775, 0.0007107199053280056, 0.0007009899127297103, 0.0006353880744427443, 0.0005838185315951705, 0.0005424902774393559, 0.0004807499353773892, 0.0004804217896889895]}\n", "micro-adventure / weekend trip\n", "0\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['going to the beach', 'sightseeing', 'relaxing', 'hiking', 'hut-to-hut hiking', 'stand-up paddleboarding (SUP)', 'photography', 'biking', 'running', 'ski touring', 'snowshoe hiking', 'yoga', 'kayaking / canoeing', 'horseback riding', 'rafting', 'paragliding', 'cross-country skiing', 'surfing', 'skiing', 'ice climbing', 'fishing', 'snorkeling', 'swimming', 'rock climbing', 'scuba diving'], 'scores': [0.4660525321960449, 0.007281942293047905, 0.003730606520548463, 0.0001860307966126129, 0.00014064949937164783, 0.00011034693307010457, 5.2949126256862655e-05, 3.828677654382773e-05, 3.396756437723525e-05, 1.5346524378401227e-05, 9.348185812996235e-06, 8.182429155567661e-06, 6.5973340497293975e-06, 6.271920938161202e-06, 5.544673058466287e-06, 5.299102667777333e-06, 4.855380211665761e-06, 4.506250661506783e-06, 3.949530764657538e-06, 3.730233856913401e-06, 3.297281637060223e-06, 3.0508665531669976e-06, 2.933618134193239e-06, 2.6379277642263332e-06, 2.2992651338427095e-06]}\n", "[]\n", "1\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['variable weather / spring / autumn', 'warm destination / summer', 'cold destination / winter', 'dry / desert-like', 'tropical / humid'], 'scores': [0.5934922695159912, 0.17430798709392548, 0.10943299531936646, 0.07068652659654617, 0.05208020657300949]}\n", "variable weather / spring / autumn\n", "2\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['minimalist', 'ultralight', 'luxury (including evening wear)', 'lightweight (but comfortable)'], 'scores': [0.6965053081512451, 0.11270010471343994, 0.10676420480012894, 0.08403033763170242]}\n", "minimalist\n", "3\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['casual', 'formal (business trip)', 'conservative'], 'scores': [0.6362482309341431, 0.22082458436489105, 0.14292724430561066]}\n", "casual\n", "4\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['indoor', 'sleeping in a tent', 'huts with half board', 'sleeping in a car'], 'scores': [0.435793399810791, 0.20242486894130707, 0.19281964004039764, 0.16896207630634308]}\n", "indoor\n", "5\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['no own vehicle', 'own vehicle'], 'scores': [0.9987181425094604, 0.0012818538816645741]}\n", "no own vehicle\n", "6\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['self-supported (bring your own food/cooking)', 'no special conditions', 'off-grid / no electricity', 'rainy climate', 'child-friendly', 'snow and ice', 'pet-friendly', 'high alpine terrain', 'avalanche-prone terrain'], 'scores': [0.1984991431236267, 0.1695038080215454, 0.16221018135547638, 0.13200421631336212, 0.12101645022630692, 0.10550825297832489, 0.042406272143125534, 0.03797775134444237, 0.030873913317918777]}\n", "self-supported (bring your own food/cooking)\n", "7\n", "{'sequence': 'I am planning a trip to Greece with my boyfriend, where we will visit two islands. We have booked an apartment on each island for a few days and plan to spend most of our time relaxing. Our main goals are to enjoy the beach, try delicious local food, and possibly go on a hike—if it’s not too hot. We will be relying solely on public transport. We’re in our late 20s and traveling from the Netherlands.', 'labels': ['7+ days', '2 days', '1 day', '7 days', '5 days', '3 days', '6 days', '4 days'], 'scores': [0.4730822443962097, 0.1168912723660469, 0.10058756172657013, 0.0991850346326828, 0.05424537882208824, 0.053677864372730255, 0.051554784178733826, 0.050775907933712006]}\n", "7+ days\n", "8\n", " superclass pred_class \\\n", "0 activity_type micro-adventure / weekend trip \n", "1 activities [] \n", "2 climate_or_season variable weather / spring / autumn \n", "3 style_or_comfort minimalist \n", "4 dress_code casual \n", "5 accommodation indoor \n", "6 transportation no own vehicle \n", "7 special_conditions self-supported (bring your own food/cooking) \n", "8 trip_length_days 7+ days \n", "\n", " true_class \n", "0 beach vacation \n", "1 [swimming, going to the beach, relaxing, hiking] \n", "2 warm destination / summer \n", "3 lightweight (but comfortable) \n", "4 casual \n", "5 indoor \n", "6 no own vehicle \n", "7 no special conditions \n", "8 7+ days \n" ] }, { "ename": "ZeroDivisionError", "evalue": "division by zero", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mZeroDivisionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[60], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(df)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# accuracy, perc true classes identified and perc wrong pred classes\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m performance \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat([performance, \u001b[43mperf_measure\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf\u001b[49m\u001b[43m)\u001b[49m])\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(performance)\n\u001b[1;32m 16\u001b[0m result_list\u001b[38;5;241m.\u001b[39mappend(df)\n", "Cell \u001b[0;32mIn[59], line 14\u001b[0m, in \u001b[0;36mperf_measure\u001b[0;34m(df)\u001b[0m\n\u001b[1;32m 12\u001b[0m correct_perc \u001b[38;5;241m=\u001b[39m num_correct\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(true_class)\n\u001b[1;32m 13\u001b[0m num_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(pred_class)\n\u001b[0;32m---> 14\u001b[0m wrong_perc \u001b[38;5;241m=\u001b[39m \u001b[43m(\u001b[49m\u001b[43mnum_pred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnum_correct\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43mnum_pred\u001b[49m\n\u001b[1;32m 15\u001b[0m df_perf \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame({\n\u001b[1;32m 16\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m: [accuracy],\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrue_ident\u001b[39m\u001b[38;5;124m'\u001b[39m: [correct_perc],\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfalse_pred\u001b[39m\u001b[38;5;124m'\u001b[39m: [wrong_perc]\n\u001b[1;32m 19\u001b[0m })\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m(df_perf)\n", "\u001b[0;31mZeroDivisionError\u001b[0m: division by zero" ] } ], "source": [ "result_list = []\n", "performance = pd.DataFrame(columns=['accuracy', 'true_ident', 'false_pred'])\n", "\n", "start_time = time.time()\n", "\n", "for i in range(len(trip_descriptions)):\n", " current_trip = trip_descriptions[i]\n", " current_type = trip_types[i]\n", " df = pred_trip(current_trip, current_type, cut_off = 0.5)\n", " print(df)\n", " \n", " # accuracy, perc true classes identified and perc wrong pred classes\n", " performance = pd.concat([performance, perf_measure(df)])\n", " print(performance)\n", " \n", " result_list.append(df)\n", "\n", "end_time = time.time()\n", "\n", "elapsed_time = end_time - start_time" ] }, { "cell_type": "markdown", "id": "b5c08703-7166-4d03-9d6b-ee2c12608134", "metadata": {}, "source": [ "**Compute average performance measures**" ] }, { "cell_type": "code", "execution_count": 61, "id": "eb33fd31-94e6-40b5-9c36-a32effe77c01", "metadata": {}, "outputs": [ { "ename": "IndexError", "evalue": "list index out of range", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[61], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Extract \"same_value\" column from each DataFrame\u001b[39;00m\n\u001b[1;32m 2\u001b[0m sv_columns \u001b[38;5;241m=\u001b[39m [df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msame_value\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m df \u001b[38;5;129;01min\u001b[39;00m result_list] \u001b[38;5;66;03m# 'same' needs to be changed\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m sv_columns\u001b[38;5;241m.\u001b[39minsert(\u001b[38;5;241m0\u001b[39m, \u001b[43mresult_list\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msuperclass\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Combine into a new DataFrame (columns side-by-side)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m sv_df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat(sv_columns, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", "\u001b[0;31mIndexError\u001b[0m: list index out of range" ] } ], "source": [ "# Extract \"same_value\" column from each DataFrame\n", "sv_columns = [df['same_value'] for df in result_list] # 'same' needs to be changed\n", "sv_columns.insert(0, result_list[0]['superclass'])\n", "\n", "# Combine into a new DataFrame (columns side-by-side)\n", "sv_df = pd.concat(sv_columns, axis=1)\n", "\n", "print(sv_df)" ] }, { "cell_type": "code", "execution_count": 62, "id": "bf7546cb-79ce-49ad-8cee-54d02239220c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " superclass accuracy\n", "0 activity_type 0.8\n", "1 activities 0.0\n", "2 climate_or_season 0.5\n", "3 style_or_comfort 0.3\n", "4 dress_code 0.8\n", "5 accommodation 0.8\n", "6 transportation 0.7\n", "7 special_conditions 0.2\n", "8 trip_length_days 0.6\n" ] } ], "source": [ "# Compute accuracy per superclass (row means of same_value matrix excluding the first column)\n", "row_means = sv_df.iloc[:, 1:].mean(axis=1)\n", "\n", "df_row_means = pd.DataFrame({\n", " 'superclass': sv_df['superclass'],\n", " 'accuracy': row_means\n", "})\n", "\n", "print(df_row_means)" ] }, { "cell_type": "code", "execution_count": null, "id": "fd232953-59e8-4f28-9ce8-11515a2c310b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Compute performance measures per trip (mean for each column of performance table)\n", "column_means = performance.mean()\n", "print(column_means)\n", "\n", "# Plot histograms for all numeric columns\n", "performance.hist(bins=10, figsize=(10, 6))\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "bd682c84-3eb1-4a8d-9621-b741e98e4537", "metadata": {}, "outputs": [], "source": [ "# save results\n", "# Structure to save\n", "model_result = {\n", " 'model': model_name,\n", " 'predictions': result_list,\n", " 'performance': performance,\n", " 'perf_summary': column_means,\n", " 'perf_superclass': df_row_means,\n", " 'elapsed_time': elapsed_time\n", "}\n", "\n", "# File path with folder\n", "filename = os.path.join('results', f'{model_name}_results.pkl')\n", "\n", "# Save the object\n", "with open(filename, 'wb') as f:\n", " pickle.dump(model_result, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "f38d0924-30b6-43cd-9bfc-fe5b0dc80411", "metadata": {}, "outputs": [], "source": [ "print(elapsed_time/60)" ] }, { "cell_type": "markdown", "id": "e1cbb54e-abe6-49b6-957e-0683196f3199", "metadata": {}, "source": [ "**Load and compare results**" ] }, { "cell_type": "code", "execution_count": 54, "id": "62ca82b0-6909-4e6c-9d2c-fed87971e5b6", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: model_MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\n", "Performance Summary:\n", "accuracy 0.522222\n", "true_ident 0.841667\n", "false_pred 0.572381\n", "dtype: float64\n", "----------------------------------------\n", "Model: model_a_facebook-bart-large-mnli\n", "Performance Summary:\n", "accuracy 0.454545\n", "true_ident 0.689394\n", "false_pred 0.409091\n", "dtype: float64\n", "----------------------------------------\n", "Model: model_b_sileod-deberta-v3-base-tasksource-nli\n", "Performance Summary:\n", "accuracy 0.500000\n", "true_ident 0.666667\n", "false_pred 0.551667\n", "dtype: float64\n", "----------------------------------------\n", "Model: model_MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\n", "Performance Summary:\n", " superclass accuracy\n", "0 activity_type 0.8\n", "1 activities 0.0\n", "2 climate_or_season 0.5\n", "3 style_or_comfort 0.3\n", "4 dress_code 0.8\n", "5 accommodation 0.8\n", "6 transportation 0.7\n", "7 special_conditions 0.2\n", "8 trip_length_days 0.6\n", "----------------------------------------\n", "Model: model_a_facebook-bart-large-mnli\n", "Performance Summary:\n", " superclass accuracy\n", "0 activity_type 0.8\n", "1 activities 0.0\n", "2 climate_or_season 0.6\n", "3 style_or_comfort 0.4\n", "4 dress_code 0.7\n", "5 accommodation 0.3\n", "6 transportation 0.8\n", "7 special_conditions 0.0\n", "8 trip_length_days 0.5\n", "----------------------------------------\n", "Model: model_b_sileod-deberta-v3-base-tasksource-nli\n", "Performance Summary:\n", " superclass accuracy\n", "0 activity_type 0.7\n", "1 activities 0.1\n", "2 climate_or_season 0.6\n", "3 style_or_comfort 0.4\n", "4 dress_code 0.6\n", "5 accommodation 0.9\n", "6 transportation 0.7\n", "7 special_conditions 0.1\n", "8 trip_length_days 0.5\n", "----------------------------------------\n" ] } ], "source": [ "# Folder where your .pkl files are saved\n", "results_dir = 'results'\n", "\n", "# Dictionary to store all loaded results\n", "all_results = {}\n", "\n", "# Loop through all .pkl files in the folder\n", "for filename in os.listdir(results_dir):\n", " if filename.endswith('.pkl'):\n", " model_name = filename.replace('_results.pkl', '') # Extract model name\n", " file_path = os.path.join(results_dir, filename)\n", " \n", " # Load the result\n", " with open(file_path, 'rb') as f:\n", " result = pickle.load(f)\n", " all_results[model_name] = result\n", "\n", "# Compare performance across models\n", "for model, data in all_results.items():\n", " print(f\"Model: {model}\")\n", " print(f\"Performance Summary:\\n{data['perf_summary']}\")\n", " print(\"-\" * 40)\n", "\n", "\n", "# Compare performance across models\n", "for model, data in all_results.items():\n", " print(f\"Model: {model}\")\n", " print(f\"Performance Summary:\\n{data['perf_superclass']}\")\n", " print(\"-\" * 40)" ] }, { "cell_type": "code", "execution_count": 69, "id": "57fd150d-1cda-4be5-806b-ef380469243a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: model_MoritzLaurer-DeBERTa-v3-base-mnli-fever-anli\n", "Time in minutes for 10 trips:\n", "83.45150986512502\n", "----------------------------------------\n", "Model: model_a_facebook-bart-large-mnli\n" ] }, { "ename": "KeyError", "evalue": "'elapsed_time'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[69], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model, data \u001b[38;5;129;01min\u001b[39;00m all_results\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTime in minutes for 10 trips:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mdata[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124melapsed_time\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m40\u001b[39m)\n", "\u001b[0;31mKeyError\u001b[0m: 'elapsed_time'" ] } ], "source": [ "# Compare across models\n", "for model, data in all_results.items():\n", " print(f\"Model: {model}\")\n", " print(f\"Time in minutes for 10 trips:\\n{data['elapsed_time']/60}\")\n", " print(\"-\" * 40)" ] }, { "cell_type": "markdown", "id": "17483df4-55c4-41cd-b8a9-61f7a5c7e8a3", "metadata": {}, "source": [ "**Use gradio for user input**" ] }, { "cell_type": "code", "execution_count": 66, "id": "cb7fd425-d0d6-458d-97ca-2150dc55f206", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "Running on public URL: https://aa06d5d85ffadaa92b.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "