File size: 4,656 Bytes
570b60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d1b8d4
570b60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d1b8d4
570b60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b68ef9
570b60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b68ef9
570b60c
 
 
 
 
 
 
5b68ef9
 
570b60c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from dotenv import load_dotenv


import gradio as gr
from gradio import ChatMessage

import json
from openai import OpenAI
from datetime import datetime
import os
import re

from termcolor import cprint
import logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s')
log = logging.getLogger(__name__)



from omegaconf import OmegaConf
from src.tools import tools, oitools


# Load the configuration file
# ===========================================================================
# Environment variables
load_dotenv(".env", override=True) 
HF_TOKEN = os.environ.get("HF_TOKEN")
LLM_BASE_URL = os.environ.get("LLM_BASE_URL")

log.info(f"Using HF_TOKEN: {HF_TOKEN[:4]}...{HF_TOKEN[-4:]}")
log.info(f"Using LLM_BASE_URL: {LLM_BASE_URL[:15]}...")

# Configuration file
config_file = "config.yaml"
cfg = OmegaConf.load(config_file)

# OpenAI API parameters
chat_params = cfg.openai.chat_params
client = OpenAI(
    base_url=f"{LLM_BASE_URL}",
    api_key=HF_TOKEN
)
logging.info(f"Client initialized: {client}")
# ===========================================================================


def today_date():
    return datetime.today().strftime('%A, %B %d, %Y, %I:%M %p')


def clean_json_string(json_str):
    return re.sub(r'[ ,}\s]+$', '', json_str) + '}'


def completion(history, model, system_prompt: str, tools=None, chat_params=chat_params):
    messages = [{"role": "system", "content": system_prompt.format(date=today_date())}]
    for msg in history:
        if isinstance(msg, dict):  
            msg = ChatMessage(**msg)
        if msg.role == "assistant" and hasattr(msg, "metadata") and msg.metadata:  
            tools_calls = json.loads(msg.metadata.get("title", "[]")) 
            messages.append({"role": "assistant", "tool_calls": tools_calls, "content": ""})
            messages.append({"role": "tool", "content": msg.content})
        else:
            messages.append({"role": msg.role, "content": msg.content})
    
    request_params = {
        "model": model,
        "messages": messages,
        **chat_params
    }
    if tools:
        request_params.update({"tool_choice": "auto", "tools": tools})

    return client.chat.completions.create(**request_params)  


def llm_in_loop(history, system_prompt, recursive):  

    try:   
        models = client.models.list()
        model = models.data[0].id
    except Exception as err:
        gr.Warning("The model is initializing. Please wait; this may take 5 to 10 minutes ⏳.", duration=20)
        raise err
    
    arguments = ""
    name = ""
    chat_completion = completion(history=history, tools=oitools, model=model, system_prompt=system_prompt)  
    appended = False


    for chunk in chat_completion:
        if chunk.choices and chunk.choices[0].delta.tool_calls:
            call = chunk.choices[0].delta.tool_calls[0]
            if hasattr(call.function, "name") and call.function.name:
                name = call.function.name
            if hasattr(call.function, "arguments") and call.function.arguments:
                arguments += call.function.arguments

        elif chunk.choices[0].delta.content:
            if not appended:
                history.append(ChatMessage(role="assistant", content=""))
                appended = True
            history[-1].content += chunk.choices[0].delta.content
            yield history[recursive:]
    
    # Convert arguments to a valid JSON
    arguments = clean_json_string(arguments) if arguments else "{}"
    arguments = json.loads(arguments)


    if appended:
        recursive -= 1
    if name:
        try:
            result = str(tools[name].invoke(input=arguments))

        except Exception as err:
            result = f"💥 Error: {err}"

        history.append(ChatMessage(
            role="assistant", 
            content=result, 
            metadata={"title": json.dumps([{"id": "call_id", "function": {"arguments": json.dumps(arguments, ensure_ascii=False), "name": name}, "type": "function"}], ensure_ascii=False)}))
        
        yield history[recursive:]
        yield from llm_in_loop(history, system_prompt, recursive - 1)


def respond(message, history, additional_inputs):  
    
    history.append(ChatMessage(role="user", content=message))
    yield from llm_in_loop(history, additional_inputs, -1)



if __name__ == "__main__":

    # system_prompt = gr.State(value=cfg.system_prompt_template)
    system_prompt = gr.Textbox(label="System prompt", value=cfg.system_prompt_template, lines=10, visible=False)
    demo = gr.ChatInterface(respond, type="messages", additional_inputs=[system_prompt])
    demo.launch()