Spaces:
Build error
Build error
Quyet
commited on
Commit
·
f30862a
1
Parent(s):
de337bd
add euc 100 200 to chat loop, fix dialog model
Browse files
README.md
CHANGED
|
@@ -11,3 +11,14 @@ license: gpl-3.0
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 14 |
+
|
| 15 |
+
For more information about this product, please visit this notion [page](https://www.notion.so/AI-Consulting-Design-Scheme-0a9c5288820d4fec98ecc7cc1e84be51)) (you need to have permission to access this page)
|
| 16 |
+
|
| 17 |
+
# Notes
|
| 18 |
+
|
| 19 |
+
### 2022/12/20
|
| 20 |
+
|
| 21 |
+
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
| 22 |
+
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
| 23 |
+
- TODO is written in the main file already
|
| 24 |
+
- Successfully convert plain euc 100 and 200 to chat flow
|
app.py
CHANGED
|
@@ -1,3 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import re, time
|
| 3 |
import matplotlib.pyplot as plt
|
|
@@ -5,228 +27,275 @@ from threading import Timer
|
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
import torch
|
| 8 |
-
from transformers import
|
| 9 |
-
GPT2LMHeadModel, GPT2Tokenizer,
|
| 10 |
-
AutoModelForSequenceClassification, AutoTokenizer,
|
| 11 |
-
pipeline
|
| 12 |
-
)
|
| 13 |
-
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
| 14 |
-
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
| 15 |
-
# gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
| 16 |
-
# https://gradio.app/interface_state/
|
| 17 |
-
|
| 18 |
-
def euc_100():
|
| 19 |
-
# 1,2,3. asks about the user's emotions and store data
|
| 20 |
-
print('How was your day?')
|
| 21 |
-
print('On the scale 1 to 10, how would you judge your emotion through the following categories:') # ~ Baymax :)
|
| 22 |
-
emotion_types = ['overall'] #, 'happiness', 'surprise', 'sadness', 'depression', 'anger', 'fear', 'anxiety']
|
| 23 |
-
emotion_degree = []
|
| 24 |
-
input_time = []
|
| 25 |
-
|
| 26 |
-
for e in emotion_types:
|
| 27 |
-
while True:
|
| 28 |
-
x = input(f'{e}: ')
|
| 29 |
-
if x.isnumeric() and (0 < int(x) < 11):
|
| 30 |
-
emotion_degree.append(int(x))
|
| 31 |
-
input_time.append(time.gmtime())
|
| 32 |
-
break
|
| 33 |
-
else:
|
| 34 |
-
print('invalid input, my friend :) plz input again')
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
'And what do you think about those feelings or emotions at that time?',
|
| 59 |
'Could you think of any evidence for your above-mentioned thought?',
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
y = 'No' # bad mood
|
| 64 |
-
while True:
|
| 65 |
-
x = input('Your answer (example of answer here): ')
|
| 66 |
-
if x == '': # need to change this part to waiting 10 seconds
|
| 67 |
-
print('Whether your bad mood is over?')
|
| 68 |
-
y = input('Your answer (Yes or No): ')
|
| 69 |
-
if y == 'Yes':
|
| 70 |
-
break
|
| 71 |
-
else:
|
| 72 |
-
break
|
| 73 |
-
if y == 'Yes':
|
| 74 |
-
print('Nice to hear that.')
|
| 75 |
-
break
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
return_all_scores=True, truncation=True)
|
| 89 |
-
return pipe
|
| 90 |
-
|
| 91 |
-
def plot_emotion_distribution(predictions):
|
| 92 |
-
fig, ax = plt.subplots()
|
| 93 |
-
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
| 94 |
-
height=[p['score'] for p in prediction],
|
| 95 |
-
tick_label=[p['label'] for p in prediction])
|
| 96 |
-
ax.tick_params(rotation=90)
|
| 97 |
-
ax.set_ylim(0, 1)
|
| 98 |
-
plt.show()
|
| 99 |
-
|
| 100 |
-
def rulebase(text):
|
| 101 |
-
keywords = {
|
| 102 |
-
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
| 103 |
-
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
| 104 |
-
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
| 105 |
}
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
threshold = 0.3
|
| 136 |
-
emotion = {'label': 'sadness', 'score': 0.4} if testing else prediction[0]
|
| 137 |
-
# then judge
|
| 138 |
-
if emotion['label'] in ['surprise', 'sadness', 'anger', 'fear'] and emotion['score'] > threshold:
|
| 139 |
-
print(f'It has come to our attention that you may suffer from {emotion["label"]}')
|
| 140 |
-
print('If you want to know more about yourself, '
|
| 141 |
-
'some professional scales are provided to quantify your current status. '
|
| 142 |
-
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
|
| 143 |
-
'you can fill out these scales again to see if you have improved.')
|
| 144 |
-
x = input('Fill in the form now (Okay or Later): ')
|
| 145 |
-
if x == 'Okay':
|
| 146 |
-
print('Display the form')
|
| 147 |
-
else:
|
| 148 |
-
print('Here are some reference articles about bad emotions. You can take a look :)')
|
| 149 |
-
|
| 150 |
-
# 4. If both of the above are not satisfied. What do u mean by 'satisfied' here?
|
| 151 |
-
questions = [
|
| 152 |
-
'What specific thing is bothering you the most right now?',
|
| 153 |
-
'Oh, I see. So when it is happening, what feelings or emotions have you got?',
|
| 154 |
-
'And what do you think about those feelings or emotions at that time?',
|
| 155 |
-
'Could you think of any evidence for your above-mentioned thought? #',
|
| 156 |
-
]
|
| 157 |
-
for q in questions:
|
| 158 |
-
print(q)
|
| 159 |
-
y = 'No' # bad mood
|
| 160 |
-
while True:
|
| 161 |
-
x = input('Your answer (example of answer here): ')
|
| 162 |
-
if x == '': # need to change this part to waiting 10 seconds
|
| 163 |
-
print('Whether your bad mood is over?')
|
| 164 |
-
y = input('Your answer (Yes or No): ')
|
| 165 |
if y == 'Yes':
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
else:
|
| 168 |
-
|
| 169 |
-
if y == 'Yes':
|
| 170 |
-
print('Nice to hear that.')
|
| 171 |
-
break
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
print((message, response), bot_output_ids[0][-10:])
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
|
|
|
| 216 |
|
| 217 |
title = 'PsyPlus Empathetic Chatbot'
|
| 218 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
| 219 |
-
chatbot = gr.Chatbot(value=
|
| 220 |
iface = gr.Interface(
|
| 221 |
-
chat,
|
| 222 |
-
|
| 223 |
-
[chatbot, 'state'],
|
| 224 |
-
# css=".gradio-container {background-color: white}",
|
| 225 |
-
allow_flagging='never',
|
| 226 |
-
title=title,
|
| 227 |
-
description=description,
|
| 228 |
)
|
|
|
|
|
|
|
| 229 |
if args.run_on_own_server == 0:
|
| 230 |
iface.launch(debug=True)
|
| 231 |
else:
|
| 232 |
-
iface.launch(debug=True, server_name='0.0.0.0', server_port=2022
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Dialog System of PsyPlus (dvq)
|
| 3 |
+
|
| 4 |
+
reference:
|
| 5 |
+
https://huggingface.co/spaces/bentrevett/emotion-prediction
|
| 6 |
+
https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
| 7 |
+
https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues
|
| 8 |
+
|
| 9 |
+
gradio vs streamlit
|
| 10 |
+
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
| 11 |
+
https://gradio.app/interface_state/
|
| 12 |
+
|
| 13 |
+
TODO
|
| 14 |
+
Add diagram in Gradio Interface showing sentimate analysis
|
| 15 |
+
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
| 16 |
+
Personalize: create database, load and save data
|
| 17 |
+
|
| 18 |
+
Run command
|
| 19 |
+
python app.py --run_on_own_server 1 --initial_chat_state free_chat
|
| 20 |
+
'''
|
| 21 |
+
|
| 22 |
+
|
| 23 |
import argparse
|
| 24 |
import re, time
|
| 25 |
import matplotlib.pyplot as plt
|
|
|
|
| 27 |
import gradio as gr
|
| 28 |
|
| 29 |
import torch
|
| 30 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
|
| 33 |
+
def option():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
|
| 36 |
+
parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues')
|
| 37 |
+
parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student')
|
| 38 |
+
parser.add_argument('--account', type=str, default=None)
|
| 39 |
+
parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat'])
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ChatHelper: # store the list of messages that are showed in therapies
|
| 45 |
+
invalid_input = 'Invalid input, my friend :) Plz input again'
|
| 46 |
+
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
| 47 |
+
good_case = 'Nice to hear that!'
|
| 48 |
+
bad_mood_over = 'Whether your bad mood is over? (Yes or No)'
|
| 49 |
+
not_answer = "It's okay, maybe you don't want to answer this question."
|
| 50 |
+
fill_form = ('It has come to our attention that you may suffer from {}.\n'
|
| 51 |
+
'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n'
|
| 52 |
+
'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, '
|
| 53 |
+
'you can fill out these scales again to see if you have improved.\n'
|
| 54 |
+
'Do you want to fill in the form now? (Okay or Later)')
|
| 55 |
+
display_form = '<Display the form>.\n'
|
| 56 |
+
reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n'
|
| 57 |
+
|
| 58 |
+
emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear',
|
| 59 |
+
euc_100 = {
|
| 60 |
+
'q': emotion_types,
|
| 61 |
+
'good_mood': [
|
| 62 |
+
'You seem to be in a good mood today. Is there anything you could notice that makes you happy?',
|
| 63 |
+
'I am glad that you are willing to share the experience with me. Thanks for letting me know.',
|
| 64 |
+
],
|
| 65 |
+
'bad_mood': [
|
| 66 |
+
'You seem not to be in a good mood. What specific thing is bothering you the most right now?',
|
| 67 |
+
'I see. So when it is happening, what feelings or emotions have you got?',
|
| 68 |
'And what do you think about those feelings or emotions at that time?',
|
| 69 |
'Could you think of any evidence for your above-mentioned thought?',
|
| 70 |
+
'Here are some reference articles about bad emotions. You can take a look :)',
|
| 71 |
+
],
|
| 72 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval',
|
| 75 |
+
'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment']
|
| 76 |
+
euc_200 = 'Now go back to the last chat. You said that "{}".\n'
|
| 77 |
+
|
| 78 |
+
greeting_template = {
|
| 79 |
+
'euc_100': 'How was your day? On the scale 1 to 10, '
|
| 80 |
+
'how would you judge your emotion through the following categories:\nOverall',
|
| 81 |
+
# euc_200 is only trigger when you say smt more negative than a certain threshol
|
| 82 |
+
# thus the greeting here is only for debuging euc_200
|
| 83 |
+
'euc_200': fill_form.format('anxiety'),
|
| 84 |
+
'free_chat': 'Hi you! How is it going?',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
}
|
| 86 |
|
| 87 |
+
def plot_emotion_distribution(predictions):
|
| 88 |
+
fig, ax = plt.subplots()
|
| 89 |
+
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
| 90 |
+
height=[p['score'] for p in prediction],
|
| 91 |
+
tick_label=[p['label'] for p in prediction])
|
| 92 |
+
ax.tick_params(rotation=90)
|
| 93 |
+
ax.set_ylim(0, 1)
|
| 94 |
+
plt.show()
|
| 95 |
+
|
| 96 |
+
def ed_rulebase(text):
|
| 97 |
+
keywords = {
|
| 98 |
+
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
| 99 |
+
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
| 100 |
+
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# if found dangerous kw/topics
|
| 104 |
+
if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \
|
| 105 |
+
sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1:
|
| 106 |
+
print('We noticed that you may need immediate professional assistance, would you like to make a phone call? '
|
| 107 |
+
'The Hong Kong Lifeline number is (852) 2382 0000')
|
| 108 |
+
x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ')
|
| 109 |
+
if x == '1':
|
| 110 |
+
print('Let you connect to the office')
|
| 111 |
+
else:
|
| 112 |
+
print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. '
|
| 113 |
+
'Would you mind if we send this conversation to the cloud to finetune the model.')
|
| 114 |
+
y = input('Yes or No: ')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if y == 'Yes':
|
| 116 |
+
pass # do smt here
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TherapyChatBot:
|
| 120 |
+
def __init__(self, args):
|
| 121 |
+
# check state to control the dialog
|
| 122 |
+
self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in
|
| 123 |
+
self.message_prev = None
|
| 124 |
+
self.chat_state_prev = None
|
| 125 |
+
self.run_on_own_server = args.run_on_own_server
|
| 126 |
+
self.account = args.account
|
| 127 |
+
|
| 128 |
+
# additional attribute for euc_100
|
| 129 |
+
self.euc_100_input_time = []
|
| 130 |
+
self.euc_100_emotion_degree = []
|
| 131 |
+
self.already_trigger_euc_200 = False
|
| 132 |
+
|
| 133 |
+
# chat and emotion-detection models
|
| 134 |
+
self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
| 135 |
+
self.ed_threshold = 0.3
|
| 136 |
+
self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
| 137 |
+
self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
| 138 |
+
self.eos = self.dialog_tokenizer.eos_token
|
| 139 |
+
# tokenizer.__call__ -> input_ids, attention_mask
|
| 140 |
+
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
| 141 |
+
|
| 142 |
+
# chat history.
|
| 143 |
+
# TODO: if we want to personalize and save the conversation,
|
| 144 |
+
# we can load data from database
|
| 145 |
+
self.greeting = ChatHelper.greeting_template[self.chat_state]
|
| 146 |
+
self.history = {'input_ids': torch.tensor([[self.dialog_tokenizer.bos_token_id]]),
|
| 147 |
+
'text': [('', self.greeting)]} if not self.account else open(f'database/{hash(self.account)}', 'rb')
|
| 148 |
+
if 'euc_100' in self.chat_state:
|
| 149 |
+
self.chat_state = 'euc_100.q.0'
|
| 150 |
+
|
| 151 |
+
def __call__(self, message, prefix=''):
|
| 152 |
+
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
| 153 |
+
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
| 154 |
+
prediction = self.ed_pipe(message)[0]
|
| 155 |
+
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
| 156 |
+
if self.run_on_own_server:
|
| 157 |
+
print(prediction)
|
| 158 |
+
# plot_emotion_distribution(prediction)
|
| 159 |
+
emotion = prediction[0]
|
| 160 |
+
|
| 161 |
+
# if message is negative, change state immediately
|
| 162 |
+
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
| 163 |
+
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > self.ed_threshold):
|
| 164 |
+
self.chat_state_prev = self.chat_state
|
| 165 |
+
self.chat_state = 'euc_200'
|
| 166 |
+
self.message_prev = message
|
| 167 |
+
self.already_trigger_euc_200 = True
|
| 168 |
+
response = ChatHelper.fill_form.format(emotion['label'])
|
| 169 |
+
|
| 170 |
+
# set up rule to update state inside each dialog function
|
| 171 |
+
elif self.chat_state.startswith('euc_100'):
|
| 172 |
+
response = self.euc_100(message)
|
| 173 |
+
if self.chat_state == 'free_chat':
|
| 174 |
+
last_two_turns_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
|
| 175 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
| 176 |
+
|
| 177 |
+
elif self.chat_state.startswith('euc_200'):
|
| 178 |
+
return self.euc_200(message)
|
| 179 |
+
|
| 180 |
+
else: # free_chat
|
| 181 |
+
response = self.free_chat(message)
|
| 182 |
+
|
| 183 |
+
if prefix:
|
| 184 |
+
response = prefix + response
|
| 185 |
+
self.history['text'].append((self.message_prev, response))
|
| 186 |
+
else:
|
| 187 |
+
self.history['text'].append((message, response))
|
| 188 |
+
return self.history['text']
|
| 189 |
+
|
| 190 |
+
def euc_100(self, x):
|
| 191 |
+
_, subsection, entry = self.chat_state.split('.')
|
| 192 |
+
entry = int(entry)
|
| 193 |
+
|
| 194 |
+
if subsection == 'q':
|
| 195 |
+
if x.isnumeric() and (0 < int(x) < 11):
|
| 196 |
+
self.euc_100_emotion_degree.append(int(x))
|
| 197 |
+
self.euc_100_input_time.append(time.gmtime())
|
| 198 |
+
if entry == len(ChatHelper.euc_100['q']) - 1:
|
| 199 |
+
if self.run_on_own_server:
|
| 200 |
+
print(self.euc_100_emotion_degree)
|
| 201 |
+
mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood'
|
| 202 |
+
self.chat_state = f'euc_100.{mood}.0'
|
| 203 |
+
response = ChatHelper.euc_100[mood][0]
|
| 204 |
+
else:
|
| 205 |
+
self.chat_state = f'euc_100.q.{entry+1}'
|
| 206 |
+
response = ChatHelper.euc_100['q'][entry+1]
|
| 207 |
else:
|
| 208 |
+
response = ChatHelper.invalid_input
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
elif subsection == 'good_mood':
|
| 211 |
+
if x == '':
|
| 212 |
+
response = ChatHelper.good_mood_over
|
| 213 |
+
else:
|
| 214 |
+
response = ChatHelper.good_case
|
| 215 |
+
response += '\n' + ChatHelper.euc_100['good_mood'][1]
|
| 216 |
+
self.chat_state = 'free_chat'
|
| 217 |
|
| 218 |
+
elif subsection == 'bad_mood':
|
| 219 |
+
if entry == -1:
|
| 220 |
+
if 'yes' in x.lower() or 'better' in x.lower():
|
| 221 |
+
response = ChatHelper.good_case
|
| 222 |
+
else:
|
| 223 |
+
entry = int(self.chat_state_prev.rsplit('.', 1))
|
| 224 |
+
response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1]
|
| 225 |
+
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
|
| 226 |
+
self.chat_state = 'free_chat'
|
| 227 |
+
else:
|
| 228 |
+
self.chat_state = f'euc_100.bad_mood.{entry+1}'
|
| 229 |
|
| 230 |
+
if x == '':
|
| 231 |
+
response = ChatHelper.bad_mood_over
|
| 232 |
+
self.chat_state_prev = self.chat_state
|
| 233 |
+
self.chat_state = 'euc_100.bad_mood.-1'
|
| 234 |
+
else:
|
| 235 |
+
response = ChatHelper.euc_100['bad_mood'][entry+1]
|
| 236 |
+
if entry == len(ChatHelper.euc_100['bad_mood']) - 2:
|
| 237 |
+
self.chat_state = 'free_chat'
|
| 238 |
+
else:
|
| 239 |
+
self.chat_state = f'euc_100.bad_mood.{entry+1}'
|
| 240 |
|
| 241 |
+
return response
|
|
|
|
| 242 |
|
| 243 |
+
def euc_200(self, x):
|
| 244 |
+
# don't ask question in euc_200, because they're similar to question in euc_100
|
| 245 |
+
if x.lower() == 'okay':
|
| 246 |
+
response = ChatHelper.display_form
|
| 247 |
+
else:
|
| 248 |
+
response = ChatHelper.reference
|
| 249 |
+
response += ChatHelper.euc_200.format(self.message_prev)
|
|
|
|
| 250 |
|
| 251 |
+
message = self.message_prev
|
| 252 |
+
self.message_prev = x
|
| 253 |
+
self.chat_state = self.chat_state_prev
|
| 254 |
+
return self.__call__(message, response)
|
| 255 |
|
| 256 |
+
def free_chat(self, message):
|
| 257 |
+
message_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
|
| 258 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
| 259 |
+
input_ids = self.history['input_ids'].clone()
|
| 260 |
|
| 261 |
+
while True:
|
| 262 |
+
bot_output_ids = self.dialog_model.generate(input_ids, max_length=1000,
|
| 263 |
+
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
| 264 |
+
pad_token_id=self.dialog_tokenizer.eos_token_id)
|
| 265 |
+
response = self.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
|
| 266 |
+
skip_special_tokens=True)
|
| 267 |
+
if response.strip() != '':
|
| 268 |
+
break
|
| 269 |
+
elif input_ids[0].tolist().count(self.dialog_tokenizer.eos_token_id) > 0:
|
| 270 |
+
idx = input_ids[0].tolist().index(self.dialog_tokenizer.eos_token_id)
|
| 271 |
+
input_ids = input_ids[:, (idx+1):]
|
| 272 |
+
else:
|
| 273 |
+
input_ids = message_ids
|
| 274 |
+
|
| 275 |
+
if self.run_on_own_server:
|
| 276 |
+
print(input_ids)
|
| 277 |
+
|
| 278 |
+
self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1)
|
| 279 |
+
if self.run_on_own_server == 1:
|
| 280 |
+
print((message, response), '\n', self.history['input_ids'])
|
| 281 |
|
| 282 |
+
return response
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if __name__ == '__main__':
|
| 286 |
+
args = option()
|
| 287 |
+
chat = TherapyChatBot(args)
|
| 288 |
|
| 289 |
title = 'PsyPlus Empathetic Chatbot'
|
| 290 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
| 291 |
+
chatbot = gr.Chatbot(value=chat.history['text'])
|
| 292 |
iface = gr.Interface(
|
| 293 |
+
chat, 'text', chatbot,
|
| 294 |
+
allow_flagging='never', title=title, description=description,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
)
|
| 296 |
+
|
| 297 |
+
# iface.queue(concurrency_count=5)
|
| 298 |
if args.run_on_own_server == 0:
|
| 299 |
iface.launch(debug=True)
|
| 300 |
else:
|
| 301 |
+
iface.launch(debug=True, share=True) # server_name='0.0.0.0', server_port=2022
|