Spaces:
Runtime error
Runtime error
Jonas Rheiner
commited on
Commit
Β·
e20beac
1
Parent(s):
8b18a0c
Reformat
Browse files
app.py
CHANGED
|
@@ -18,52 +18,158 @@ device = "cuda" if CUDA_AVAILABLE else "cpu"
|
|
| 18 |
print(f"count={torch.cuda.device_count()}")
|
| 19 |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
| 20 |
|
| 21 |
-
continent_model = CLIPModel.from_pretrained(
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
continent_model = continent_model.to(device)
|
| 25 |
country_model = country_model.to(device)
|
| 26 |
|
| 27 |
|
| 28 |
-
continents = ["Africa", "Asia", "Europe",
|
| 29 |
-
"North America", "Oceania", "South America"]
|
| 30 |
countries_per_continent = {
|
| 31 |
"Africa": [
|
| 32 |
-
"Botswana",
|
| 33 |
-
"
|
| 34 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
],
|
| 36 |
"Asia": [
|
| 37 |
-
"Bangladesh",
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
],
|
| 44 |
"Europe": [
|
| 45 |
-
"Albania",
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
],
|
| 53 |
"North America": [
|
| 54 |
-
"Canada",
|
| 55 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
],
|
| 57 |
"Oceania": [
|
| 58 |
-
"Australia",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
],
|
| 60 |
"South America": [
|
| 61 |
-
"Argentina",
|
| 62 |
-
"
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
-
countries = list(set(itertools.chain.from_iterable(
|
| 66 |
-
countries_per_continent.values())))
|
| 67 |
|
| 68 |
country_to_center_coords = {
|
| 69 |
"Indonesia": (-2.4833826, 117.8902853),
|
|
@@ -181,7 +287,7 @@ country_to_center_coords = {
|
|
| 181 |
"Djibouti": (11.8145966, 42.8453061),
|
| 182 |
"Senegal": (14.4750607, -14.4529612),
|
| 183 |
"Bermuda": (32.3040273, -64.7563086),
|
| 184 |
-
"United States": (39.7837304, -100.445882)
|
| 185 |
}
|
| 186 |
|
| 187 |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
|
|
@@ -191,29 +297,35 @@ INITAL_VERSUS_STATE = {
|
|
| 191 |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
|
| 192 |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
|
| 193 |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
|
| 194 |
-
"score": {
|
| 195 |
-
|
| 196 |
-
"AI": 0
|
| 197 |
-
},
|
| 198 |
-
"idx": 0
|
| 199 |
}
|
| 200 |
|
| 201 |
|
| 202 |
def predict(input_img):
|
| 203 |
-
inputs = processor(
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
inputs = inputs.to(device)
|
| 206 |
with torch.no_grad():
|
| 207 |
outputs = continent_model(**inputs)
|
| 208 |
logits_per_image = outputs.logits_per_image
|
| 209 |
probs = logits_per_image.softmax(dim=-1)
|
| 210 |
pred_id = probs.argmax().cpu().item()
|
| 211 |
-
continent_probs = {
|
| 212 |
-
|
|
|
|
| 213 |
model_continent = continents[pred_id]
|
| 214 |
predicted_continent_countries = countries_per_continent[model_continent]
|
| 215 |
-
inputs = processor(
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
inputs = inputs.to(device)
|
| 218 |
with torch.no_grad():
|
| 219 |
outputs = country_model(**inputs)
|
|
@@ -221,26 +333,37 @@ def predict(input_img):
|
|
| 221 |
probs = logits_per_image.softmax(dim=-1)
|
| 222 |
pred_id = probs.argmax().cpu().item()
|
| 223 |
model_country = predicted_continent_countries[pred_id]
|
| 224 |
-
country_probs = {
|
| 225 |
-
predicted_continent_countries, probs.tolist()[0])
|
| 226 |
-
|
|
|
|
| 227 |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
|
| 228 |
metadata_block = gr.Accordion(visible=False)
|
| 229 |
metadata_map = None
|
| 230 |
if hash in EXAMPLE_METADATA.keys():
|
| 231 |
model_result = ""
|
| 232 |
-
if
|
|
|
|
|
|
|
|
|
|
| 233 |
model_result = "The AI π€ correctly guessed continent and country β
β
."
|
| 234 |
-
elif model_continent == EXAMPLE_METADATA[hash][
|
| 235 |
model_result = "The AI π€ only guessed the correct continent β β
."
|
| 236 |
-
elif
|
|
|
|
|
|
|
|
|
|
| 237 |
model_result = "The AI π€ only guessed the correct country β
β."
|
| 238 |
else:
|
| 239 |
model_result = "The AI π€ failed to guess country and continent β β."
|
| 240 |
-
metadata_block = gr.Accordion(
|
|
|
|
|
|
|
|
|
|
| 241 |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
|
| 242 |
return continent_probs, country_probs, metadata_block, metadata_map
|
| 243 |
|
|
|
|
| 244 |
def make_versus_map(human_country, model_country, versus_state):
|
| 245 |
if human_country:
|
| 246 |
human_coordinates = country_to_center_coords[human_country]
|
|
@@ -248,64 +371,66 @@ def make_versus_map(human_country, model_country, versus_state):
|
|
| 248 |
human_coordinates = (None, None)
|
| 249 |
model_coordinates = country_to_center_coords[model_country]
|
| 250 |
fig = go.Figure()
|
| 251 |
-
fig.add_trace(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
versus_state['continent']}"],
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
showlegend=True,
|
| 260 |
-
name="π· Photo Location"
|
| 261 |
-
))
|
| 262 |
-
if human_country == model_country:
|
| 263 |
-
fig.add_trace(go.Scattermapbox(
|
| 264 |
-
lat=[human_coordinates[0], model_coordinates[0]],
|
| 265 |
-
lon=[human_coordinates[1], model_coordinates[1]],
|
| 266 |
-
text=f"π§ π€ Human & AI guess {human_country}",
|
| 267 |
-
mode='markers',
|
| 268 |
-
hoverinfo='text',
|
| 269 |
-
marker=dict(size=14, color='#FF9500'),
|
| 270 |
showlegend=True,
|
| 271 |
-
name="
|
| 272 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
else:
|
| 274 |
if human_country:
|
| 275 |
-
fig.add_trace(
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
showlegend=True,
|
| 283 |
-
name="
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
lon=[model_coordinates[1]],
|
| 288 |
-
text=[f"π€ AI guesses {model_country}"],
|
| 289 |
-
mode='markers',
|
| 290 |
-
hoverinfo='text',
|
| 291 |
-
marker=dict(size=14, color='#474747'),
|
| 292 |
-
showlegend=True,
|
| 293 |
-
name="π€ AI Guess"
|
| 294 |
-
))
|
| 295 |
-
|
| 296 |
fig.update_layout(
|
| 297 |
mapbox=dict(
|
| 298 |
style="carto-positron",
|
| 299 |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
|
| 300 |
-
zoom=2
|
| 301 |
),
|
| 302 |
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
| 303 |
-
legend=dict(
|
| 304 |
-
yanchor="bottom",
|
| 305 |
-
y=0.01,
|
| 306 |
-
xanchor="left",
|
| 307 |
-
x=0.01
|
| 308 |
-
)
|
| 309 |
)
|
| 310 |
return fig
|
| 311 |
|
|
@@ -323,12 +448,13 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
| 323 |
human_points += 1
|
| 324 |
else:
|
| 325 |
continent_result = "β"
|
| 326 |
-
human_result = f"The photo is from **{versus_state['country']}** {
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
| 330 |
|
| 331 |
-
continent_probs, country_probs, _,_ = predict(input_img)
|
| 332 |
model_country = max(country_probs, key=country_probs.get)
|
| 333 |
model_continent = max(continent_probs, key=continent_probs.get)
|
| 334 |
if model_country == versus_state["country"]:
|
|
@@ -341,11 +467,16 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
| 341 |
model_points += 1
|
| 342 |
else:
|
| 343 |
model_continent_result = "β"
|
| 344 |
-
model_score_update =
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
map = make_versus_map(human_country, model_country, versus_state)
|
| 348 |
-
return
|
|
|
|
| 349 |
## {human_result}
|
| 350 |
### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
|
| 351 |
|
|
@@ -353,7 +484,12 @@ def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
|
|
| 353 |
π€ {model_score_update}
|
| 354 |
|
| 355 |
### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€
|
| 356 |
-
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
|
| 359 |
def get_example_images(dir):
|
|
@@ -393,45 +529,65 @@ for img_path in example_images:
|
|
| 393 |
|
| 394 |
demo = gr.Blocks(title="Thesis Demo")
|
| 395 |
with demo:
|
| 396 |
-
gr.HTML(
|
|
|
|
| 397 |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
|
| 398 |
|
| 399 |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
|
| 400 |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
|
| 401 |
<p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
|
| 402 |
|
| 403 |
-
"""
|
| 404 |
-
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
with gr.Tab("Image Geolocation Demo"):
|
| 407 |
with gr.Row():
|
| 408 |
with gr.Column():
|
| 409 |
-
image = gr.Image(
|
| 410 |
-
|
|
|
|
| 411 |
predict_btn = gr.Button("Predict")
|
| 412 |
example_images = get_example_images("kerger-test-images")
|
| 413 |
# example_images.extend(get_example_images("versus_images"))
|
| 414 |
-
gr.Examples(examples=example_images,
|
| 415 |
-
inputs=image, examples_per_page=24)
|
| 416 |
with gr.Column():
|
| 417 |
with gr.Accordion(visible=False) as metadata_block:
|
| 418 |
map = gr.Plot(label="Locations")
|
| 419 |
with gr.Group():
|
| 420 |
continents_label = gr.Label(label="Continents")
|
| 421 |
-
country_label = gr.Label(
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
|
|
|
|
|
|
| 425 |
|
| 426 |
with gr.Tab("Versus Mode"):
|
| 427 |
versus_state = gr.State(value=INITAL_VERSUS_STATE)
|
| 428 |
with gr.Row():
|
| 429 |
with gr.Column():
|
| 430 |
-
versus_image = gr.Image(
|
| 431 |
-
INITAL_VERSUS_STATE["image"], interactive=False)
|
| 432 |
continent_selection = gr.Radio(
|
| 433 |
-
continents,
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
with gr.Row():
|
| 436 |
next_img_btn = gr.Button("Try new image")
|
| 437 |
versus_btn = gr.Button("Submit guess")
|
|
@@ -443,11 +599,28 @@ with demo:
|
|
| 443 |
with gr.Group():
|
| 444 |
continents_label = gr.Label(label="Continents")
|
| 445 |
country_label = gr.Label(
|
| 446 |
-
num_top_classes=5, label="Top countries"
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
|
| 453 |
if __name__ == "__main__":
|
|
|
|
| 18 |
print(f"count={torch.cuda.device_count()}")
|
| 19 |
print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
| 20 |
|
| 21 |
+
continent_model = CLIPModel.from_pretrained(
|
| 22 |
+
"jrheiner/thesis-clip-geoloc-continent",
|
| 23 |
+
token=os.getenv("token"),
|
| 24 |
+
)
|
| 25 |
+
country_model = CLIPModel.from_pretrained(
|
| 26 |
+
"jrheiner/thesis-clip-geoloc-country",
|
| 27 |
+
token=os.getenv("token"),
|
| 28 |
+
)
|
| 29 |
+
processor = CLIPProcessor.from_pretrained(
|
| 30 |
+
"jrheiner/thesis-clip-geoloc-continent",
|
| 31 |
+
token=os.getenv("token"),
|
| 32 |
+
)
|
| 33 |
continent_model = continent_model.to(device)
|
| 34 |
country_model = country_model.to(device)
|
| 35 |
|
| 36 |
|
| 37 |
+
continents = ["Africa", "Asia", "Europe", "North America", "Oceania", "South America"]
|
|
|
|
| 38 |
countries_per_continent = {
|
| 39 |
"Africa": [
|
| 40 |
+
"Botswana",
|
| 41 |
+
"Eswatini",
|
| 42 |
+
"Ghana",
|
| 43 |
+
"Kenya",
|
| 44 |
+
"Lesotho",
|
| 45 |
+
"Nigeria",
|
| 46 |
+
"Senegal",
|
| 47 |
+
"South Africa",
|
| 48 |
+
"Rwanda",
|
| 49 |
+
"Uganda",
|
| 50 |
+
"Tanzania",
|
| 51 |
+
"Madagascar",
|
| 52 |
+
"Djibouti",
|
| 53 |
+
"Mali",
|
| 54 |
+
"Libya",
|
| 55 |
+
"Morocco",
|
| 56 |
+
"Somalia",
|
| 57 |
+
"Tunisia",
|
| 58 |
+
"Egypt",
|
| 59 |
+
"RΓ©union",
|
| 60 |
],
|
| 61 |
"Asia": [
|
| 62 |
+
"Bangladesh",
|
| 63 |
+
"Bhutan",
|
| 64 |
+
"Cambodia",
|
| 65 |
+
"China",
|
| 66 |
+
"India",
|
| 67 |
+
"Indonesia",
|
| 68 |
+
"Israel",
|
| 69 |
+
"Japan",
|
| 70 |
+
"Jordan",
|
| 71 |
+
"Kyrgyzstan",
|
| 72 |
+
"Laos",
|
| 73 |
+
"Malaysia",
|
| 74 |
+
"Mongolia",
|
| 75 |
+
"Nepal",
|
| 76 |
+
"Palestine",
|
| 77 |
+
"Philippines",
|
| 78 |
+
"Singapore",
|
| 79 |
+
"South Korea",
|
| 80 |
+
"Sri Lanka",
|
| 81 |
+
"Taiwan",
|
| 82 |
+
"Thailand",
|
| 83 |
+
"United Arab Emirates",
|
| 84 |
+
"Vietnam",
|
| 85 |
+
"Afghanistan",
|
| 86 |
+
"Azerbaijan",
|
| 87 |
+
"Cyprus",
|
| 88 |
+
"Iran",
|
| 89 |
+
"Syria",
|
| 90 |
+
"Tajikistan",
|
| 91 |
+
"Turkey",
|
| 92 |
+
"Russia",
|
| 93 |
+
"Pakistan",
|
| 94 |
+
"Hong Kong",
|
| 95 |
],
|
| 96 |
"Europe": [
|
| 97 |
+
"Albania",
|
| 98 |
+
"Andorra",
|
| 99 |
+
"Austria",
|
| 100 |
+
"Belgium",
|
| 101 |
+
"Bulgaria",
|
| 102 |
+
"Croatia",
|
| 103 |
+
"Czechia",
|
| 104 |
+
"Denmark",
|
| 105 |
+
"Estonia",
|
| 106 |
+
"Finland",
|
| 107 |
+
"France",
|
| 108 |
+
"Germany",
|
| 109 |
+
"Greece",
|
| 110 |
+
"Hungary",
|
| 111 |
+
"Iceland",
|
| 112 |
+
"Ireland",
|
| 113 |
+
"Italy",
|
| 114 |
+
"Latvia",
|
| 115 |
+
"Lithuania",
|
| 116 |
+
"Luxembourg",
|
| 117 |
+
"Montenegro",
|
| 118 |
+
"Netherlands",
|
| 119 |
+
"North Macedonia",
|
| 120 |
+
"Norway",
|
| 121 |
+
"Poland",
|
| 122 |
+
"Portugal",
|
| 123 |
+
"Romania",
|
| 124 |
+
"Russia",
|
| 125 |
+
"Serbia",
|
| 126 |
+
"Slovakia",
|
| 127 |
+
"Slovenia",
|
| 128 |
+
"Spain",
|
| 129 |
+
"Sweden",
|
| 130 |
+
"Switzerland",
|
| 131 |
+
"Ukraine",
|
| 132 |
+
"United Kingdom",
|
| 133 |
+
"Bosnia and Herzegovina",
|
| 134 |
+
"Cyprus",
|
| 135 |
+
"Turkey",
|
| 136 |
+
"Greenland",
|
| 137 |
+
"Faroe Islands",
|
| 138 |
],
|
| 139 |
"North America": [
|
| 140 |
+
"Canada",
|
| 141 |
+
"Dominican Republic",
|
| 142 |
+
"Guatemala",
|
| 143 |
+
"Mexico",
|
| 144 |
+
"United States",
|
| 145 |
+
"Bahamas",
|
| 146 |
+
"Cuba",
|
| 147 |
+
"Panama",
|
| 148 |
+
"Puerto Rico",
|
| 149 |
+
"Bermuda",
|
| 150 |
+
"Greenland",
|
| 151 |
],
|
| 152 |
"Oceania": [
|
| 153 |
+
"Australia",
|
| 154 |
+
"New Zealand",
|
| 155 |
+
"Fiji",
|
| 156 |
+
"Papua New Guinea",
|
| 157 |
+
"Solomon Islands",
|
| 158 |
+
"Vanuatu",
|
| 159 |
],
|
| 160 |
"South America": [
|
| 161 |
+
"Argentina",
|
| 162 |
+
"Bolivia",
|
| 163 |
+
"Brazil",
|
| 164 |
+
"Chile",
|
| 165 |
+
"Colombia",
|
| 166 |
+
"Ecuador",
|
| 167 |
+
"Paraguay",
|
| 168 |
+
"Peru",
|
| 169 |
+
"Uruguay",
|
| 170 |
+
],
|
| 171 |
}
|
| 172 |
+
countries = list(set(itertools.chain.from_iterable(countries_per_continent.values())))
|
|
|
|
| 173 |
|
| 174 |
country_to_center_coords = {
|
| 175 |
"Indonesia": (-2.4833826, 117.8902853),
|
|
|
|
| 287 |
"Djibouti": (11.8145966, 42.8453061),
|
| 288 |
"Senegal": (14.4750607, -14.4529612),
|
| 289 |
"Bermuda": (32.3040273, -64.7563086),
|
| 290 |
+
"United States": (39.7837304, -100.445882),
|
| 291 |
}
|
| 292 |
|
| 293 |
INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
|
|
|
|
| 297 |
"country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
|
| 298 |
"lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
|
| 299 |
"lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
|
| 300 |
+
"score": {"HUMAN": 0, "AI": 0},
|
| 301 |
+
"idx": 0,
|
|
|
|
|
|
|
|
|
|
| 302 |
}
|
| 303 |
|
| 304 |
|
| 305 |
def predict(input_img):
|
| 306 |
+
inputs = processor(
|
| 307 |
+
text=[f"A photo from {geo}." for geo in continents],
|
| 308 |
+
images=input_img,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
padding=True,
|
| 311 |
+
)
|
| 312 |
inputs = inputs.to(device)
|
| 313 |
with torch.no_grad():
|
| 314 |
outputs = continent_model(**inputs)
|
| 315 |
logits_per_image = outputs.logits_per_image
|
| 316 |
probs = logits_per_image.softmax(dim=-1)
|
| 317 |
pred_id = probs.argmax().cpu().item()
|
| 318 |
+
continent_probs = {
|
| 319 |
+
label: prob for label, prob in zip(continents, probs.tolist()[0])
|
| 320 |
+
}
|
| 321 |
model_continent = continents[pred_id]
|
| 322 |
predicted_continent_countries = countries_per_continent[model_continent]
|
| 323 |
+
inputs = processor(
|
| 324 |
+
text=[f"A photo from {geo}." for geo in predicted_continent_countries],
|
| 325 |
+
images=input_img,
|
| 326 |
+
return_tensors="pt",
|
| 327 |
+
padding=True,
|
| 328 |
+
)
|
| 329 |
inputs = inputs.to(device)
|
| 330 |
with torch.no_grad():
|
| 331 |
outputs = country_model(**inputs)
|
|
|
|
| 333 |
probs = logits_per_image.softmax(dim=-1)
|
| 334 |
pred_id = probs.argmax().cpu().item()
|
| 335 |
model_country = predicted_continent_countries[pred_id]
|
| 336 |
+
country_probs = {
|
| 337 |
+
label: prob for label, prob in zip(predicted_continent_countries, probs.tolist()[0])
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
hash = hashlib.sha1(np.asarray(input_img).data.tobytes()).hexdigest()
|
| 341 |
metadata_block = gr.Accordion(visible=False)
|
| 342 |
metadata_map = None
|
| 343 |
if hash in EXAMPLE_METADATA.keys():
|
| 344 |
model_result = ""
|
| 345 |
+
if (
|
| 346 |
+
model_continent == EXAMPLE_METADATA[hash]["continent"]
|
| 347 |
+
and model_country == EXAMPLE_METADATA[hash]["country"]
|
| 348 |
+
):
|
| 349 |
model_result = "The AI π€ correctly guessed continent and country β
β
."
|
| 350 |
+
elif model_continent == EXAMPLE_METADATA[hash]["continent"]:
|
| 351 |
model_result = "The AI π€ only guessed the correct continent β β
."
|
| 352 |
+
elif (
|
| 353 |
+
model_country == EXAMPLE_METADATA[hash]["country"]
|
| 354 |
+
and model_continent != EXAMPLE_METADATA[hash]["continent"]
|
| 355 |
+
):
|
| 356 |
model_result = "The AI π€ only guessed the correct country β
β."
|
| 357 |
else:
|
| 358 |
model_result = "The AI π€ failed to guess country and continent β β."
|
| 359 |
+
metadata_block = gr.Accordion(
|
| 360 |
+
visible=True,
|
| 361 |
+
label=f"This photo was taken in {EXAMPLE_METADATA[hash]['country']}, {EXAMPLE_METADATA[hash]['continent']}.\n{model_result}",
|
| 362 |
+
)
|
| 363 |
metadata_map = make_versus_map(None, model_country, EXAMPLE_METADATA[hash])
|
| 364 |
return continent_probs, country_probs, metadata_block, metadata_map
|
| 365 |
|
| 366 |
+
|
| 367 |
def make_versus_map(human_country, model_country, versus_state):
|
| 368 |
if human_country:
|
| 369 |
human_coordinates = country_to_center_coords[human_country]
|
|
|
|
| 371 |
human_coordinates = (None, None)
|
| 372 |
model_coordinates = country_to_center_coords[model_country]
|
| 373 |
fig = go.Figure()
|
| 374 |
+
fig.add_trace(
|
| 375 |
+
go.Scattermapbox(
|
| 376 |
+
lon=[versus_state["lon"]],
|
| 377 |
+
lat=[versus_state["lat"]],
|
| 378 |
+
text=[f"π· Photo taken in {versus_state['country']}, {versus_state['continent']}"],
|
| 379 |
+
mode="markers",
|
| 380 |
+
hoverinfo="text",
|
| 381 |
+
marker=dict(size=14, color="#0C5DA5"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
showlegend=True,
|
| 383 |
+
name="π· Photo Location",
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
if human_country == model_country:
|
| 387 |
+
fig.add_trace(
|
| 388 |
+
go.Scattermapbox(
|
| 389 |
+
lat=[human_coordinates[0], model_coordinates[0]],
|
| 390 |
+
lon=[human_coordinates[1], model_coordinates[1]],
|
| 391 |
+
text=f"π§ π€ Human & AI guess {human_country}",
|
| 392 |
+
mode="markers",
|
| 393 |
+
hoverinfo="text",
|
| 394 |
+
marker=dict(size=14, color="#FF9500"),
|
| 395 |
+
showlegend=True,
|
| 396 |
+
name="π§ π€ Human & AI Guess",
|
| 397 |
+
)
|
| 398 |
+
)
|
| 399 |
else:
|
| 400 |
if human_country:
|
| 401 |
+
fig.add_trace(
|
| 402 |
+
go.Scattermapbox(
|
| 403 |
+
lat=[human_coordinates[0]],
|
| 404 |
+
lon=[human_coordinates[1]],
|
| 405 |
+
text=[f"π§ Human guesses {human_country}"],
|
| 406 |
+
mode="markers",
|
| 407 |
+
hoverinfo="text",
|
| 408 |
+
marker=dict(size=14, color="#FF9500"),
|
| 409 |
+
showlegend=True,
|
| 410 |
+
name="π§ Human Guess",
|
| 411 |
+
)
|
| 412 |
+
)
|
| 413 |
+
fig.add_trace(
|
| 414 |
+
go.Scattermapbox(
|
| 415 |
+
lat=[model_coordinates[0]],
|
| 416 |
+
lon=[model_coordinates[1]],
|
| 417 |
+
text=[f"π€ AI guesses {model_country}"],
|
| 418 |
+
mode="markers",
|
| 419 |
+
hoverinfo="text",
|
| 420 |
+
marker=dict(size=14, color="#474747"),
|
| 421 |
showlegend=True,
|
| 422 |
+
name="π€ AI Guess",
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
fig.update_layout(
|
| 427 |
mapbox=dict(
|
| 428 |
style="carto-positron",
|
| 429 |
center=dict(lat=float(versus_state["lat"]), lon=float(versus_state["lon"])),
|
| 430 |
+
zoom=2,
|
| 431 |
),
|
| 432 |
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
| 433 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.01),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
)
|
| 435 |
return fig
|
| 436 |
|
|
|
|
| 448 |
human_points += 1
|
| 449 |
else:
|
| 450 |
continent_result = "β"
|
| 451 |
+
human_result = f"The photo is from **{versus_state['country']}** {country_result} in **{versus_state['continent']}** {continent_result}"
|
| 452 |
+
human_score_update = (
|
| 453 |
+
f"+{human_points} points" if human_points > 0 else "0 Points..."
|
| 454 |
+
)
|
| 455 |
+
versus_state["score"]["HUMAN"] += human_points
|
| 456 |
|
| 457 |
+
continent_probs, country_probs, _, _ = predict(input_img)
|
| 458 |
model_country = max(country_probs, key=country_probs.get)
|
| 459 |
model_continent = max(continent_probs, key=continent_probs.get)
|
| 460 |
if model_country == versus_state["country"]:
|
|
|
|
| 467 |
model_points += 1
|
| 468 |
else:
|
| 469 |
model_continent_result = "β"
|
| 470 |
+
model_score_update = (
|
| 471 |
+
f"+{model_points} points"
|
| 472 |
+
if model_points > 0
|
| 473 |
+
else "0 Points... The model was completely wrong, it seems the world is not doomed yet."
|
| 474 |
+
)
|
| 475 |
+
versus_state["score"]["AI"] += model_points
|
| 476 |
|
| 477 |
map = make_versus_map(human_country, model_country, versus_state)
|
| 478 |
+
return (
|
| 479 |
+
f"""
|
| 480 |
## {human_result}
|
| 481 |
### The AI π€ thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
|
| 482 |
|
|
|
|
| 484 |
π€ {model_score_update}
|
| 485 |
|
| 486 |
### Score π§ {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} π€
|
| 487 |
+
""",
|
| 488 |
+
continent_probs,
|
| 489 |
+
country_probs,
|
| 490 |
+
map,
|
| 491 |
+
versus_state,
|
| 492 |
+
)
|
| 493 |
|
| 494 |
|
| 495 |
def get_example_images(dir):
|
|
|
|
| 529 |
|
| 530 |
demo = gr.Blocks(title="Thesis Demo")
|
| 531 |
with demo:
|
| 532 |
+
gr.HTML(
|
| 533 |
+
"""
|
| 534 |
<h1 style="text-align: center; margin-bottom: 1rem">Image Geolocation Thesis Demo</h1>
|
| 535 |
|
| 536 |
<h3> This Demo showcases the developed models and allows interacting with the optimized prototype.</h3>
|
| 537 |
<p>Try the <b>"Image Geolocation Demo"</b> tab with your own images or with one of the examples. For all example image the ground truth is available and will be displayed together with the model predictions.</p>
|
| 538 |
<p>In the <b>"Versus Mode"</b> tab to play against the AI, guessing the country and continent where images where taken. Images in the versus mode are from the <a href="http://graphics.cs.cmu.edu/projects/im2gps/"><code>Im2GPS</code></a> and <a href="https://arxiv.org/abs/1705.04838"><code>Im2GPS3k</code></a> geolocation literature benchmarks. Can you beat the AI?
|
| 539 |
|
| 540 |
+
"""
|
| 541 |
+
)
|
| 542 |
+
with gr.Accordion(
|
| 543 |
+
label="The demo currently encompasses 116 countries from 6 continents π",
|
| 544 |
+
open=False,
|
| 545 |
+
):
|
| 546 |
+
gr.Code(
|
| 547 |
+
json.dumps(countries_per_continent, indent=2, ensure_ascii=False),
|
| 548 |
+
label="countries_per_continent.json",
|
| 549 |
+
language="json",
|
| 550 |
+
interactive=False,
|
| 551 |
+
)
|
| 552 |
with gr.Tab("Image Geolocation Demo"):
|
| 553 |
with gr.Row():
|
| 554 |
with gr.Column():
|
| 555 |
+
image = gr.Image(
|
| 556 |
+
label="Image", type="pil", sources=["upload", "clipboard"]
|
| 557 |
+
)
|
| 558 |
predict_btn = gr.Button("Predict")
|
| 559 |
example_images = get_example_images("kerger-test-images")
|
| 560 |
# example_images.extend(get_example_images("versus_images"))
|
| 561 |
+
gr.Examples(examples=example_images, inputs=image, examples_per_page=24)
|
|
|
|
| 562 |
with gr.Column():
|
| 563 |
with gr.Accordion(visible=False) as metadata_block:
|
| 564 |
map = gr.Plot(label="Locations")
|
| 565 |
with gr.Group():
|
| 566 |
continents_label = gr.Label(label="Continents")
|
| 567 |
+
country_label = gr.Label(num_top_classes=5, label="Top countries")
|
| 568 |
+
predict_btn.click(
|
| 569 |
+
predict,
|
| 570 |
+
inputs=image,
|
| 571 |
+
outputs=[continents_label, country_label, metadata_block, map],
|
| 572 |
+
)
|
| 573 |
|
| 574 |
with gr.Tab("Versus Mode"):
|
| 575 |
versus_state = gr.State(value=INITAL_VERSUS_STATE)
|
| 576 |
with gr.Row():
|
| 577 |
with gr.Column():
|
| 578 |
+
versus_image = gr.Image(INITAL_VERSUS_STATE["image"], interactive=False)
|
|
|
|
| 579 |
continent_selection = gr.Radio(
|
| 580 |
+
continents,
|
| 581 |
+
label="Continents",
|
| 582 |
+
info="Where was this image taken? (1 Point)",
|
| 583 |
+
)
|
| 584 |
+
country_selection = (
|
| 585 |
+
gr.Dropdown(
|
| 586 |
+
countries,
|
| 587 |
+
label="Countries",
|
| 588 |
+
info="Can you guess the exact country? (2 Points)",
|
| 589 |
+
),
|
| 590 |
+
)
|
| 591 |
with gr.Row():
|
| 592 |
next_img_btn = gr.Button("Try new image")
|
| 593 |
versus_btn = gr.Button("Submit guess")
|
|
|
|
| 599 |
with gr.Group():
|
| 600 |
continents_label = gr.Label(label="Continents")
|
| 601 |
country_label = gr.Label(
|
| 602 |
+
num_top_classes=5, label="Top countries"
|
| 603 |
+
)
|
| 604 |
+
next_img_btn.click(
|
| 605 |
+
next_versus_image,
|
| 606 |
+
inputs=[versus_state],
|
| 607 |
+
outputs=[
|
| 608 |
+
versus_image,
|
| 609 |
+
versus_state,
|
| 610 |
+
continent_selection,
|
| 611 |
+
country_selection[0],
|
| 612 |
+
],
|
| 613 |
+
)
|
| 614 |
+
versus_btn.click(
|
| 615 |
+
versus_mode_inputs,
|
| 616 |
+
inputs=[
|
| 617 |
+
versus_image,
|
| 618 |
+
continent_selection,
|
| 619 |
+
country_selection[0],
|
| 620 |
+
versus_state,
|
| 621 |
+
],
|
| 622 |
+
outputs=[versus_output, continents_label, country_label, map, versus_state],
|
| 623 |
+
)
|
| 624 |
|
| 625 |
|
| 626 |
if __name__ == "__main__":
|