File size: 5,568 Bytes
f52d137 6ac5bb0 f52d137 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import datetime
import logging
import os
import random
import subprocess
import uuid
import psutil
from backend.globals import FREE_PORTS_POOL, ORG, SPACE, STORAGE_PATH
from backend.session import UserSession
logger = logging.getLogger(__name__)
def start_are_process_and_session_lite(
model: str,
provider: str,
username: str,
bearer_token: str | None,
user_token: str | None,
app_path: str | None,
) -> UserSession:
if not user_token:
error_msg = (
f"HF_TOKEN (user_token) is None for user {username}. "
"Cannot start ARE session without Hugging Face token."
)
raise ValueError(error_msg)
global FREE_PORTS_POOL
port = random.sample(FREE_PORTS_POOL, k=1)[0]
log_path = f"{STORAGE_PATH}/log_{port}.log"
env_vars = dict(os.environ)
env_vars["ARE_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SIMULATION_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SERVER_PORT"] = str(port)
env_vars["ARE_SIMULATION_SERVER_PORT"] = str(port)
env_vars["HF_TOKEN"] = os.environ.get("HF_DATASET_TOKEN", user_token)
env_vars["HF_INFERENCE_TOKEN"] = os.environ.get("HF_INFERENCE_TOKEN", user_token)
env_vars["HF_DEMO_UNIVERSE"] = "universe_hf_0" # universe_hf"
bill_to = os.environ.get("HF_BILL_TO")
if bill_to:
env_vars["HF_BILL_TO"] = bill_to
llama_key = os.environ.get("LLAMA_API_KEY")
if llama_key:
env_vars["LLAMA_API_KEY"] = llama_key
env_vars["INTERACTIVE_SCENARIOS_TREE"] = "/app/mcp_demo_prompts.json"
if app_path:
env_vars["MCP_APPS_JSON_PATH"] = app_path
p = subprocess.Popen(
" ".join(
["python", "-u", "-m", "are.simulation.gui.cli", "-a", "default"]
+ ["-m", model, "--provider", provider]
+ [
"-s",
"scenario_hf_demo_mcp",
# "hf://datasets/meta-agents-research-environments/gaia2/demo/validation/universe_hf",
"--ui_view",
"playground",
] # scenario_universe_hf_0 or "scenario_hf_0" or "universe_hf_0"
+ ["2>&1", "|", "tee", log_path]
),
env=env_vars,
shell=True,
executable="/bin/bash",
)
FREE_PORTS_POOL = [p for p in FREE_PORTS_POOL if p != port]
user_session = UserSession(
port=int(port),
pid=p.pid,
sid=str(uuid.uuid4()),
model=model,
provider=provider,
log_path=log_path,
start_time=str(datetime.datetime.now()),
user=username,
sign=bearer_token,
)
return user_session
def kill_are_process(session: UserSession) -> None:
# Automatically kills the are processes and all their children
global FREE_PORTS_POOL
try:
# Get the main process
main_process = psutil.Process(session.pid)
# Get all child processes recursively
children = main_process.children(recursive=True)
# Kill all child processes first
for child in children:
try:
child.kill()
logger.info(f"Killed child process PID {child.pid}")
except psutil.NoSuchProcess:
logger.info(f"Child process PID {child.pid} already terminated")
except OSError:
logger.warning(f"Child process PID {child.pid} not found")
# Wait for child processes to terminate
for child in children:
try:
child.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Child process PID {child.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
# Kill the main process
main_process.kill()
logger.info(f"Sent SIGKILL to main PID {session.pid}")
# Wait for main process to terminate
try:
main_process.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Main process PID {session.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
FREE_PORTS_POOL.append(session.port)
logger.info(
f"Killed session {session.sid} PID {session.pid} on port {session.port} for user {session.user}"
)
except psutil.NoSuchProcess:
logger.info(f"Process PID {session.pid} not found - may already be terminated")
FREE_PORTS_POOL.append(session.port)
except OSError:
logger.error(
f"COULD NOT KILL ARE on port {session.port} for user {session.user}",
exc_info=True,
)
def get_are_url(session: UserSession, server: str) -> str:
"""Generates the are url
Args:
port (str): Port on which the app is running
session_id (str): Session id in ARE
sign (str): Auth key provided by the query
server (str): Must be either "are" or "graphql"
Returns:
str: The url to look at
"""
# Check if we're in development mode
flask_env = os.environ.get("FLASK_ENV", "production")
if flask_env == "development":
# In development mode, use localhost with the actual ARE port
return f"http://localhost:{session.port}/{server}?sid={session.sid}&__sign={session.sign}"
else:
# In production mode, use Hugging Face Space URL
return f"https://{ORG.lower()}-{SPACE.lower()}--{session.port}.hf.space/{server}?sid={session.sid}&__sign={session.sign}"
|