Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| # --- Configuration --- | |
| MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" | |
| # Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # --- Load Model and Tokenizer --- | |
| # Note: Loading might require specific trust_remote_code=True or other flags | |
| # depending on the model implementation. Check the model card on Hugging Face. | |
| # You might also need specific quantization configs if not handled automatically. | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # Adjust loading parameters as needed (e.g., torch_dtype, device_map) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs | |
| device_map="auto", # Automatically distribute across available devices (GPU/CPU) | |
| trust_remote_code=True # May be required for some custom model code | |
| ) | |
| # model.to(DEVICE) # Usually handled by device_map="auto" | |
| print("Model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| # Fallback or exit if loading fails | |
| raise SystemExit("Failed to load model/tokenizer.") | |
| # --- Chat Processing Function --- | |
| def predict(message, history): | |
| """ | |
| Generates a response to the user's message using the chat history. | |
| """ | |
| history_transformer_format = [] | |
| for human, assistant in history: | |
| # Basic alternating format - adjust if the model expects something different | |
| history_transformer_format.append({"role": "user", "content": human}) | |
| history_transformer_format.append({"role": "assistant", "content": assistant}) | |
| # Add the current user message | |
| history_transformer_format.append({"role": "user", "content": message}) | |
| # Use the tokenizer's chat template if available, otherwise manual formatting. | |
| # Base models might not have a specific chat template. | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| history_transformer_format, | |
| tokenize=False, | |
| add_generation_prompt=True # Important for generation | |
| ) | |
| except Exception: | |
| # Manual fallback prompt formatting (Example - adjust as needed!) | |
| print("Warning: Using basic manual prompt formatting.") | |
| prompt_parts = ["Chat History:"] | |
| for turn in history_transformer_format: | |
| prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}") | |
| prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation | |
| print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n") | |
| # Use a streamer for interactive generation | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.7, | |
| # Add other generation parameters as needed | |
| # eos_token_id=tokenizer.eos_token_id # Important if model needs it | |
| pad_token_id=tokenizer.eos_token_id # Often set for open-end generation | |
| ) | |
| # Run generation in a separate thread for streaming | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield tokens as they become available | |
| partial_message = "" | |
| for new_token in streamer: | |
| partial_message += new_token | |
| yield partial_message | |
| # --- Gradio Interface --- | |
| # Use gr.ChatInterface - it handles history management automatically | |
| chatbot_interface = gr.ChatInterface( | |
| fn=predict, | |
| chatbot=gr.Chatbot(height=500), | |
| textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7), | |
| title="Chat with microsoft/bitnet-b1.58-2B-4T", | |
| description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!", | |
| theme="soft", | |
| examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]], | |
| cache_examples=False, # Set to True to cache example results | |
| retry_btn=None, | |
| undo_btn="Delete Previous Turn", | |
| clear_btn="Clear Chat", | |
| ) | |
| # --- Launch the Interface --- | |
| if __name__ == "__main__": | |
| chatbot_interface.launch() # Use share=True for public link if running locally |