Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from threading import Thread | |
| import spaces | |
| class ChatInterface: | |
| def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| def format_chat_prompt(self, message, history, system_message): | |
| messages = [{"role": "system", "content": system_message}] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| # Format messages according to model's expected chat template | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return prompt | |
| def generate_response( | |
| self, | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| prompt = self.format_chat_prompt(message, history, system_message) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # Setup streamer | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| timeout=10.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| # Generate in a separate thread to enable streaming | |
| generation_kwargs = dict( | |
| inputs=inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| ) | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the response | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield response | |
| def create_demo(): | |
| chat_interface = ChatInterface() | |
| demo = gr.ChatInterface( | |
| chat_interface.generate_response, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a friendly Chatbot.", | |
| label="System message" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=512, | |
| step=1, | |
| label="Max new tokens" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=4.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ), | |
| ], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch() |