GPT / app.py
Mohssinibra's picture
Update app.py
54a9e39 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Specify the model ID
model_id = "MBZUAI-Paris/Atlas-Chat-2B"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # Automatically selects the device
torch_dtype=torch.bfloat16 # Use bfloat16 for efficiency
)
# Define the text generation function
def generate_text(prompt, max_length=100, temperature=0.7):
# Prepare the input message in chat format
messages = [{"role": "user", "content": prompt}]
# Tokenize the input with the chat template
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True
).to(model.device)
# Generate the response
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_k=50,
top_p=0.95,
do_sample=True,
num_return_sequences=1
)
# Decode and return the generated text
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create the Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=4, label="Enter Prompt"),
gr.Slider(minimum=50, maximum=300, step=10, value=100, label="Max Length"),
gr.Slider(minimum=0.1, maximum=1.5, step=0.1, value=0.7, label="Temperature")
],
outputs="text",
title="Atlas-Chat-27B Text Generator",
description="Powered by the MBZUAI-Paris/Atlas-Chat-27B model."
)
# Launch the interface
interface.launch()