IoannisKat1's picture
Update app.py
8114ad3 verified
raw
history blame
2.89 kB
from unsloth import FastLanguageModel
import torch
import gradio as gr
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
model,tokenizer = FastLanguageModel.from_pretrained('./unified_model')
client = FastLanguageModel.for_inference(model)
def generate_response(instruction,chat_history):
"""Generates a response using your fine-tuned model."""
# FastLanguageModel.for_inference(model) # Enable native 2x faster inference within the function
prompt = f"""### Instruction:
Answer the following question.
### Question:
{instruction}
Provide a unique, concise, and non-repetitive answer.
### Answer:"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs,early_stopping=True,min_length=50,length_penalty=2,do_sample=True,max_new_tokens=300,
top_p=0.95,
top_k=50,
temperature=0.7,
repetition_penalty=1.2,
num_return_sequences=1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Answer:")[-1]
return response
def update_chat_history(chat_history, user_message, bot_message):
"""Update chat history to maintain relevance and avoid excessive growth."""
chat_history['user'].append(user_message)
chat_history['bot'].append(bot_message)
# Keep only the last N interactions
if len(chat_history['user']) > 5:
chat_history['user'] = chat_history['user'][-5:]
chat_history['bot'] = chat_history['bot'][-5:]
return chat_history
def chatbot(input_text,chat_history):
messages = {
"user": [],
"bot": [],
}
for user_msg, bot_msg in chat_history:
messages["user"].append(user_msg)
messages["bot"].append(bot_msg)
bot_response = generate_response(input_text,messages)
chat_history.append(("User: " + input_text, bot_response))
messages = update_chat_history(messages, input_text, bot_response)
return "", chat_history
with gr.Blocks() as demo:
gr.Markdown('## AILA INTERFACE DEMO')
with gr.Row():
gr.Image(value="up_2017_logo_en.png", interactive=False, label="Upatras Logo",width=150,height=100)
gr.Image(value="aila_new.png", interactive=False, label="AILA project Logo",width=150,height=100)
gr.Image(value="banner-horizontal-default-en.png", interactive=False, label="AUTH Logo",width=150,height=100)
with gr.Row():
user_input = gr.Textbox(
placeholder = "Type your message here...",
label = "Your Message",
lines = 1
)
submit_button = gr.Button('Submit')
chat_history = gr.Chatbot()
submit_button.click(
chatbot,
inputs = [user_input,chat_history],
outputs = [user_input, chat_history]
)
demo.launch()