demo / backend /are.py
Pierre Andrews
use custom token for inference
6ac5bb0
raw
history blame
5.57 kB
import datetime
import logging
import os
import random
import subprocess
import uuid
import psutil
from backend.globals import FREE_PORTS_POOL, ORG, SPACE, STORAGE_PATH
from backend.session import UserSession
logger = logging.getLogger(__name__)
def start_are_process_and_session_lite(
model: str,
provider: str,
username: str,
bearer_token: str | None,
user_token: str | None,
app_path: str | None,
) -> UserSession:
if not user_token:
error_msg = (
f"HF_TOKEN (user_token) is None for user {username}. "
"Cannot start ARE session without Hugging Face token."
)
raise ValueError(error_msg)
global FREE_PORTS_POOL
port = random.sample(FREE_PORTS_POOL, k=1)[0]
log_path = f"{STORAGE_PATH}/log_{port}.log"
env_vars = dict(os.environ)
env_vars["ARE_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SIMULATION_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SERVER_PORT"] = str(port)
env_vars["ARE_SIMULATION_SERVER_PORT"] = str(port)
env_vars["HF_TOKEN"] = os.environ.get("HF_DATASET_TOKEN", user_token)
env_vars["HF_INFERENCE_TOKEN"] = os.environ.get("HF_INFERENCE_TOKEN", user_token)
env_vars["HF_DEMO_UNIVERSE"] = "universe_hf_0" # universe_hf"
bill_to = os.environ.get("HF_BILL_TO")
if bill_to:
env_vars["HF_BILL_TO"] = bill_to
llama_key = os.environ.get("LLAMA_API_KEY")
if llama_key:
env_vars["LLAMA_API_KEY"] = llama_key
env_vars["INTERACTIVE_SCENARIOS_TREE"] = "/app/mcp_demo_prompts.json"
if app_path:
env_vars["MCP_APPS_JSON_PATH"] = app_path
p = subprocess.Popen(
" ".join(
["python", "-u", "-m", "are.simulation.gui.cli", "-a", "default"]
+ ["-m", model, "--provider", provider]
+ [
"-s",
"scenario_hf_demo_mcp",
# "hf://datasets/meta-agents-research-environments/gaia2/demo/validation/universe_hf",
"--ui_view",
"playground",
] # scenario_universe_hf_0 or "scenario_hf_0" or "universe_hf_0"
+ ["2>&1", "|", "tee", log_path]
),
env=env_vars,
shell=True,
executable="/bin/bash",
)
FREE_PORTS_POOL = [p for p in FREE_PORTS_POOL if p != port]
user_session = UserSession(
port=int(port),
pid=p.pid,
sid=str(uuid.uuid4()),
model=model,
provider=provider,
log_path=log_path,
start_time=str(datetime.datetime.now()),
user=username,
sign=bearer_token,
)
return user_session
def kill_are_process(session: UserSession) -> None:
# Automatically kills the are processes and all their children
global FREE_PORTS_POOL
try:
# Get the main process
main_process = psutil.Process(session.pid)
# Get all child processes recursively
children = main_process.children(recursive=True)
# Kill all child processes first
for child in children:
try:
child.kill()
logger.info(f"Killed child process PID {child.pid}")
except psutil.NoSuchProcess:
logger.info(f"Child process PID {child.pid} already terminated")
except OSError:
logger.warning(f"Child process PID {child.pid} not found")
# Wait for child processes to terminate
for child in children:
try:
child.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Child process PID {child.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
# Kill the main process
main_process.kill()
logger.info(f"Sent SIGKILL to main PID {session.pid}")
# Wait for main process to terminate
try:
main_process.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Main process PID {session.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
FREE_PORTS_POOL.append(session.port)
logger.info(
f"Killed session {session.sid} PID {session.pid} on port {session.port} for user {session.user}"
)
except psutil.NoSuchProcess:
logger.info(f"Process PID {session.pid} not found - may already be terminated")
FREE_PORTS_POOL.append(session.port)
except OSError:
logger.error(
f"COULD NOT KILL ARE on port {session.port} for user {session.user}",
exc_info=True,
)
def get_are_url(session: UserSession, server: str) -> str:
"""Generates the are url
Args:
port (str): Port on which the app is running
session_id (str): Session id in ARE
sign (str): Auth key provided by the query
server (str): Must be either "are" or "graphql"
Returns:
str: The url to look at
"""
# Check if we're in development mode
flask_env = os.environ.get("FLASK_ENV", "production")
if flask_env == "development":
# In development mode, use localhost with the actual ARE port
return f"http://localhost:{session.port}/{server}?sid={session.sid}&__sign={session.sign}"
else:
# In production mode, use Hugging Face Space URL
return f"https://{ORG.lower()}-{SPACE.lower()}--{session.port}.hf.space/{server}?sid={session.sid}&__sign={session.sign}"