Spaces:
Sleeping
Sleeping
Andy Lee
commited on
Commit
·
f2b6ded
1
Parent(s):
6fda968
feat: app.py
Browse files
app.py
CHANGED
|
@@ -6,94 +6,95 @@ from io import BytesIO
|
|
| 6 |
from PIL import Image
|
| 7 |
from typing import Dict, List, Any
|
| 8 |
|
| 9 |
-
#
|
| 10 |
from geo_bot import (
|
| 11 |
GeoBot,
|
| 12 |
AGENT_PROMPT_TEMPLATE,
|
| 13 |
BENCHMARK_PROMPT,
|
| 14 |
-
)
|
| 15 |
from benchmark import MapGuesserBenchmark
|
| 16 |
from config import MODELS_CONFIG, DATA_PATHS, SUCCESS_THRESHOLD_KM
|
| 17 |
from langchain_openai import ChatOpenAI
|
| 18 |
from langchain_anthropic import ChatAnthropic
|
| 19 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 20 |
|
| 21 |
-
# ---
|
| 22 |
st.set_page_config(page_title="MapCrunch AI Agent", layout="wide")
|
| 23 |
st.title("🗺️ MapCrunch AI Agent")
|
| 24 |
-
st.caption(
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
# --- Sidebar
|
| 27 |
with st.sidebar:
|
| 28 |
-
st.header("⚙️
|
| 29 |
|
| 30 |
-
#
|
| 31 |
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY", "")
|
| 32 |
os.environ["ANTHROPIC_API_KEY"] = st.secrets.get("ANTHROPIC_API_KEY", "")
|
| 33 |
-
# 添加其他你可能需要的API密钥
|
| 34 |
# os.environ['GOOGLE_API_KEY'] = st.secrets.get("GOOGLE_API_KEY", "")
|
| 35 |
|
| 36 |
-
model_choice = st.selectbox("
|
| 37 |
steps_per_sample = st.slider(
|
| 38 |
-
"
|
| 39 |
)
|
| 40 |
|
| 41 |
-
#
|
| 42 |
try:
|
| 43 |
with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
|
| 44 |
golden_labels = json.load(f).get("samples", [])
|
| 45 |
total_samples = len(golden_labels)
|
| 46 |
num_samples_to_run = st.slider(
|
| 47 |
-
"
|
| 48 |
)
|
| 49 |
except FileNotFoundError:
|
| 50 |
-
st.error(
|
|
|
|
|
|
|
| 51 |
golden_labels = []
|
| 52 |
num_samples_to_run = 0
|
| 53 |
|
| 54 |
start_button = st.button(
|
| 55 |
-
"🚀
|
| 56 |
)
|
| 57 |
|
| 58 |
-
# --- Agent
|
| 59 |
if start_button:
|
| 60 |
-
#
|
| 61 |
test_samples = golden_labels[:num_samples_to_run]
|
| 62 |
|
| 63 |
config = MODELS_CONFIG.get(model_choice)
|
| 64 |
model_class = globals()[config["class"]]
|
| 65 |
model_instance_name = config["model_name"]
|
| 66 |
|
| 67 |
-
#
|
| 68 |
benchmark_helper = MapGuesserBenchmark()
|
| 69 |
all_results = []
|
| 70 |
|
| 71 |
st.info(
|
| 72 |
-
f"
|
| 73 |
)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
overall_progress_bar = st.progress(0, text="总进度")
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
|
| 82 |
|
| 83 |
-
#
|
| 84 |
for i, sample in enumerate(test_samples):
|
| 85 |
sample_id = sample.get("id", "N/A")
|
| 86 |
st.divider()
|
| 87 |
-
st.header(f"▶️
|
| 88 |
|
| 89 |
-
# 加载地图位置
|
| 90 |
if not bot.controller.load_location_from_data(sample):
|
| 91 |
-
st.error(f"
|
| 92 |
continue
|
| 93 |
|
| 94 |
bot.controller.setup_clean_environment()
|
| 95 |
|
| 96 |
-
#
|
| 97 |
col1, col2 = st.columns([2, 3])
|
| 98 |
with col1:
|
| 99 |
image_placeholder = st.empty()
|
|
@@ -101,25 +102,25 @@ if start_button:
|
|
| 101 |
reasoning_placeholder = st.empty()
|
| 102 |
action_placeholder = st.empty()
|
| 103 |
|
| 104 |
-
# ---
|
| 105 |
history = []
|
| 106 |
final_guess = None
|
| 107 |
|
| 108 |
for step in range(steps_per_sample):
|
| 109 |
step_num = step + 1
|
| 110 |
reasoning_placeholder.info(
|
| 111 |
-
f"
|
| 112 |
)
|
| 113 |
action_placeholder.empty()
|
| 114 |
|
| 115 |
-
#
|
| 116 |
bot.controller.label_arrows_on_screen()
|
| 117 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
| 118 |
image_placeholder.image(
|
| 119 |
screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
|
| 120 |
)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
history.append(
|
| 124 |
{
|
| 125 |
"image_b64": bot.pil_to_base64(
|
|
@@ -129,7 +130,7 @@ if start_button:
|
|
| 129 |
}
|
| 130 |
)
|
| 131 |
|
| 132 |
-
#
|
| 133 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
| 134 |
remaining_steps=steps_per_sample - step,
|
| 135 |
history_text="\n".join(
|
|
@@ -157,12 +158,12 @@ if start_button:
|
|
| 157 |
)
|
| 158 |
action_placeholder.success(f"**AI Action:** `{action}`")
|
| 159 |
|
| 160 |
-
#
|
| 161 |
if step_num == steps_per_sample and action != "GUESS":
|
| 162 |
-
st.warning("
|
| 163 |
action = "GUESS"
|
| 164 |
|
| 165 |
-
#
|
| 166 |
if action == "GUESS":
|
| 167 |
lat, lon = (
|
| 168 |
decision.get("action_details", {}).get("lat"),
|
|
@@ -171,10 +172,10 @@ if start_button:
|
|
| 171 |
if lat is not None and lon is not None:
|
| 172 |
final_guess = (lat, lon)
|
| 173 |
else:
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
break #
|
| 178 |
|
| 179 |
elif action == "MOVE_FORWARD":
|
| 180 |
bot.controller.move("forward")
|
|
@@ -185,9 +186,9 @@ if start_button:
|
|
| 185 |
elif action == "PAN_RIGHT":
|
| 186 |
bot.controller.pan_view("right")
|
| 187 |
|
| 188 |
-
time.sleep(1) #
|
| 189 |
|
| 190 |
-
# ---
|
| 191 |
true_coords = {"lat": sample.get("lat"), "lng": sample.get("lng")}
|
| 192 |
distance_km = None
|
| 193 |
is_success = False
|
|
@@ -197,23 +198,23 @@ if start_button:
|
|
| 197 |
if distance_km is not None:
|
| 198 |
is_success = distance_km <= SUCCESS_THRESHOLD_KM
|
| 199 |
|
| 200 |
-
st.subheader("🎯
|
| 201 |
res_col1, res_col2, res_col3 = st.columns(3)
|
| 202 |
res_col1.metric(
|
| 203 |
-
"
|
| 204 |
)
|
| 205 |
res_col2.metric(
|
| 206 |
-
"
|
| 207 |
f"{true_coords['lat']:.3f}, {true_coords['lng']:.3f}",
|
| 208 |
)
|
| 209 |
res_col3.metric(
|
| 210 |
-
"
|
| 211 |
f"{distance_km:.1f} km" if distance_km is not None else "N/A",
|
| 212 |
-
delta=f"{'
|
| 213 |
delta_color=("inverse" if is_success else "off"),
|
| 214 |
)
|
| 215 |
else:
|
| 216 |
-
st.error("Agent
|
| 217 |
|
| 218 |
all_results.append(
|
| 219 |
{
|
|
@@ -226,22 +227,27 @@ if start_button:
|
|
| 226 |
}
|
| 227 |
)
|
| 228 |
|
| 229 |
-
#
|
| 230 |
overall_progress_bar.progress(
|
| 231 |
-
(i + 1) / num_samples_to_run,
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
-
# ---
|
| 235 |
-
bot.close() #
|
| 236 |
st.divider()
|
| 237 |
-
st.header("🏁 Benchmark
|
| 238 |
|
| 239 |
summary = benchmark_helper.generate_summary(all_results)
|
| 240 |
if summary and model_choice in summary:
|
| 241 |
stats = summary[model_choice]
|
| 242 |
sum_col1, sum_col2 = st.columns(2)
|
| 243 |
-
sum_col1.metric(
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
else:
|
| 247 |
-
st.warning("
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
from typing import Dict, List, Any
|
| 8 |
|
| 9 |
+
# Import core logic and configurations from the project
|
| 10 |
from geo_bot import (
|
| 11 |
GeoBot,
|
| 12 |
AGENT_PROMPT_TEMPLATE,
|
| 13 |
BENCHMARK_PROMPT,
|
| 14 |
+
)
|
| 15 |
from benchmark import MapGuesserBenchmark
|
| 16 |
from config import MODELS_CONFIG, DATA_PATHS, SUCCESS_THRESHOLD_KM
|
| 17 |
from langchain_openai import ChatOpenAI
|
| 18 |
from langchain_anthropic import ChatAnthropic
|
| 19 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 20 |
|
| 21 |
+
# --- Page UI Setup ---
|
| 22 |
st.set_page_config(page_title="MapCrunch AI Agent", layout="wide")
|
| 23 |
st.title("🗺️ MapCrunch AI Agent")
|
| 24 |
+
st.caption(
|
| 25 |
+
"An AI agent that explores and identifies geographic locations through multi-step interaction."
|
| 26 |
+
)
|
| 27 |
|
| 28 |
+
# --- Sidebar for Configuration ---
|
| 29 |
with st.sidebar:
|
| 30 |
+
st.header("⚙️ Agent Configuration")
|
| 31 |
|
| 32 |
+
# Get API keys from HF Secrets (must be set in Space settings when deploying)
|
| 33 |
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY", "")
|
| 34 |
os.environ["ANTHROPIC_API_KEY"] = st.secrets.get("ANTHROPIC_API_KEY", "")
|
|
|
|
| 35 |
# os.environ['GOOGLE_API_KEY'] = st.secrets.get("GOOGLE_API_KEY", "")
|
| 36 |
|
| 37 |
+
model_choice = st.selectbox("Select AI Model", list(MODELS_CONFIG.keys()))
|
| 38 |
steps_per_sample = st.slider(
|
| 39 |
+
"Max Exploration Steps per Sample", min_value=3, max_value=20, value=10
|
| 40 |
)
|
| 41 |
|
| 42 |
+
# Load golden labels for selection
|
| 43 |
try:
|
| 44 |
with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
|
| 45 |
golden_labels = json.load(f).get("samples", [])
|
| 46 |
total_samples = len(golden_labels)
|
| 47 |
num_samples_to_run = st.slider(
|
| 48 |
+
"Number of Samples to Test", min_value=1, max_value=total_samples, value=3
|
| 49 |
)
|
| 50 |
except FileNotFoundError:
|
| 51 |
+
st.error(
|
| 52 |
+
f"Data file '{DATA_PATHS['golden_labels']}' not found. Please prepare the data."
|
| 53 |
+
)
|
| 54 |
golden_labels = []
|
| 55 |
num_samples_to_run = 0
|
| 56 |
|
| 57 |
start_button = st.button(
|
| 58 |
+
"🚀 Start Agent Benchmark", disabled=(num_samples_to_run == 0), type="primary"
|
| 59 |
)
|
| 60 |
|
| 61 |
+
# --- Agent Execution Logic ---
|
| 62 |
if start_button:
|
| 63 |
+
# Prepare the environment
|
| 64 |
test_samples = golden_labels[:num_samples_to_run]
|
| 65 |
|
| 66 |
config = MODELS_CONFIG.get(model_choice)
|
| 67 |
model_class = globals()[config["class"]]
|
| 68 |
model_instance_name = config["model_name"]
|
| 69 |
|
| 70 |
+
# Initialize helpers and result lists
|
| 71 |
benchmark_helper = MapGuesserBenchmark()
|
| 72 |
all_results = []
|
| 73 |
|
| 74 |
st.info(
|
| 75 |
+
f"Starting Agent Benchmark... Model: {model_choice}, Steps: {steps_per_sample}, Samples: {num_samples_to_run}"
|
| 76 |
)
|
| 77 |
|
| 78 |
+
overall_progress_bar = st.progress(0, text="Overall Progress")
|
|
|
|
| 79 |
|
| 80 |
+
# Initialize the bot outside the loop to reuse the browser instance for efficiency
|
| 81 |
+
with st.spinner("Initializing browser and AI model..."):
|
| 82 |
+
# Note: Must run in headless mode on HF Spaces
|
| 83 |
bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
|
| 84 |
|
| 85 |
+
# Main loop to iterate through all selected test samples
|
| 86 |
for i, sample in enumerate(test_samples):
|
| 87 |
sample_id = sample.get("id", "N/A")
|
| 88 |
st.divider()
|
| 89 |
+
st.header(f"▶️ Running Sample {i + 1}/{num_samples_to_run} (ID: {sample_id})")
|
| 90 |
|
|
|
|
| 91 |
if not bot.controller.load_location_from_data(sample):
|
| 92 |
+
st.error(f"Failed to load location for sample {sample_id}. Skipping.")
|
| 93 |
continue
|
| 94 |
|
| 95 |
bot.controller.setup_clean_environment()
|
| 96 |
|
| 97 |
+
# Create the visualization layout for the current sample
|
| 98 |
col1, col2 = st.columns([2, 3])
|
| 99 |
with col1:
|
| 100 |
image_placeholder = st.empty()
|
|
|
|
| 102 |
reasoning_placeholder = st.empty()
|
| 103 |
action_placeholder = st.empty()
|
| 104 |
|
| 105 |
+
# --- Inner agent exploration loop ---
|
| 106 |
history = []
|
| 107 |
final_guess = None
|
| 108 |
|
| 109 |
for step in range(steps_per_sample):
|
| 110 |
step_num = step + 1
|
| 111 |
reasoning_placeholder.info(
|
| 112 |
+
f"Thinking... (Step {step_num}/{steps_per_sample})"
|
| 113 |
)
|
| 114 |
action_placeholder.empty()
|
| 115 |
|
| 116 |
+
# Observe and label arrows
|
| 117 |
bot.controller.label_arrows_on_screen()
|
| 118 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
| 119 |
image_placeholder.image(
|
| 120 |
screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
|
| 121 |
)
|
| 122 |
|
| 123 |
+
# Update history
|
| 124 |
history.append(
|
| 125 |
{
|
| 126 |
"image_b64": bot.pil_to_base64(
|
|
|
|
| 130 |
}
|
| 131 |
)
|
| 132 |
|
| 133 |
+
# Think
|
| 134 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
| 135 |
remaining_steps=steps_per_sample - step,
|
| 136 |
history_text="\n".join(
|
|
|
|
| 158 |
)
|
| 159 |
action_placeholder.success(f"**AI Action:** `{action}`")
|
| 160 |
|
| 161 |
+
# Force a GUESS on the last step
|
| 162 |
if step_num == steps_per_sample and action != "GUESS":
|
| 163 |
+
st.warning("Max steps reached. Forcing a GUESS action.")
|
| 164 |
action = "GUESS"
|
| 165 |
|
| 166 |
+
# Act
|
| 167 |
if action == "GUESS":
|
| 168 |
lat, lon = (
|
| 169 |
decision.get("action_details", {}).get("lat"),
|
|
|
|
| 172 |
if lat is not None and lon is not None:
|
| 173 |
final_guess = (lat, lon)
|
| 174 |
else:
|
| 175 |
+
st.error(
|
| 176 |
+
"GUESS action was missing coordinates. Guess failed for this sample."
|
| 177 |
+
)
|
| 178 |
+
break # End exploration for the current sample
|
| 179 |
|
| 180 |
elif action == "MOVE_FORWARD":
|
| 181 |
bot.controller.move("forward")
|
|
|
|
| 186 |
elif action == "PAN_RIGHT":
|
| 187 |
bot.controller.pan_view("right")
|
| 188 |
|
| 189 |
+
time.sleep(1) # A brief pause between steps for better visualization
|
| 190 |
|
| 191 |
+
# --- End of single sample run, calculate and display results ---
|
| 192 |
true_coords = {"lat": sample.get("lat"), "lng": sample.get("lng")}
|
| 193 |
distance_km = None
|
| 194 |
is_success = False
|
|
|
|
| 198 |
if distance_km is not None:
|
| 199 |
is_success = distance_km <= SUCCESS_THRESHOLD_KM
|
| 200 |
|
| 201 |
+
st.subheader("🎯 Round Result")
|
| 202 |
res_col1, res_col2, res_col3 = st.columns(3)
|
| 203 |
res_col1.metric(
|
| 204 |
+
"Final Guess (Lat, Lon)", f"{final_guess[0]:.3f}, {final_guess[1]:.3f}"
|
| 205 |
)
|
| 206 |
res_col2.metric(
|
| 207 |
+
"Ground Truth (Lat, Lon)",
|
| 208 |
f"{true_coords['lat']:.3f}, {true_coords['lng']:.3f}",
|
| 209 |
)
|
| 210 |
res_col3.metric(
|
| 211 |
+
"Distance Error",
|
| 212 |
f"{distance_km:.1f} km" if distance_km is not None else "N/A",
|
| 213 |
+
delta=f"{'Success' if is_success else 'Failure'}",
|
| 214 |
delta_color=("inverse" if is_success else "off"),
|
| 215 |
)
|
| 216 |
else:
|
| 217 |
+
st.error("Agent failed to make a final guess.")
|
| 218 |
|
| 219 |
all_results.append(
|
| 220 |
{
|
|
|
|
| 227 |
}
|
| 228 |
)
|
| 229 |
|
| 230 |
+
# Update overall progress bar
|
| 231 |
overall_progress_bar.progress(
|
| 232 |
+
(i + 1) / num_samples_to_run,
|
| 233 |
+
text=f"Overall Progress: {i + 1}/{num_samples_to_run}",
|
| 234 |
)
|
| 235 |
|
| 236 |
+
# --- End of all samples, display final summary ---
|
| 237 |
+
bot.close() # Close the browser
|
| 238 |
st.divider()
|
| 239 |
+
st.header("🏁 Benchmark Summary")
|
| 240 |
|
| 241 |
summary = benchmark_helper.generate_summary(all_results)
|
| 242 |
if summary and model_choice in summary:
|
| 243 |
stats = summary[model_choice]
|
| 244 |
sum_col1, sum_col2 = st.columns(2)
|
| 245 |
+
sum_col1.metric(
|
| 246 |
+
"Overall Success Rate", f"{stats.get('success_rate', 0) * 100:.1f} %"
|
| 247 |
+
)
|
| 248 |
+
sum_col2.metric(
|
| 249 |
+
"Average Distance Error", f"{stats.get('average_distance_km', 0):.1f} km"
|
| 250 |
+
)
|
| 251 |
+
st.dataframe(all_results) # Display the detailed results table
|
| 252 |
else:
|
| 253 |
+
st.warning("Not enough results to generate a summary.")
|