TheAmateur / app.py
CineAI's picture
Update app.py
c0fce6a verified
import gradio as gr
import os
import pandas as pd
from PIL import Image
import folium
from src.rainbolt_parody import RainboltParody
from src.map_gen import MapGenerator
from src.geo_info import GeoInfo
# --- Constants and System Prompt ---
SYSTEM_PROMPT = """
# ROLE
You are a world-class GeoGuessr expert. Your task is to pinpoint the location of the provided image with extreme precision.
# METHODOLOGY
You will follow a strict, three-stage analytical process.
## STAGE 1: OBSERVATION & CLUE EXTRACTION
Internally, create a structured list of all visual evidence. Do not output this list.
- **Road & Infrastructure:** Road Lines, Bollards, Poles, Signage (Language, Script, Style)
- **Environment & Nature:** Sun Position & Climate, Vegetation & Trees, Soil & Topography
- **Human & Cultural Markers:** Architecture Style, Vehicle Models & License Plates
- **Meta Clues:** Google Car Generation/Antenna, Image Quality/Season
## STAGE 2: DEDUCTION & SYNTHESIS
Internally, reason through the clues from Stage 1.
1. **Broad Localization:** Use major clues (driving side, language, sun) to determine the country or large region.
2. **Narrowing Down:** Use secondary clues (bollards, architecture, area codes) to narrow it down to a specific state, province, or city.
3. **Verification:** Cross-reference your findings. Do all clues point to the same place?
4. **Confidence Assessment:** Based on the consistency of the evidence, determine a confidence score.
## STAGE 3: FINAL OUTPUT
After completing your internal analysis, provide your answer *only* in the following JSON format. The reasoning section must be a concise summary of your Stage 2 synthesis.
{
"continent": "Your final guess",
"country": "Your final guess",
"region": "Your final guess (e.g., state, province, county)",
"city": "Your final guess (or nearest town if rural)",
"coordinates": "Your best estimate for latitude and longitude (e.g., 40.7128, -74.0060)",
"confidence_score": "A score from 0 to 100",
"reasoning": "Summarize your deduction process. Mention the key 3-5 clues and how they led to your conclusion."
}
"""
# --- Backend Functions ---
def handle_error(message, e=None):
"""Helper to format and raise Gradio errors."""
if e:
print(f"Error: {message} - {e}")
raise gr.Error(f"{message}. Details: {str(e)}")
else:
print(f"Error: {message}")
raise gr.Error(message)
def pinpoint_location(api_key, uploaded_image):
"""
Takes an API key and an uploaded image, returns location data.
"""
if not api_key:
handle_error("Google AI API Key is required.")
if uploaded_image is None:
handle_error("Please upload an image.")
try:
rp = RainboltParody(title="Location Analysis", api_key=api_key)
response, _, _ = rp.get_info(SYSTEM_PROMPT=SYSTEM_PROMPT, img=uploaded_image)
info_data = [[key, value] for key, value in response.items()]
info_df = pd.DataFrame(info_data, columns=["Attribute", "Value"])
reasoning = response.get('reasoning', 'No reasoning provided.')
coords_str = response.get('coordinates', '0,0')
mgk = MapGenerator()
single_point_map = mgk.get_map(location1=coords_str)
map_html_output = single_point_map._repr_html_()
# --- FIX IS HERE ---
# The return statement must have exactly 6 values to match the 6 outputs.
return (
info_df, # 1. for info_output_df
reasoning, # 2. for reasoning_output
map_html_output, # 3. for map_output_single
response, # 4. for prediction_state
gr.Accordion(visible=True), # 5. for error_accordion
gr.Accordion(visible=True) # 6. for save_accordion
)
except Exception as e:
handle_error("Failed to process the image with the AI model", e)
def calculate_error_distance(prediction_state, url, manual_coords_str):
"""
Calculates the distance between the predicted point and an actual point.
"""
if not prediction_state:
handle_error("You must pinpoint a location first.")
if not url and not manual_coords_str:
handle_error("Please provide either a Google Maps URL or manual coordinates for the actual location.")
try:
if url:
gi = GeoInfo(response=prediction_state, url=url)
else:
gi = GeoInfo(response=prediction_state, coordinates=manual_coords_str)
error_km, lat_actual, lon_actual, _, _ = gi.calculate_error()
bearing = gi.calculate_bearing()
direction = gi.get_direction(angle=bearing)
info_text = gi.combine_info(error=error_km, direction=direction)
actual_coords = f"{lat_actual},{lon_actual}"
mgk = MapGenerator()
map_with_line = mgk.get_map(
location1=prediction_state['coordinates'],
location2=actual_coords,
error=error_km,
direction=direction
)
map_html = map_with_line._repr_html_()
return info_text, gr.HTML(value=map_html, visible=True), error_km, direction, url if url else ""
except Exception as e:
handle_error("Failed to calculate the distance", e)
def save_and_download(prediction_state, error_km, direction, original_url, filename):
"""
Saves the combined results to a CSV file and provides it for download.
"""
if not prediction_state:
handle_error("No prediction data to save.")
if not filename:
handle_error("Please provide a filename for the CSV.")
if not filename.endswith('.csv'):
filename += '.csv'
try:
data_to_add = prediction_state.copy()
data_to_add['error'] = f"{error_km:.2f}" if error_km is not None else 'N/A'
data_to_add['direction'] = direction if direction else 'N/A'
data_to_add['actual_location_url'] = original_url if original_url else 'N/A'
new_df = pd.DataFrame([data_to_add])
if os.path.exists(filename):
existing_df = pd.read_csv(filename)
final_df = pd.concat([existing_df, new_df], ignore_index=True)
else:
final_df = new_df
final_df.to_csv(filename, index=False)
return gr.DataFrame(value=final_df, visible=True), gr.File(value=filename, label="Download CSV", visible=True)
except Exception as e:
handle_error("Failed to save data to CSV", e)
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr AI") as app:
gr.Markdown("# GeoGuessr AI πŸ”Ž")
gr.Markdown("Pinpoint the location of any Google Street View image.")
prediction_state = gr.State(None)
error_state = gr.State(None)
direction_state = gr.State(None)
url_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## 1. Pinpoint Location")
api_key_input = gr.Textbox(label="Google AI API Key", placeholder="Enter your API key here...", type="password")
image_input = gr.Image(type="pil", label="Upload Street View Image")
pinpoint_btn = gr.Button("Pinpoint Location", variant="primary")
gr.Markdown("### Prediction Results")
info_output_df = gr.DataFrame(label="Location Information", headers=["Attribute", "Value"])
reasoning_output = gr.Textbox(label="AI Reasoning", lines=5, interactive=False)
with gr.Column(scale=2):
gr.Markdown("### Map View")
map_output_single = gr.HTML(label="Predicted Location")
map_output_comparison = gr.HTML(visible=False)
with gr.Accordion("2. Calculate Error (Optional)", open=False, visible=False) as error_accordion:
with gr.Row():
url_input = gr.Textbox(label="Google Maps URL of Actual Location", placeholder="Paste URL here...")
manual_coords_input = gr.Textbox(label="Or, Enter Manual Coordinates (lat, lon)", placeholder="e.g., 51.5074, -0.1278")
calculate_btn = gr.Button("Calculate Distance", variant="secondary")
distance_output = gr.Textbox(label="Result", interactive=False)
with gr.Accordion("3. Save Results (Optional)", open=False, visible=False) as save_accordion:
filename_input = gr.Textbox(label="CSV Filename", value="geo_results.csv")
save_btn = gr.Button("Save to CSV & Prepare Download", variant="secondary")
with gr.Row():
csv_preview_df = gr.DataFrame(label="CSV Content Preview", visible=False)
download_file = gr.File(label="Download CSV", visible=False)
# --- Event Handlers ---
pinpoint_btn.click(
fn=pinpoint_location,
inputs=[api_key_input, image_input],
outputs=[info_output_df, reasoning_output, map_output_single, prediction_state, error_accordion, save_accordion]
)
calculate_btn.click(
fn=calculate_error_distance,
inputs=[prediction_state, url_input, manual_coords_input],
outputs=[distance_output, map_output_comparison, error_state, direction_state, url_state]
)
save_btn.click(
fn=save_and_download,
inputs=[prediction_state, error_state, direction_state, url_state, filename_input],
outputs=[csv_preview_df, download_file]
)
# gr.Examples(
# examples=[
# [os.path.join(REPO_DIR, "examples/example_1.png")],
# [os.path.join(REPO_DIR, "examples/example_2.png")],
# [os.path.join(REPO_DIR, "examples/example_3.png")],
# ],
# inputs=image_input,
# label="Example Images"
# )
if __name__ == "__main__":
app.launch(debug=True)