File size: 5,568 Bytes
f52d137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac5bb0
f52d137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import datetime
import logging
import os
import random
import subprocess
import uuid

import psutil
from backend.globals import FREE_PORTS_POOL, ORG, SPACE, STORAGE_PATH
from backend.session import UserSession

logger = logging.getLogger(__name__)


def start_are_process_and_session_lite(
    model: str,
    provider: str,
    username: str,
    bearer_token: str | None,
    user_token: str | None,
    app_path: str | None,
) -> UserSession:
    if not user_token:
        error_msg = (
            f"HF_TOKEN (user_token) is None for user {username}. "
            "Cannot start ARE session without Hugging Face token."
        )
        raise ValueError(error_msg)

    global FREE_PORTS_POOL
    port = random.sample(FREE_PORTS_POOL, k=1)[0]

    log_path = f"{STORAGE_PATH}/log_{port}.log"
    env_vars = dict(os.environ)
    env_vars["ARE_SERVER_HOSTNAME"] = "0.0.0.0"
    env_vars["ARE_SIMULATION_SERVER_HOSTNAME"] = "0.0.0.0"
    env_vars["ARE_SERVER_PORT"] = str(port)
    env_vars["ARE_SIMULATION_SERVER_PORT"] = str(port)
    env_vars["HF_TOKEN"] = os.environ.get("HF_DATASET_TOKEN", user_token)
    env_vars["HF_INFERENCE_TOKEN"] = os.environ.get("HF_INFERENCE_TOKEN", user_token)
    env_vars["HF_DEMO_UNIVERSE"] = "universe_hf_0"  # universe_hf"
    bill_to = os.environ.get("HF_BILL_TO")
    if bill_to:
        env_vars["HF_BILL_TO"] = bill_to
    llama_key = os.environ.get("LLAMA_API_KEY")
    if llama_key:
        env_vars["LLAMA_API_KEY"] = llama_key
    env_vars["INTERACTIVE_SCENARIOS_TREE"] = "/app/mcp_demo_prompts.json"
    if app_path:
        env_vars["MCP_APPS_JSON_PATH"] = app_path

    p = subprocess.Popen(
        " ".join(
            ["python", "-u", "-m", "are.simulation.gui.cli", "-a", "default"]
            + ["-m", model, "--provider", provider]
            + [
                "-s",
                "scenario_hf_demo_mcp",
                # "hf://datasets/meta-agents-research-environments/gaia2/demo/validation/universe_hf",
                "--ui_view",
                "playground",
            ]  # scenario_universe_hf_0 or "scenario_hf_0" or "universe_hf_0"
            + ["2>&1", "|", "tee", log_path]
        ),
        env=env_vars,
        shell=True,
        executable="/bin/bash",
    )

    FREE_PORTS_POOL = [p for p in FREE_PORTS_POOL if p != port]
    user_session = UserSession(
        port=int(port),
        pid=p.pid,
        sid=str(uuid.uuid4()),
        model=model,
        provider=provider,
        log_path=log_path,
        start_time=str(datetime.datetime.now()),
        user=username,
        sign=bearer_token,
    )

    return user_session


def kill_are_process(session: UserSession) -> None:
    # Automatically kills the are processes and all their children
    global FREE_PORTS_POOL

    try:
        # Get the main process
        main_process = psutil.Process(session.pid)

        # Get all child processes recursively
        children = main_process.children(recursive=True)

        # Kill all child processes first
        for child in children:
            try:
                child.kill()
                logger.info(f"Killed child process PID {child.pid}")
            except psutil.NoSuchProcess:
                logger.info(f"Child process PID {child.pid} already terminated")
            except OSError:
                logger.warning(f"Child process PID {child.pid} not found")

        # Wait for child processes to terminate
        for child in children:
            try:
                child.wait(timeout=5)
            except psutil.TimeoutExpired:
                logger.warning(
                    f"Child process PID {child.pid} did not terminate within timeout"
                )
            except psutil.NoSuchProcess:
                pass

        # Kill the main process
        main_process.kill()
        logger.info(f"Sent SIGKILL to main PID {session.pid}")

        # Wait for main process to terminate
        try:
            main_process.wait(timeout=5)
        except psutil.TimeoutExpired:
            logger.warning(
                f"Main process PID {session.pid} did not terminate within timeout"
            )
        except psutil.NoSuchProcess:
            pass

        FREE_PORTS_POOL.append(session.port)
        logger.info(
            f"Killed session {session.sid} PID {session.pid} on port {session.port} for user {session.user}"
        )

    except psutil.NoSuchProcess:
        logger.info(f"Process PID {session.pid} not found - may already be terminated")
        FREE_PORTS_POOL.append(session.port)
    except OSError:
        logger.error(
            f"COULD NOT KILL ARE on port {session.port} for user {session.user}",
            exc_info=True,
        )


def get_are_url(session: UserSession, server: str) -> str:
    """Generates the are url

    Args:
        port (str): Port on which the app is running
        session_id (str): Session id in ARE
        sign (str): Auth key provided by the query
        server (str): Must be either "are" or "graphql"

    Returns:
        str: The url to look at
    """
    # Check if we're in development mode
    flask_env = os.environ.get("FLASK_ENV", "production")

    if flask_env == "development":
        # In development mode, use localhost with the actual ARE port
        return f"http://localhost:{session.port}/{server}?sid={session.sid}&__sign={session.sign}"
    else:
        # In production mode, use Hugging Face Space URL
        return f"https://{ORG.lower()}-{SPACE.lower()}--{session.port}.hf.space/{server}?sid={session.sid}&__sign={session.sign}"