Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import BitsAndBytesConfig | |
| from accelerate import infer_auto_device_map | |
| # Load the model name | |
| model_name = "ai4bharat/Airavata" | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Create a BitsAndBytesConfig for quantization | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True, # Set this to True for 8-bit loading | |
| # Optionally, you can specify more parameters based on your needs | |
| ) | |
| # Load the model using the BitsAndBytesConfig | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config # Use the BitsAndBytesConfig | |
| ) | |
| # Now infer the device map | |
| device_map = infer_auto_device_map(model) | |
| # Move model to the appropriate device based on device_map | |
| model.to(device_map) | |
| # Define the inference function | |
| def generate_text(prompt): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| outputs = model.generate(**inputs) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=generate_text, | |
| inputs="text", | |
| outputs="text", | |
| title="Airavata Text Generation Model", | |
| description="This is the AI4Bharat Airavata model for text generation in Indic languages." | |
| ) | |
| # Launch the interface | |
| interface.launch() | |