Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel | |
| import torch | |
| import itertools | |
| import os | |
| import plotly.graph_objects as go | |
| import hashlib | |
| from PIL import Image | |
| import json | |
| import random | |
| os.environ["PYTHONHASHSEED"] = "42" | |
| CUDA_AVAILABLE = torch.cuda.is_available() | |
| print(f"CUDA={CUDA_AVAILABLE}") | |
| device = "cuda" if CUDA_AVAILABLE else "cpu" | |
| if CUDA_AVAILABLE: | |
| print(f"count={torch.cuda.device_count()}") | |
| print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| continent_model = CLIPModel.from_pretrained( | |
| "jrheiner/thesis-clip-geoloc-continent", | |
| token=os.getenv("token"), | |
| ) | |
| country_model = CLIPModel.from_pretrained( | |
| "jrheiner/thesis-clip-geoloc-country", | |
| token=os.getenv("token"), | |
| ) | |
| processor = CLIPProcessor.from_pretrained( | |
| "jrheiner/thesis-clip-geoloc-continent", | |
| token=os.getenv("token"), | |
| ) | |
| continent_model = continent_model.to(device) | |
| country_model = country_model.to(device) | |
| continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"] | |
| countries_per_continent = { | |
| "Africa": [ | |
| "Botswana", | |
| "Eswatini", | |
| "Ghana", | |
| "Kenya", | |
| "Lesotho", | |
| "Nigeria", | |
| "Senegal", | |
| "South Africa", | |
| "Rwanda", | |
| "Uganda", | |
| "Tanzania", | |
| "Madagascar", | |
| "Djibouti", | |
| "Mali", | |
| "Libya", | |
| "Morocco", | |
| "Somalia", | |
| "Tunisia", | |
| "Egypt", | |
| "Réunion", | |
| ], | |
| "Asia": [ | |
| "Bangladesh", | |
| "Bhutan", | |
| "Cambodia", | |
| "China", | |
| "India", | |
| "Indonesia", | |
| "Israel", | |
| "Japan", | |
| "Jordan", | |
| "Kyrgyzstan", | |
| "Laos", | |
| "Malaysia", | |
| "Mongolia", | |
| "Nepal", | |
| "Palestine", | |
| "Philippines", | |
| "Singapore", | |
| "South Korea", | |
| "Sri Lanka", | |
| "Taiwan", | |
| "Thailand", | |
| "United Arab Emirates", | |
| "Vietnam", | |
| "Afghanistan", | |
| "Azerbaijan", | |
| "Cyprus", | |
| "Iran", | |
| "Syria", | |
| "Tajikistan", | |
| "Turkey", | |
| "Russia", | |
| "Pakistan", | |
| "Hong Kong", | |
| ], | |
| "Europe": [ | |
| "Albania", | |
| "Andorra", | |
| "Austria", | |
| "Belgium", | |
| "Bulgaria", | |
| "Croatia", | |
| "Czechia", | |
| "Denmark", | |
| "Estonia", | |
| "Finland", | |
| "France", | |
| "Germany", | |
| "Greece", | |
| "Hungary", | |
| "Iceland", | |
| "Ireland", | |
| "Italy", | |
| "Latvia", | |
| "Lithuania", | |
| "Luxembourg", | |
| "Montenegro", | |
| "Netherlands", | |
| "North Macedonia", | |
| "Norway", | |
| "Poland", | |
| "Portugal", | |
| "Romania", | |
| "Russia", | |
| "Serbia", | |
| "Slovakia", | |
| "Slovenia", | |
| "Spain", | |
| "Sweden", | |
| "Switzerland", | |
| "Ukraine", | |
| "United Kingdom", | |
| "Bosnia and Herzegovina", | |
| "Cyprus", | |
| "Turkey", | |
| "Greenland", | |
| "Faroe Islands", | |
| ], | |
| "North America": [ | |
| "Canada", | |
| "Dominican Republic", | |
| "Guatemala", | |
| "Mexico", | |
| "United States", | |
| "Bahamas", | |
| "Cuba", | |
| "Panama", | |
| "Puerto Rico", | |
| "Bermuda", | |
| "Greenland", | |
| ], | |
| "Oceania": [ | |
| "Australia", | |
| "New Zealand", | |
| "Fiji", | |
| "Papua New Guinea", | |
| "Solomon Islands", | |
| "Vanuatu", | |
| ], | |
| "South America": [ | |
| "Argentina", | |
| "Bolivia", | |
| "Brazil", | |
| "Chile", | |
| "Colombia", | |
| "Ecuador", | |
| "Paraguay", | |
| "Peru", | |
| "Uruguay", | |
| ], | |
| } | |
| countries = list(set(itertools.chain.from_iterable(countries_per_continent.values()))) | |
| country_to_center_coords = { | |
| "Indonesia": (-2.4833826, 117.8902853), | |
| "Egypt": (26.2540493, 29.2675469), | |
| "Dominican Republic": (19.0974031, -70.3028026), | |
| "Russia": (64.6863136, 97.7453061), | |
| "Denmark": (55.670249, 10.3333283), | |
| "Latvia": (56.8406494, 24.7537645), | |
| "Hong Kong": (22.350627, 114.1849161), | |
| "Brazil": (-10.3333333, -53.2), | |
| "Turkey": (38.9597594, 34.9249653), | |
| "Paraguay": (-23.3165935, -58.1693445), | |
| "Nigeria": (9.6000359, 7.9999721), | |
| "United Kingdom": (54.7023545, -3.2765753), | |
| "Argentina": (-34.9964963, -64.9672817), | |
| "United Arab Emirates": (24.0002488, 53.9994829), | |
| "Estonia": (58.7523778, 25.3319078), | |
| "Greenland": (69.6354163, -42.1736914), | |
| "Canada": (61.0666922, -107.991707), | |
| "Andorra": (42.5407167, 1.5732033), | |
| "Czechia": (49.7439047, 15.3381061), | |
| "Australia": (-24.7761086, 134.755), | |
| "Azerbaijan": (40.3936294, 47.7872508), | |
| "Cambodia": (12.5433216, 104.8144914), | |
| "Peru": (-6.8699697, -75.0458515), | |
| "Slovakia": (48.7411522, 19.4528646), | |
| "Réunion": (-21.130737949999997, 55.536480112992315), | |
| "France": (46.603354, 1.8883335), | |
| "Israel": (30.8124247, 34.8594762), | |
| "China": (35.000074, 104.999927), | |
| "Ecuador": (-1.3397668, -79.3666965), | |
| "Poland": (52.215933, 19.134422), | |
| "Switzerland": (46.7985624, 8.2319736), | |
| "Singapore": (1.357107, 103.8194992), | |
| "Kenya": (1.4419683, 38.4313975), | |
| "Bhutan": (27.549511, 90.5119273), | |
| "Laos": (20.0171109, 103.378253), | |
| "Vietnam": (15.9266657, 107.9650855), | |
| "Puerto Rico": (18.2247706, -66.4858295), | |
| "Germany": (51.1638175, 10.4478313), | |
| "Tanzania": (-6.5247123, 35.7878438), | |
| "Colombia": (4.099917, -72.9088133), | |
| "Italy": (42.6384261, 12.674297), | |
| "Bahamas": (24.7736546, -78.0000547), | |
| "Panama": (8.559559, -81.1308434), | |
| "Bulgaria": (42.6073975, 25.4856617), | |
| "Solomon Islands": (-8.7053941, 159.1070693851845), | |
| "Afghanistan": (33.7680065, 66.2385139), | |
| "Tajikistan": (38.6281733, 70.8156541), | |
| "Portugal": (39.6621648, -8.1353519), | |
| "Tunisia": (36.8002068, 10.1857757), | |
| "Bolivia": (-17.0568696, -64.9912286), | |
| "Malaysia": (4.5693754, 102.2656823), | |
| "Lithuania": (55.3500003, 23.7499997), | |
| "Sweden": (59.6749712, 14.5208584), | |
| "Belgium": (50.6402809, 4.6667145), | |
| "Libya": (26.8234472, 18.1236723), | |
| "Guatemala": (15.5855545, -90.345759), | |
| "India": (22.3511148, 78.6677428), | |
| "Sri Lanka": (7.5554942, 80.7137847), | |
| "New Zealand": (-41.5000831, 172.8344077), | |
| "Iceland": (64.9841821, -18.1059013), | |
| "Somalia": (8.3676771, 49.083416), | |
| "Croatia": (45.3658443, 15.6575209), | |
| "Bosnia and Herzegovina": (44.3053476, 17.5961467), | |
| "Greece": (38.9953683, 21.9877132), | |
| "Rwanda": (-1.9646631, 30.0644358), | |
| "Hungary": (47.1817585, 19.5060937), | |
| "Eswatini": (-26.5624806, 31.3991317), | |
| "Kyrgyzstan": (41.5089324, 74.724091), | |
| "Bangladesh": (23.6943117, 90.344352), | |
| "Morocco": (28.3347722, -10.371337908392647), | |
| "Finland": (63.2467777, 25.9209164), | |
| "Luxembourg": (49.6112768, 6.129799), | |
| "North Macedonia": (41.6171214, 21.7168387), | |
| "Uruguay": (-32.8755548, -56.0201525), | |
| "Chile": (-31.7613365, -71.3187697), | |
| "Spain": (39.3260685, -4.8379791), | |
| "South Korea": (36.638392, 127.6961188), | |
| "Botswana": (-23.1681782, 24.5928742), | |
| "Uganda": (1.5333554, 32.2166578), | |
| "Papua New Guinea": (-5.6816069, 144.2489081), | |
| "Mali": (16.3700359, -2.2900239), | |
| "Philippines": (12.7503486, 122.7312101), | |
| "Norway": (64.5731537, 11.52803643954819), | |
| "Thailand": (14.8971921, 100.83273), | |
| "Mongolia": (46.8651082, 103.8347844), | |
| "Japan": (36.5748441, 139.2394179), | |
| "Montenegro": (42.7044223, 19.3957785), | |
| "Austria": (47.59397, 14.12456), | |
| "Taiwan": (23.6978, 120.9605), | |
| "Netherlands": (52.2434979, 5.6343227), | |
| "Ukraine": (49.4871968, 31.2718321), | |
| "Fiji": (-18.1239696, 179.0122737), | |
| "Ghana": (8.0300284, -1.0800271), | |
| "Cuba": (23.0131338, -80.8328748), | |
| "Nepal": (28.3780464, 83.9999901), | |
| "Faroe Islands": (62.0448724, -7.0322972), | |
| "Slovenia": (46.1199444, 14.8153333), | |
| "Cyprus": (34.9174159, 32.889902651331866), | |
| "Serbia": (44.024322850000004, 21.07657433209902), | |
| "Madagascar": (-18.9249604, 46.4416422), | |
| "Pakistan": (30.3308401, 71.247499), | |
| "Syria": (34.6401861, 39.0494106), | |
| "Iran": (32.6475314, 54.5643516), | |
| "Ireland": (52.865196, -7.9794599), | |
| "South Africa": (-28.8166236, 24.991639), | |
| "Albania": (41.1529058, 20.1605717), | |
| "Lesotho": (-29.6039267, 28.3350193), | |
| "Romania": (45.9852129, 24.6859225), | |
| "Palestine": (31.947351, 35.227163), | |
| "Vanuatu": (-16.5255069, 168.1069154), | |
| "Mexico": (19.4326296, -99.1331785), | |
| "Jordan": (31.279862, 37.1297454), | |
| "Djibouti": (11.8145966, 42.8453061), | |
| "Senegal": (14.4750607, -14.4529612), | |
| "Bermuda": (32.3040273, -64.7563086), | |
| "United States": (39.7837304, -100.445882), | |
| } | |
| def predict(input_img): | |
| inputs = processor( | |
| text=[f"A photo from {geo}." for geo in continents], | |
| images=input_img, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = continent_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=-1) | |
| pred_id = probs.argmax().cpu().item() | |
| continent_probs = { | |
| label: prob for label, prob in zip(continents, probs.tolist()[0]) | |
| } | |
| model_continent = continents[pred_id] | |
| predicted_continent_countries = countries_per_continent[model_continent] | |
| inputs = processor( | |
| text=[f"A photo from {geo}." for geo in predicted_continent_countries], | |
| images=input_img, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = country_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=-1) | |
| pred_id = probs.argmax().cpu().item() | |
| model_country = predicted_continent_countries[pred_id] | |
| country_probs = { | |
| label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0]) | |
| } | |
| hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest() | |
| metadata_block = gr.Accordion(visible=False) | |
| metadata_map = None | |
| if hash in EXAMPLE_METADATA.keys(): | |
| model_result = "" | |
| if ( | |
| model_continent == EXAMPLE_METADATA[hash]["continent"] | |
| and model_country == EXAMPLE_METADATA[hash]["country"] | |
| ): | |
| model_result = "The AI 🤖 correctly guessed continent and country ✅ ✅." | |
| elif model_continent == EXAMPLE_METADATA[hash]["continent"]: | |
| model_result = "The AI 🤖 only guessed the correct continent ❌ ✅." | |
| elif ( | |
| model_country == EXAMPLE_METADATA[hash]["country"] | |
| and model_continent != EXAMPLE_METADATA[hash]["continent"] | |
| ): | |
| model_result = "The AI 🤖 only guessed the correct country ✅ ❌." | |
| else: | |
| model_result = "The AI 🤖 failed to guess country and continent ❌ ❌." | |
| metadata_block = gr.Accordion( | |
| visible=True, | |
| label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}", | |
| ) | |
| metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash]) | |
| return continent_probs, country_probs, metadata_block, metadata_map | |
| def make_versus_map(human_country, model_country, versus_state): | |
| if human_country: | |
| human_coordinates = country_to_center_coords[human_country] | |
| else: | |
| human_coordinates = (None, None) | |
| model_coordinates = country_to_center_coords[model_country] | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scattermapbox( | |
| lon=[versus_state["lon"]], | |
| lat=[versus_state["lat"]], | |
| text=[f"📷 Photo taken in {versus_state['country']}, {versus_state['continent']}"], | |
| mode="markers", | |
| hoverinfo="text", | |
| marker=dict(size=14, color="#0C5DA5"), | |
| showlegend=True, | |
| name="📷 Photo Location", | |
| ) | |
| ) | |
| if human_country == model_country: | |
| fig.add_trace( | |
| go.Scattermapbox( | |
| lat=[human_coordinates[0], model_coordinates[0]], | |
| lon=[human_coordinates[1], model_coordinates[1]], | |
| text=f"🧑 🤖 Human & AI guess {human_country}", | |
| mode="markers", | |
| hoverinfo="text", | |
| marker=dict(size=14, color="#FF9500"), | |
| showlegend=True, | |
| name="🧑 🤖 Human & AI Guess", | |
| ) | |
| ) | |
| else: | |
| if human_country: | |
| fig.add_trace( | |
| go.Scattermapbox( | |
| lat=[human_coordinates[0]], | |
| lon=[human_coordinates[1]], | |
| text=[f"🧑 Human guesses {human_country}"], | |
| mode="markers", | |
| hoverinfo="text", | |
| marker=dict(size=14, color="#FF9500"), | |
| showlegend=True, | |
| name="🧑 Human Guess", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scattermapbox( | |
| lat=[model_coordinates[0]], | |
| lon=[model_coordinates[1]], | |
| text=[f"🤖 AI guesses {model_country}"], | |
| mode="markers", | |
| hoverinfo="text", | |
| marker=dict(size=14, color="#474747"), | |
| showlegend=True, | |
| name="🤖 AI Guess", | |
| ) | |
| ) | |
| fig.update_layout( | |
| mapbox=dict( | |
| style="carto-positron", | |
| center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])), | |
| zoom=2, | |
| ), | |
| margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
| legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01), | |
| ) | |
| return fig | |
| def versus_mode_inputs(input_img, human_continent, human_country, versus_state): | |
| human_points = 0 | |
| model_points = 0 | |
| if human_country == versus_state["country"]: | |
| country_result = "✅" | |
| human_points += 2 | |
| else: | |
| country_result = "❌" | |
| if human_continent == versus_state["continent"]: | |
| continent_result = "✅" | |
| human_points += 1 | |
| else: | |
| continent_result = "❌" | |
| human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}" | |
| human_score_update = ( | |
| f"+{human_points} points" if human_points > 0 else "0 Points..." | |
| ) | |
| versus_state["score"]["HUMAN"] += human_points | |
| continent_probs, country_probs, _, _ = predict(input_img) | |
| model_country = max(country_probs, key=country_probs.get) | |
| model_continent = max(continent_probs, key=continent_probs.get) | |
| if model_country == versus_state["country"]: | |
| model_country_result = "✅" | |
| model_points += 2 | |
| else: | |
| model_country_result = "❌" | |
| if model_continent == versus_state["continent"]: | |
| model_continent_result = "✅" | |
| model_points += 1 | |
| else: | |
| model_continent_result = "❌" | |
| model_score_update = ( | |
| f"+{model_points} points" | |
| if model_points > 0 | |
| else "0 Points... The model was completely wrong, it seems the world is not doomed yet." | |
| ) | |
| versus_state["score"]["AI"] += model_points | |
| map = make_versus_map(human_country, model_country, versus_state) | |
| return ( | |
| f""" | |
| ## {human_result} | |
| ### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result} | |
| 🧑 {human_score_update} | |
| 🤖 {model_score_update} | |
| ### Score 🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖 | |
| """, | |
| continent_probs, | |
| country_probs, | |
| map, | |
| versus_state, | |
| ) | |
| def get_example_images(dir): | |
| image_extensions = (".jpg", ".jpeg", ".png") | |
| image_files = [] | |
| for root, dirs, files in os.walk(dir): | |
| for file in files: | |
| if file.lower().endswith(image_extensions): | |
| image_files.append(os.path.join(root, file)) | |
| return image_files | |
| def next_versus_image(versus_state): | |
| versus_image = random.sample(versus_state["images"], 1)[0] | |
| versus_state["continent"] = versus_image.split("/")[-1].split("_")[0] | |
| versus_state["country"] = versus_image.split("/")[-1].split("_")[1] | |
| versus_state["lat"] = versus_image.split("/")[-1].split("_")[2] | |
| versus_state["lon"] = versus_image.split("/")[-1].split("_")[3] | |
| versus_state["image"] = versus_image | |
| return versus_image, versus_state, None, None | |
| example_images = get_example_images("kerger-test-images") | |
| EXAMPLE_METADATA = {} | |
| for img_path in example_images: | |
| hash = hashlib.sha1(np.asarray(Image.open(img_path)).data.tobytes()).hexdigest() | |
| EXAMPLE_METADATA[hash] = { | |
| "continent": img_path.split("/")[-1].split("_")[0], | |
| "country": img_path.split("/")[-1].split("_")[1], | |
| "lat": img_path.split("/")[-1].split("_")[2], | |
| "lon": img_path.split("/")[-1].split("_")[3], | |
| } | |
| def set_up_intial_state(): | |
| INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg" | |
| INITAL_VERSUS_STATE = { | |
| "image": INTIAL_VERSUS_IMAGE, | |
| "continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0], | |
| "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1], | |
| "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2], | |
| "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3], | |
| "score": {"HUMAN": 0, "AI": 0}, | |
| "images": get_example_images("versus_images") | |
| } | |
| return INITAL_VERSUS_STATE | |
| demo = gr.Blocks(title="Thesis Demo") | |
| with demo: | |
| gr.HTML( | |
| """ | |
| <h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1> | |
| <h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3> | |
| <p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p> | |
| <p>In the <b>"Versus Mode"</b> tab you can play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/" target="_blank" rel="noopener noreferrer"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838" target="_blank" rel="noopener noreferrer"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI? | |
| <div style="font-style: italic; font-size: smaller;">Note that inference in this publicly hosted version is very slow due to the limited and shared hardware. This demo runs on the <a href="https://huggingface.co/pricing#spaces" style="color: inherit;" target="_blank" rel="noopener noreferrer">Hugging Face free tier</a> without GPU acceleration. Running the demo with a GPU allows for inference times between 0.5-2 seconds per image.</div> | |
| """ | |
| ) | |
| with gr.Accordion( | |
| label="The demo currently encompasses 116 countries from 6 continents 🌍", | |
| open=False, | |
| ): | |
| gr.Code( | |
| json.dumps(countries_per_continent, indent=2, ensure_ascii=False), | |
| label="countries_per_continent.json", | |
| language="json", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Image Geolocation Demo"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image( | |
| label="Image", type="pil", sources=["upload", "clipboard"], show_fullscreen_button=True | |
| ) | |
| predict_btn = gr.Button("Predict") | |
| example_images = get_example_images("kerger-test-images") | |
| # example_images.extend(get_example_images("versus_images")) | |
| gr.Examples(examples=example_images, inputs=image, examples_per_page=24) | |
| with gr.Column(): | |
| with gr.Accordion(visible=False) as metadata_block: | |
| map = gr.Plot(label="Locations") | |
| with gr.Group(): | |
| continents_label = gr.Label(label="Continents") | |
| country_label = gr.Label(num_top_classes=5, label="Top countries") | |
| predict_btn.click( | |
| predict, | |
| inputs=image, | |
| outputs=[continents_label, country_label, metadata_block, map], | |
| ) | |
| with gr.Tab("Versus Mode"): | |
| versus_state = gr.State(value=set_up_intial_state()) | |
| with gr.Row(): | |
| with gr.Column(): | |
| versus_image = gr.Image(versus_state.value["image"], interactive=False, show_download_button=False, show_share_button=False, show_fullscreen_button=True) | |
| continent_selection = gr.Radio( | |
| continents, | |
| label="Continents", | |
| info="Where was this image taken? (1 Point)", | |
| ) | |
| country_selection = ( | |
| gr.Dropdown( | |
| countries, | |
| label="Countries", | |
| info="Can you guess the exact country? (2 Points)", | |
| ), | |
| ) | |
| with gr.Row(): | |
| next_img_btn = gr.Button("Try new image") | |
| versus_btn = gr.Button("Submit guess") | |
| with gr.Column(): | |
| versus_output = gr.Markdown() | |
| # with gr.Accordion("View Map", open=False): | |
| map = gr.Plot(label="Locations") | |
| with gr.Accordion("Full Model Output", open=False): | |
| with gr.Group(): | |
| continents_label = gr.Label(label="Continents") | |
| country_label = gr.Label( | |
| num_top_classes=5, label="Top countries" | |
| ) | |
| next_img_btn.click( | |
| next_versus_image, | |
| inputs=[versus_state], | |
| outputs=[ | |
| versus_image, | |
| versus_state, | |
| continent_selection, | |
| country_selection[0], | |
| ], | |
| ) | |
| versus_btn.click( | |
| versus_mode_inputs, | |
| inputs=[ | |
| versus_image, | |
| continent_selection, | |
| country_selection[0], | |
| versus_state, | |
| ], | |
| outputs=[versus_output, continents_label, country_label, map, versus_state], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False) | |