Spaces:
Build error
Build error
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
| grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector') | |
| grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector') | |
| import torch | |
| import gradio as gr | |
| # def chat(message, history): | |
| # history = history if history is not None else [] | |
| # new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt') | |
| # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
| # history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist() | |
| # # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n") | |
| # # pretty print last ouput tokens from bot | |
| # response = tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True) | |
| # print("The response is ", [response]) | |
| # # history.append((message, response, new_user_input_ids, chat_history_ids)) | |
| # return response, history, feedback(message) | |
| def chat(message, history=[]): | |
| new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') | |
| bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) | |
| history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist() | |
| response = tokenizer.decode(history[0]).replace("<|endoftext|>", "") | |
| return response, history | |
| def feedback(text): | |
| num_return_sequences=1 | |
| batch = grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt") | |
| corrections= grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5) | |
| print("The corrections are: ", corrections) | |
| if len(corrections) == 0: | |
| feedback = f'Looks good! Keep up the good work' | |
| else: | |
| suggestion = grammar_tokenizer.batch_decode(corrections[0], skip_special_tokens=True) | |
| suggestion = [sug for sug in suggestion if '<' not in sug] | |
| feedback = f'\'{" ".join(suggestion)}\' might be a little better' | |
| return feedback | |
| iface = gr.Interface( | |
| chat, | |
| ["text", "state"], | |
| ["chatbot", "state", "text"], | |
| allow_screenshot=False, | |
| allow_flagging="never", | |
| ) | |
| iface.launch() | |