MADtoBAD commited on
Commit
de90e17
·
verified ·
1 Parent(s): d29a348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -70
app.py CHANGED
@@ -1,93 +1,105 @@
1
  import gradio as gr
2
- from smolagents import CodeAgent, DuckDuckGoSearchTool
3
- from smolagents.models import TransformersModel
4
 
5
- class SimpleAIAgent:
6
  def __init__(self):
7
- print("Initializing AI Agent...")
8
- self.model = TransformersModel("microsoft/DialoGPT-small")
9
-
10
- self.search_tool = DuckDuckGoSearchTool()
11
-
12
- self.agent = CodeAgent(
13
- tools=[self.search_tool],
14
- model=self.model,
15
- max_steps=4
16
- )
17
-
18
- print("AI Agent ready!")
19
-
20
- def chat(self, message, history):
21
- """
22
- Основная функция для общения с агентом
23
- """
24
- print(f"User asked: {message}")
25
-
26
- prompt = f"""
27
- The user asked: {message}
28
 
29
- Please provide a helpful and accurate answer.
30
- If you need current information, use the search tool to find it online.
31
- Keep your response clear and conversational.
32
- """
33
 
 
 
 
 
 
 
 
34
  try:
35
- response = self.agent.run(prompt)
36
-
37
- clean_response = self.clean_answer(response)
38
- print(f"Agent replied: {clean_response[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- return clean_response
 
 
 
 
 
 
41
 
42
  except Exception as e:
43
- error_msg = f"Sorry, I encountered an error: {str(e)}"
44
- print(f"Error: {e}")
45
- return error_msg
46
-
47
- def clean_answer(self, answer):
48
- """
49
- Убираем техническую информацию из ответа агента
50
- """
51
- lines = answer.split('\n')
52
- clean_lines = []
53
 
54
- for line in lines:
55
- lower_line = line.lower()
56
- if any(word in lower_line for word in ['tool:', 'searching', 'step', 'using tool']):
57
- continue
58
-
59
- if line.strip():
60
- clean_lines.append(line)
61
-
62
- result = '\n'.join(clean_lines).strip()
63
-
64
- if len(result) > 1500:
65
- result = result[:1497] + "..."
66
 
67
- return result if result else "I couldn't find a good answer to that question."
68
-
69
- ai_agent = SimpleAIAgent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with gr.Blocks(title="My AI Assistant") as chat_interface:
72
- gr.Markdown("# My AI Assistant")
73
- gr.Markdown("Ask me anything! I can search the internet for current information.")
74
 
75
- chatbot = gr.Chatbot(height=400)
 
 
 
 
 
76
  msg = gr.Textbox(
77
- label="Your question",
78
- placeholder="Ask me anything...",
79
  lines=2
80
  )
81
- clear_btn = gr.Button("Clear Chat")
82
 
83
  def respond(message, chat_history):
84
- bot_response = ai_agent.chat(message, chat_history)
85
- chat_history.append((message, bot_response))
86
  return "", chat_history
87
 
88
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
89
- clear_btn.click(lambda: None, None, chatbot, queue=False)
90
 
91
  if __name__ == "__main__":
92
- print("Starting AI Chat Assistant...")
93
- chat_interface.launch(share=True)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ class FixedAIAgent:
6
  def __init__(self):
7
+ print("Loading DialoGPT-small...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Загружаем модель и токенизатор
10
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
11
+ self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
 
12
 
13
+ # Добавляем pad token если его нет
14
+ if self.tokenizer.pad_token is None:
15
+ self.tokenizer.pad_token = self.tokenizer.eos_token
16
+
17
+ print("DialoGPT-small loaded successfully!")
18
+
19
+ def chat(self, message, history):
20
  try:
21
+ # Форматируем историю для DialoGPT
22
+ input_text = self.format_conversation(message, history)
23
+
24
+ # Токенизируем входной текст
25
+ inputs = self.tokenizer.encode(input_text + self.tokenizer.eos_token, return_tensors='pt')
26
+
27
+ # Генерируем ответ
28
+ with torch.no_grad():
29
+ outputs = self.model.generate(
30
+ inputs,
31
+ max_length=1000,
32
+ pad_token_id=self.tokenizer.eos_token_id,
33
+ do_sample=True,
34
+ temperature=0.7,
35
+ top_k=50,
36
+ top_p=0.95,
37
+ repetition_penalty=1.2
38
+ )
39
 
40
+ # Декодируем ответ
41
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+
43
+ # Извлекаем только новый ответ (убираем историю)
44
+ bot_response = self.extract_new_response(input_text, response)
45
+
46
+ return bot_response
47
 
48
  except Exception as e:
49
+ return f"Error: {str(e)}"
50
+
51
+ def format_conversation(self, message, history):
52
+ """Форматирует историю чата для DialoGPT"""
53
+ # Начинаем с нового сообщения
54
+ conversation = f"User: {message}"
 
 
 
 
55
 
56
+ # Добавляем историю (последние 2-3 сообщения)
57
+ if history:
58
+ # Берем последние 2 обмена
59
+ recent_history = history[-2:] if len(history) > 2 else history
 
 
 
 
 
 
 
 
60
 
61
+ for user_msg, bot_msg in recent_history:
62
+ conversation = f"User: {user_msg}\nBot: {bot_msg}\n" + conversation
63
+
64
+ return conversation
65
+
66
+ def extract_new_response(self, input_text, full_response):
67
+ """Извлекает только новый ответ из полного ответа модели"""
68
+ if input_text in full_response:
69
+ # Убираем входной текст чтобы оставить только новый ответ
70
+ new_response = full_response[len(input_text):].strip()
71
+ # Убираем возможные префиксы
72
+ if new_response.startswith("Bot:"):
73
+ new_response = new_response[4:].strip()
74
+ return new_response
75
+ else:
76
+ # Если не нашли входной текст, возвращаем как есть
77
+ return full_response
78
 
79
+ # Создаем агента
80
+ agent = FixedAIAgent()
 
81
 
82
+ # Создаем интерфейс
83
+ with gr.Blocks() as app:
84
+ gr.Markdown("# AI Chat Assistant")
85
+ gr.Markdown("Powered by DialoGPT-small")
86
+
87
+ chatbot = gr.Chatbot(height=400, label="Chat History")
88
  msg = gr.Textbox(
89
+ label="Your message",
90
+ placeholder="Type your message here...",
91
  lines=2
92
  )
93
+ clear = gr.Button("Clear Chat")
94
 
95
  def respond(message, chat_history):
96
+ bot_message = agent.chat(message, chat_history)
97
+ chat_history.append((message, bot_message))
98
  return "", chat_history
99
 
100
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
101
+ clear.click(lambda: None, None, chatbot)
102
 
103
  if __name__ == "__main__":
104
+ print("Starting AI Chat...")
105
+ app.lalunch()