Spaces:
Build error
Build error
| ''' | |
| Dialog System of PsyPlus (dvq) | |
| reference: | |
| https://huggingface.co/spaces/bentrevett/emotion-prediction | |
| https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT | |
| https://huggingface.co/benjaminbeilharz/t5-empatheticdialogues | |
| gradio vs streamlit | |
| https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/ | |
| https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions | |
| TODO | |
| Add command to reset/jump to a function, e.g >reset, >euc_100 | |
| Add diagram in Gradio Interface showing sentimate analysis | |
| Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement | |
| Personalize: create database, load and save data | |
| Run command | |
| python app.py --run_on_own_server 1 --initial_chat_state free_chat | |
| ''' | |
| import argparse | |
| import re, time | |
| import matplotlib.pyplot as plt | |
| from threading import Timer | |
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline | |
| def option(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode') | |
| parser.add_argument('--dialog_model', type=str, default='tareknaous/dialogpt-empathetic-dialogues') | |
| parser.add_argument('--emotion_model', type=str, default='joeddav/distilbert-base-uncased-go-emotions-student') | |
| parser.add_argument('--account', type=str, default=None) | |
| parser.add_argument('--initial_chat_state', type=str, default='euc_100', choices=['euc_100', 'euc_200', 'free_chat']) | |
| args = parser.parse_args() | |
| return args | |
| args = option() | |
| # store the list of messages that are showed in therapies and models as global variables | |
| # let all chat-session-wise variables placed in TherapyChatBot | |
| class ChatHelper: | |
| # chat and emotion-detection models | |
| ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True) | |
| ed_threshold = 0.3 | |
| dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model) | |
| dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model) | |
| eos = dialog_tokenizer.eos_token | |
| # tokenizer.__call__ -> input_ids, attention_mask | |
| # tokenizer.encode -> only inputs_ids, which is required by model.generate function | |
| invalid_input = 'Invalid input, my friend :) Plz input again' | |
| good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?' | |
| good_case = 'Nice to hear that!' | |
| bad_mood_over = 'Whether your bad mood is over? (Yes or No)' | |
| not_answer = "It's okay, maybe you don't want to answer this question." | |
| fill_form = ('It has come to our attention that you may suffer from {}.\n' | |
| 'If you want to know more about yourself, some professional scales are provided to quantify your current status.\n' | |
| 'After a period of time (maybe a week/two months/a month) trying to follow the solutions we suggested, ' | |
| 'you can fill out these scales again to see if you have improved.\n' | |
| 'Do you want to fill in the form now? (Okay or Later)') | |
| display_form = '<Display the form>.\n' | |
| reference = 'Here are some reference articles about bad emotions. You can take a look :) <Display references>\n' | |
| emotion_types = ['Overall', 'Happiness', 'Anxiety'] # 'Surprise', 'Sadness', 'Depression', 'Anger', 'Fear', | |
| euc_100 = { | |
| 'q': emotion_types, | |
| 'good_mood': [ | |
| 'You seem to be in a good mood today. Is there anything you could notice that makes you happy?', | |
| 'I am glad that you are willing to share the experience with me. Thanks for letting me know.', | |
| ], | |
| 'bad_mood': [ | |
| 'You seem not to be in a good mood. What specific thing is bothering you the most right now?', | |
| 'I see. So when it is happening, what feelings or emotions have you got?', | |
| 'And what do you think about those feelings or emotions at that time?', | |
| 'Could you think of any evidence for your above-mentioned thought?', | |
| 'Here are some reference articles about bad emotions. You can take a look :)', | |
| ], | |
| } | |
| negative_emotions = ['remorse', 'nervousness', 'annoyance', 'anger', 'grief', 'fear', 'disapproval', | |
| 'confusion', 'embarrassment', 'disgust', 'sadness', 'disappointment'] | |
| euc_200 = 'Now go back to the last chat. You said that "{}".\n' | |
| greeting_template = { | |
| 'euc_100': 'How was your day? On the scale 1 to 10, ' | |
| 'how would you judge your emotion through the following categories:\nOverall', | |
| # euc_200 is only trigger when you say smt more negative than a certain threshol | |
| # thus the greeting here is only for debuging euc_200 | |
| 'euc_200': fill_form.format('anxiety'), | |
| 'free_chat': 'Hi you! How is it going?', | |
| } | |
| def plot_emotion_distribution(predictions): | |
| fig, ax = plt.subplots() | |
| ax.bar(x=[i for i, _ in enumerate(prediction)], | |
| height=[p['score'] for p in prediction], | |
| tick_label=[p['label'] for p in prediction]) | |
| ax.tick_params(rotation=90) | |
| ax.set_ylim(0, 1) | |
| plt.show() | |
| def ed_rulebase(text): | |
| keywords = { | |
| 'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'], | |
| 'immediacy': ['now', 'immediately', 'tomorrow', 'today'], | |
| 'manifestation': ['never stop', 'every moment', 'strong', 'very'] | |
| } | |
| # if found dangerous kw/topics | |
| if re.search(rf"{'|'.join(keywords['life_safety'])}", text) != None and \ | |
| sum([re.search(rf"{'|'.join(keywords[k])}", text) != None for k in ['immediacy','manifestation']]) >= 1: | |
| print('We noticed that you may need immediate professional assistance, would you like to make a phone call? ' | |
| 'The Hong Kong Lifeline number is (852) 2382 0000') | |
| x = input('Choose 1. "Dial to the number" or 2. "No dangerous emotion la": ') | |
| if x == '1': | |
| print('Let you connect to the office') | |
| else: | |
| print('Sorry for our misdetection. We just want to make sure that you could get immediate help when needed. ' | |
| 'Would you mind if we send this conversation to the cloud to finetune the model.') | |
| y = input('Yes or No: ') | |
| if y == 'Yes': | |
| pass # do smt here | |
| class TherapyChatBot: | |
| def __init__(self, args): | |
| # check state to control the dialog | |
| self.chat_state = args.initial_chat_state # name of the chat function/therapy segment the model is in | |
| self.message_prev = None | |
| self.chat_state_prev = None | |
| self.run_on_own_server = args.run_on_own_server | |
| self.account = args.account | |
| # additional attribute for euc_100 | |
| self.euc_100_input_time = [] | |
| self.euc_100_emotion_degree = [] | |
| self.already_trigger_euc_200 = False | |
| # chat history. | |
| # TODO: if we want to personalize and save the conversation, | |
| # we can load data from database | |
| self.greeting = [('', ChatHelper.greeting_template[self.chat_state])] | |
| self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]), | |
| 'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb') | |
| if 'euc_100' in self.chat_state: | |
| self.chat_state = 'euc_100.q.0' | |
| def __call__(self, message, prefix=''): | |
| # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion | |
| if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200: | |
| prediction = ChatHelper.ed_pipe(message)[0] | |
| prediction = sorted(prediction, key=lambda x: x['score'], reverse=True) | |
| if self.run_on_own_server: | |
| print(prediction) | |
| # plot_emotion_distribution(prediction) | |
| emotion = prediction[0] | |
| # if message is negative, change state immediately | |
| if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \ | |
| (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold): | |
| self.chat_state_prev = self.chat_state | |
| self.chat_state = 'euc_200' | |
| self.message_prev = message | |
| self.already_trigger_euc_200 = True | |
| response = ChatHelper.fill_form.format(emotion['label']) | |
| # set up rule to update state inside each dialog function | |
| elif self.chat_state.startswith('euc_100'): | |
| response = self.euc_100(message) | |
| if self.chat_state == 'free_chat': | |
| last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') | |
| self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1) | |
| elif self.chat_state.startswith('euc_200'): | |
| return self.euc_200(message) | |
| else: # free_chat | |
| response = self.free_chat(message) | |
| if prefix: | |
| response = prefix + response | |
| self.history['text'].append((self.message_prev, response)) | |
| else: | |
| self.history['text'].append((message, response)) | |
| def euc_100(self, x): | |
| _, subsection, entry = self.chat_state.split('.') | |
| entry = int(entry) | |
| if subsection == 'q': | |
| if x.isnumeric() and (0 < int(x) < 11): | |
| self.euc_100_emotion_degree.append(int(x)) | |
| self.euc_100_input_time.append(time.gmtime()) | |
| if entry == len(ChatHelper.euc_100['q']) - 1: | |
| if self.run_on_own_server: | |
| print(self.euc_100_emotion_degree) | |
| mood = 'good_mood' if self.euc_100_emotion_degree[0] > 5 else 'bad_mood' | |
| self.chat_state = f'euc_100.{mood}.0' | |
| response = ChatHelper.euc_100[mood][0] | |
| else: | |
| self.chat_state = f'euc_100.q.{entry+1}' | |
| response = ChatHelper.euc_100['q'][entry+1] | |
| else: | |
| response = ChatHelper.invalid_input | |
| elif subsection == 'good_mood': | |
| if x == '': | |
| response = ChatHelper.good_mood_over | |
| else: | |
| response = ChatHelper.good_case | |
| response += '\n' + ChatHelper.euc_100['good_mood'][1] | |
| self.chat_state = 'free_chat' | |
| elif subsection == 'bad_mood': | |
| if entry == -1: | |
| if 'yes' in x.lower() or 'better' in x.lower(): | |
| response = ChatHelper.good_case | |
| else: | |
| entry = int(self.chat_state_prev.rsplit('.', 1)) | |
| response = ChatHelper.not_answer + '\n' + ChatHelper.euc_100['bad_mood'][entry+1] | |
| if entry == len(ChatHelper.euc_100['bad_mood']) - 2: | |
| self.chat_state = 'free_chat' | |
| else: | |
| self.chat_state = f'euc_100.bad_mood.{entry+1}' | |
| if x == '': | |
| response = ChatHelper.bad_mood_over | |
| self.chat_state_prev = self.chat_state | |
| self.chat_state = 'euc_100.bad_mood.-1' | |
| else: | |
| response = ChatHelper.euc_100['bad_mood'][entry+1] | |
| if entry == len(ChatHelper.euc_100['bad_mood']) - 2: | |
| self.chat_state = 'free_chat' | |
| else: | |
| self.chat_state = f'euc_100.bad_mood.{entry+1}' | |
| return response | |
| def euc_200(self, x): | |
| # don't ask question in euc_200, because they're similar to question in euc_100 | |
| if x.lower() == 'okay': | |
| response = ChatHelper.display_form | |
| else: | |
| response = ChatHelper.reference | |
| response += ChatHelper.euc_200.format(self.message_prev) | |
| message = self.message_prev | |
| self.message_prev = x | |
| self.chat_state = self.chat_state_prev | |
| return self.__call__(message, prefix=response) | |
| def free_chat(self, message): | |
| message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt') | |
| self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1) | |
| input_ids = self.history['input_ids'].clone() | |
| while True: | |
| bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000, | |
| do_sample=True, top_p=0.9, temperature=0.8, num_beams=2, | |
| pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id) | |
| response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:], | |
| skip_special_tokens=True) | |
| if response.strip() != '': | |
| break | |
| elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0: | |
| idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id) | |
| input_ids = input_ids[:, (idx+1):] | |
| else: | |
| input_ids = message_ids | |
| if self.run_on_own_server: | |
| print(input_ids) | |
| self.history['input_ids'] = torch.cat([self.history['input_ids'], bot_output_ids[0:1, input_ids.shape[-1]:]], dim=-1) | |
| if self.run_on_own_server == 1: | |
| print((message, response), '\n', self.history['input_ids']) | |
| return response | |
| if __name__ == '__main__': | |
| def chat(message, bot): | |
| bot = bot or TherapyChatBot(args) | |
| bot(message) | |
| return bot.history['text'], bot | |
| title = 'PsyPlus Empathetic Chatbot' | |
| description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT' | |
| greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])] | |
| chatbot = gr.Chatbot(value=greeting) | |
| iface = gr.Interface( | |
| chat, ['text', 'state'], [chatbot, 'state'], | |
| allow_flagging='never', title=title, description=description, | |
| ) | |
| if args.run_on_own_server == 0: | |
| iface.launch(debug=True) | |
| else: | |
| iface.launch(debug=True, share=True) |