Spaces:
Running
Running
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import time | |
| import gradio as gr | |
| from gradio import deploy | |
| def generate_prompt(instruction, input=""): | |
| instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') | |
| input = input.strip().replace('\r\n','\n').replace('\n\n','\n') | |
| if input: | |
| return f"""Instruction: {instruction} | |
| Input: {input} | |
| Response:""" | |
| else: | |
| return f"""User: hi | |
| Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. | |
| User: {instruction} | |
| Lover:""" | |
| model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| use_flash_attention_2=False | |
| ).to(torch.float32) | |
| # Create a custom tokenizer (make sure to download vocab.json) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| bos_token="</s>", | |
| eos_token="</ s>", | |
| unk_token="<unk>", | |
| pad_token="<pad>", | |
| trust_remote_code=True, | |
| padding_side='left', | |
| clean_up_tokenization_spaces=False # Or set to True if you prefer | |
| ) | |
| # Function to handle text generation with word-by-word output and stop sequence | |
| def generate_text(input_text): | |
| prompt = generate_prompt(input_text) | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
| generated_text = "" | |
| for i in range(333): | |
| output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0) | |
| new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True) | |
| print(new_word, end="", flush=True) | |
| generated_text += new_word | |
| input_ids = output | |
| return generated_text | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs="text", | |
| outputs="text", | |
| title="RWKV Chatbot", | |
| description="Enter your prompt below:", | |
| # flagging_callback=None | |
| flagging_dir="gradio_flagged/" | |
| ) | |
| # For local testing: | |
| # iface.launch(share=True) | |
| deploy() | |
| # Hugging Face Spaces will automatically launch the interface. | |