Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Specify the model ID | |
| model_id = "MBZUAI-Paris/Atlas-Chat-2B" | |
| # Load the tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", # Automatically selects the device | |
| torch_dtype=torch.bfloat16 # Use bfloat16 for efficiency | |
| ) | |
| # Define the text generation function | |
| def generate_text(prompt, max_length=100, temperature=0.7): | |
| # Prepare the input message in chat format | |
| messages = [{"role": "user", "content": prompt}] | |
| # Tokenize the input with the chat template | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| return_dict=True, | |
| add_generation_prompt=True | |
| ).to(model.device) | |
| # Generate the response | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_k=50, | |
| top_p=0.95, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| # Decode and return the generated text | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(lines=4, label="Enter Prompt"), | |
| gr.Slider(minimum=50, maximum=300, step=10, value=100, label="Max Length"), | |
| gr.Slider(minimum=0.1, maximum=1.5, step=0.1, value=0.7, label="Temperature") | |
| ], | |
| outputs="text", | |
| title="Atlas-Chat-27B Text Generator", | |
| description="Powered by the MBZUAI-Paris/Atlas-Chat-27B model." | |
| ) | |
| # Launch the interface | |
| interface.launch() | |