Spaces:
Sleeping
Sleeping
| # app.py – Gradio chatbot for FractalAIResearch/Fathom-R1-14B | |
| # --------------------------------------------------------------------- | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from threading import Thread | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # 1. Model & tokenizer | |
| # --------------------------------------------------------------------- | |
| MODEL_NAME = "FractalAIResearch/Fathom-R1-14B" | |
| print("⏳ Loading model … (this may take a couple of minutes)") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", # dispatch across any available device(s) | |
| trust_remote_code=True, # Fathom uses custom modelling code | |
| low_cpu_mem_usage=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| print("✅ Model loaded") | |
| # --------------------------------------------------------------------- | |
| # 2. Helper: build a prompt with the tokenizer’s chat_template | |
| # --------------------------------------------------------------------- | |
| def build_chat_prompt(history, user_message, system_message): | |
| """ | |
| history : list[dict(role, content)] | |
| user_message : str | |
| system_message : str | |
| returns a single prompt string (not tokenised) | |
| """ | |
| msgs = [] | |
| if system_message: | |
| msgs.append({"role": "system", "content": system_message}) | |
| msgs.extend(history) | |
| msgs.append({"role": "user", "content": user_message}) | |
| return tokenizer.apply_chat_template( | |
| msgs, | |
| tokenize=False, # return pure text | |
| add_generation_prompt=True, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # 3. Generation endpoint | |
| # --------------------------------------------------------------------- | |
| # short GPU reservation if available | |
| def generate_response( | |
| user_message, | |
| max_tokens, | |
| temperature, | |
| top_k, | |
| top_p, | |
| repetition_penalty, | |
| history_state, | |
| ): | |
| # Empty input → nothing to do | |
| if not user_message.strip(): | |
| return history_state, history_state | |
| # System prompt (kept from your Phi-4 version) | |
| system_message = ( | |
| "Your role as an assistant involves thoroughly exploring questions through a " | |
| "systematic thinking process before providing the final precise and accurate " | |
| "solutions. Please structure your response into two main sections: " | |
| "<think> … </think> and Solution." | |
| ) | |
| prompt = build_chat_prompt(history_state, user_message, system_message) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Stream tokens as they come | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=int(max_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| streamer=streamer, | |
| ) | |
| # Run generate in a background thread so the UI stays responsive | |
| Thread(target=model.generate, kwargs=generation_kwargs).start() | |
| assistant_response = "" | |
| new_history = history_state + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": ""}, | |
| ] | |
| for token in streamer: | |
| # strip any stray special tokens the model may output | |
| cleaned = ( | |
| token.replace("<|im_start|>", "") | |
| .replace("<|im_end|>", "") | |
| .replace("<|im_sep|>", "") | |
| ) | |
| assistant_response += cleaned | |
| new_history[-1]["content"] = assistant_response.strip() | |
| yield new_history, new_history | |
| yield new_history, new_history | |
| # --------------------------------------------------------------------- | |
| # 4. Example questions (unchanged) | |
| # --------------------------------------------------------------------- | |
| example_messages = { | |
| "Math reasoning": "If a rectangular prism has a length of 6 cm, a width of 4 cm, and a height of 5 cm, what is the length of the longest line segment that can be drawn from one vertex to another?", | |
| "Logic puzzle": "Four people (Alex, Blake, Casey, and Dana) each have a different favorite color (red, blue, green, yellow) and a different favorite fruit (apple, banana, cherry, date). Given the following clues: 1) The person who likes red doesn't like dates. 2) Alex likes yellow. 3) The person who likes blue likes cherries. 4) Blake doesn't like apples or bananas. 5) Casey doesn't like yellow or green. Who likes what color and what fruit?", | |
| "Physics problem": "A ball is thrown upward with an initial velocity of 15 m/s from a height of 2 meters above the ground. Assuming the acceleration due to gravity is 9.8 m/s², determine: 1) The maximum height the ball reaches. 2) The total time the ball is in the air before hitting the ground. 3) The velocity with which the ball hits the ground.", | |
| } | |
| # --------------------------------------------------------------------- | |
| # 5. Gradio UI (identical to the original, just lower default max_tokens) | |
| # --------------------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Fathom-R1-14B Chatbot | |
| The model excels at multi-step reasoning in mathematics, logic, and science. | |
| It returns two sections:\n | |
| 1. **<think>** – detailed chain-of-thought (reasoning)\n | |
| 2. **Solution** – concise, final answer | |
| """ | |
| ) | |
| history_state = gr.State([]) | |
| with gr.Row(): | |
| # Settings panel | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, maximum=4096, step=256, value=1024, label="Max Tokens" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.8, label="Temperature" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=50, label="Top-k" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, label="Top-p" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty" | |
| ) | |
| # Chat area | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(label="Chat", type="messages") | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| label="Your message", placeholder="Type your message here…", scale=3 | |
| ) | |
| submit_button = gr.Button("Send", variant="primary", scale=1) | |
| clear_button = gr.Button("Clear", scale=1) | |
| gr.Markdown("**Try these examples:**") | |
| with gr.Row(): | |
| example1_button = gr.Button("Math reasoning") | |
| example2_button = gr.Button("Logic puzzle") | |
| example3_button = gr.Button("Physics problem") | |
| # Button wiring | |
| submit_button.click( | |
| fn=generate_response, | |
| inputs=[ | |
| user_input, | |
| max_tokens_slider, | |
| temperature_slider, | |
| top_k_slider, | |
| top_p_slider, | |
| repetition_penalty_slider, | |
| history_state, | |
| ], | |
| outputs=[chatbot, history_state], | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| inputs=None, | |
| outputs=user_input, | |
| ) | |
| clear_button.click( | |
| fn=lambda: ([], []), | |
| inputs=None, | |
| outputs=[chatbot, history_state], | |
| ) | |
| example1_button.click( | |
| fn=lambda: gr.update(value=example_messages["Math reasoning"]), | |
| inputs=None, | |
| outputs=user_input, | |
| ) | |
| example2_button.click( | |
| fn=lambda: gr.update(value=example_messages["Logic puzzle"]), | |
| inputs=None, | |
| outputs=user_input, | |
| ) | |
| example3_button.click( | |
| fn=lambda: gr.update(value=example_messages["Physics problem"]), | |
| inputs=None, | |
| outputs=user_input, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # 6. Launch | |
| # --------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |