Spaces:
Running
Running
Niki Zhang
commited on
Update chatbox.py
Browse files- chatbox.py +10 -4
chatbox.py
CHANGED
|
@@ -12,7 +12,7 @@ import inspect
|
|
| 12 |
|
| 13 |
from langchain.agents.initialize import initialize_agent
|
| 14 |
from langchain.agents.tools import Tool
|
| 15 |
-
from langchain.memory import ConversationBufferMemory
|
| 16 |
from langchain_community.chat_models import ChatOpenAI
|
| 17 |
import torch
|
| 18 |
from PIL import Image, ImageDraw, ImageOps
|
|
@@ -141,7 +141,7 @@ class ConversationBot:
|
|
| 141 |
def __init__(self, tools, api_key=""):
|
| 142 |
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
| 143 |
print("chatbot api",api_key)
|
| 144 |
-
llm = ChatOpenAI(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key
|
| 145 |
self.llm = llm
|
| 146 |
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
| 147 |
self.tools = tools
|
|
@@ -172,11 +172,17 @@ class ConversationBot:
|
|
| 172 |
return ans
|
| 173 |
|
| 174 |
def run_text(self, text, state, aux_state):
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if self.point_prompt != "":
|
| 177 |
Human_prompt = f'\nHuman: {self.point_prompt}\n'
|
| 178 |
AI_prompt = 'Ok'
|
| 179 |
-
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
|
|
|
| 180 |
self.point_prompt = ""
|
| 181 |
res = self.agent({"input": text})
|
| 182 |
res['output'] = res['output'].replace("\\", "/")
|
|
|
|
| 12 |
|
| 13 |
from langchain.agents.initialize import initialize_agent
|
| 14 |
from langchain.agents.tools import Tool
|
| 15 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
| 16 |
from langchain_community.chat_models import ChatOpenAI
|
| 17 |
import torch
|
| 18 |
from PIL import Image, ImageDraw, ImageOps
|
|
|
|
| 141 |
def __init__(self, tools, api_key=""):
|
| 142 |
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
| 143 |
print("chatbot api",api_key)
|
| 144 |
+
llm = ChatOpenAI(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key)
|
| 145 |
self.llm = llm
|
| 146 |
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
| 147 |
self.tools = tools
|
|
|
|
| 172 |
return ans
|
| 173 |
|
| 174 |
def run_text(self, text, state, aux_state):
|
| 175 |
+
memory_str = self.agent.memory.buffer_as_str
|
| 176 |
+
trimmed_memory_str = cut_dialogue_history(memory_str, keep_last_n_words=500)
|
| 177 |
+
trimmed_messages = self.memory.buffer_as_messages[:len(trimmed_memory_str.split())]
|
| 178 |
+
# self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
| 179 |
+
self.memory.chat_memory.messages = trimmed_messages
|
| 180 |
+
print("done")
|
| 181 |
if self.point_prompt != "":
|
| 182 |
Human_prompt = f'\nHuman: {self.point_prompt}\n'
|
| 183 |
AI_prompt = 'Ok'
|
| 184 |
+
# self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
| 185 |
+
self.agent.memory.save_context({'input': Human_prompt}, {'output': AI_prompt})
|
| 186 |
self.point_prompt = ""
|
| 187 |
res = self.agent({"input": text})
|
| 188 |
res['output'] = res['output'].replace("\\", "/")
|