Spaces:
Build error
Build error
Quyet
commited on
Commit
·
93152cc
1
Parent(s):
f30862a
fix global vs local chatbot
Browse files
README.md
CHANGED
|
@@ -18,6 +18,7 @@ For more information about this product, please visit this notion [page](https:/
|
|
| 18 |
|
| 19 |
### 2022/12/20
|
| 20 |
|
|
|
|
| 21 |
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
| 22 |
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
| 23 |
- TODO is written in the main file already
|
|
|
|
| 18 |
|
| 19 |
### 2022/12/20
|
| 20 |
|
| 21 |
+
- DONE turning the chatbot to session varible so that different sessions will show different conversation
|
| 22 |
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
| 23 |
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
| 24 |
- TODO is written in the main file already
|
app.py
CHANGED
|
@@ -8,9 +8,10 @@ reference:
|
|
| 8 |
|
| 9 |
gradio vs streamlit
|
| 10 |
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
| 11 |
-
https://gradio.app/interface_state/
|
| 12 |
|
| 13 |
TODO
|
|
|
|
| 14 |
Add diagram in Gradio Interface showing sentimate analysis
|
| 15 |
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
| 16 |
Personalize: create database, load and save data
|
|
@@ -40,8 +41,21 @@ def option():
|
|
| 40 |
args = parser.parse_args()
|
| 41 |
return args
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
class ChatHelper: # store the list of messages that are showed in therapies
|
| 45 |
invalid_input = 'Invalid input, my friend :) Plz input again'
|
| 46 |
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
| 47 |
good_case = 'Nice to hear that!'
|
|
@@ -130,28 +144,19 @@ class TherapyChatBot:
|
|
| 130 |
self.euc_100_emotion_degree = []
|
| 131 |
self.already_trigger_euc_200 = False
|
| 132 |
|
| 133 |
-
# chat and emotion-detection models
|
| 134 |
-
self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
| 135 |
-
self.ed_threshold = 0.3
|
| 136 |
-
self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
| 137 |
-
self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
| 138 |
-
self.eos = self.dialog_tokenizer.eos_token
|
| 139 |
-
# tokenizer.__call__ -> input_ids, attention_mask
|
| 140 |
-
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
| 141 |
-
|
| 142 |
# chat history.
|
| 143 |
# TODO: if we want to personalize and save the conversation,
|
| 144 |
# we can load data from database
|
| 145 |
-
self.greeting = ChatHelper.greeting_template[self.chat_state]
|
| 146 |
-
self.history = {'input_ids': torch.tensor([[
|
| 147 |
-
'text':
|
| 148 |
if 'euc_100' in self.chat_state:
|
| 149 |
self.chat_state = 'euc_100.q.0'
|
| 150 |
|
| 151 |
def __call__(self, message, prefix=''):
|
| 152 |
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
| 153 |
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
| 154 |
-
prediction =
|
| 155 |
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
| 156 |
if self.run_on_own_server:
|
| 157 |
print(prediction)
|
|
@@ -160,7 +165,7 @@ class TherapyChatBot:
|
|
| 160 |
|
| 161 |
# if message is negative, change state immediately
|
| 162 |
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
| 163 |
-
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] >
|
| 164 |
self.chat_state_prev = self.chat_state
|
| 165 |
self.chat_state = 'euc_200'
|
| 166 |
self.message_prev = message
|
|
@@ -171,7 +176,7 @@ class TherapyChatBot:
|
|
| 171 |
elif self.chat_state.startswith('euc_100'):
|
| 172 |
response = self.euc_100(message)
|
| 173 |
if self.chat_state == 'free_chat':
|
| 174 |
-
last_two_turns_ids =
|
| 175 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
| 176 |
|
| 177 |
elif self.chat_state.startswith('euc_200'):
|
|
@@ -185,7 +190,6 @@ class TherapyChatBot:
|
|
| 185 |
self.history['text'].append((self.message_prev, response))
|
| 186 |
else:
|
| 187 |
self.history['text'].append((message, response))
|
| 188 |
-
return self.history['text']
|
| 189 |
|
| 190 |
def euc_100(self, x):
|
| 191 |
_, subsection, entry = self.chat_state.split('.')
|
|
@@ -251,23 +255,23 @@ class TherapyChatBot:
|
|
| 251 |
message = self.message_prev
|
| 252 |
self.message_prev = x
|
| 253 |
self.chat_state = self.chat_state_prev
|
| 254 |
-
return self.__call__(message, response)
|
| 255 |
|
| 256 |
def free_chat(self, message):
|
| 257 |
-
message_ids =
|
| 258 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
| 259 |
input_ids = self.history['input_ids'].clone()
|
| 260 |
|
| 261 |
while True:
|
| 262 |
-
bot_output_ids =
|
| 263 |
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
| 264 |
-
pad_token_id=
|
| 265 |
-
response =
|
| 266 |
skip_special_tokens=True)
|
| 267 |
if response.strip() != '':
|
| 268 |
break
|
| 269 |
-
elif input_ids[0].tolist().count(
|
| 270 |
-
idx = input_ids[0].tolist().index(
|
| 271 |
input_ids = input_ids[:, (idx+1):]
|
| 272 |
else:
|
| 273 |
input_ids = message_ids
|
|
@@ -282,20 +286,22 @@ class TherapyChatBot:
|
|
| 282 |
return response
|
| 283 |
|
| 284 |
|
| 285 |
-
if __name__ == '__main__':
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
| 288 |
|
| 289 |
title = 'PsyPlus Empathetic Chatbot'
|
| 290 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
| 291 |
-
|
|
|
|
| 292 |
iface = gr.Interface(
|
| 293 |
-
chat, 'text', chatbot,
|
| 294 |
allow_flagging='never', title=title, description=description,
|
| 295 |
)
|
| 296 |
|
| 297 |
-
# iface.queue(concurrency_count=5)
|
| 298 |
if args.run_on_own_server == 0:
|
| 299 |
iface.launch(debug=True)
|
| 300 |
else:
|
| 301 |
-
iface.launch(debug=True, share=True)
|
|
|
|
| 8 |
|
| 9 |
gradio vs streamlit
|
| 10 |
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
| 11 |
+
https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions
|
| 12 |
|
| 13 |
TODO
|
| 14 |
+
Add command to reset/jump to a function, e.g >reset, >euc_100
|
| 15 |
Add diagram in Gradio Interface showing sentimate analysis
|
| 16 |
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
| 17 |
Personalize: create database, load and save data
|
|
|
|
| 41 |
args = parser.parse_args()
|
| 42 |
return args
|
| 43 |
|
| 44 |
+
args = option()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# store the list of messages that are showed in therapies and models as global variables
|
| 48 |
+
# let all chat-session-wise variables placed in TherapyChatBot
|
| 49 |
+
class ChatHelper:
|
| 50 |
+
# chat and emotion-detection models
|
| 51 |
+
ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
| 52 |
+
ed_threshold = 0.3
|
| 53 |
+
dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
| 54 |
+
dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
| 55 |
+
eos = dialog_tokenizer.eos_token
|
| 56 |
+
# tokenizer.__call__ -> input_ids, attention_mask
|
| 57 |
+
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
| 58 |
|
|
|
|
| 59 |
invalid_input = 'Invalid input, my friend :) Plz input again'
|
| 60 |
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
| 61 |
good_case = 'Nice to hear that!'
|
|
|
|
| 144 |
self.euc_100_emotion_degree = []
|
| 145 |
self.already_trigger_euc_200 = False
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# chat history.
|
| 148 |
# TODO: if we want to personalize and save the conversation,
|
| 149 |
# we can load data from database
|
| 150 |
+
self.greeting = [('', ChatHelper.greeting_template[self.chat_state])]
|
| 151 |
+
self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]),
|
| 152 |
+
'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb')
|
| 153 |
if 'euc_100' in self.chat_state:
|
| 154 |
self.chat_state = 'euc_100.q.0'
|
| 155 |
|
| 156 |
def __call__(self, message, prefix=''):
|
| 157 |
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
| 158 |
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
| 159 |
+
prediction = ChatHelper.ed_pipe(message)[0]
|
| 160 |
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
| 161 |
if self.run_on_own_server:
|
| 162 |
print(prediction)
|
|
|
|
| 165 |
|
| 166 |
# if message is negative, change state immediately
|
| 167 |
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
| 168 |
+
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold):
|
| 169 |
self.chat_state_prev = self.chat_state
|
| 170 |
self.chat_state = 'euc_200'
|
| 171 |
self.message_prev = message
|
|
|
|
| 176 |
elif self.chat_state.startswith('euc_100'):
|
| 177 |
response = self.euc_100(message)
|
| 178 |
if self.chat_state == 'free_chat':
|
| 179 |
+
last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
|
| 180 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
| 181 |
|
| 182 |
elif self.chat_state.startswith('euc_200'):
|
|
|
|
| 190 |
self.history['text'].append((self.message_prev, response))
|
| 191 |
else:
|
| 192 |
self.history['text'].append((message, response))
|
|
|
|
| 193 |
|
| 194 |
def euc_100(self, x):
|
| 195 |
_, subsection, entry = self.chat_state.split('.')
|
|
|
|
| 255 |
message = self.message_prev
|
| 256 |
self.message_prev = x
|
| 257 |
self.chat_state = self.chat_state_prev
|
| 258 |
+
return self.__call__(message, prefix=response)
|
| 259 |
|
| 260 |
def free_chat(self, message):
|
| 261 |
+
message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
|
| 262 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
| 263 |
input_ids = self.history['input_ids'].clone()
|
| 264 |
|
| 265 |
while True:
|
| 266 |
+
bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000,
|
| 267 |
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
| 268 |
+
pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id)
|
| 269 |
+
response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
|
| 270 |
skip_special_tokens=True)
|
| 271 |
if response.strip() != '':
|
| 272 |
break
|
| 273 |
+
elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0:
|
| 274 |
+
idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id)
|
| 275 |
input_ids = input_ids[:, (idx+1):]
|
| 276 |
else:
|
| 277 |
input_ids = message_ids
|
|
|
|
| 286 |
return response
|
| 287 |
|
| 288 |
|
| 289 |
+
if __name__ == '__main__':
|
| 290 |
+
def chat(message, bot):
|
| 291 |
+
bot = bot or TherapyChatBot(args)
|
| 292 |
+
bot(message)
|
| 293 |
+
return bot.history['text'], bot
|
| 294 |
|
| 295 |
title = 'PsyPlus Empathetic Chatbot'
|
| 296 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
| 297 |
+
greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])]
|
| 298 |
+
chatbot = gr.Chatbot(value=greeting)
|
| 299 |
iface = gr.Interface(
|
| 300 |
+
chat, ['text', 'state'], [chatbot, 'state'],
|
| 301 |
allow_flagging='never', title=title, description=description,
|
| 302 |
)
|
| 303 |
|
|
|
|
| 304 |
if args.run_on_own_server == 0:
|
| 305 |
iface.launch(debug=True)
|
| 306 |
else:
|
| 307 |
+
iface.launch(debug=True, share=True)
|