Fathom / app.py
FractalAIR's picture
Update app.py
e30ac3f verified
raw
history blame
8.62 kB
# app.py – Gradio chatbot for FractalAIResearch/Fathom-R1-14B
# ---------------------------------------------------------------------
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# ---------------------------------------------------------------------
# 1. Model & tokenizer
# ---------------------------------------------------------------------
MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
print("⏳ Loading model … (this may take a couple of minutes)")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto", # dispatch across any available device(s)
trust_remote_code=True, # Fathom uses custom modelling code
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
print("✅ Model loaded")
# ---------------------------------------------------------------------
# 2. Helper: build a prompt with the tokenizer’s chat_template
# ---------------------------------------------------------------------
def build_chat_prompt(history, user_message, system_message):
"""
history : list[dict(role, content)]
user_message : str
system_message : str
returns a single prompt string (not tokenised)
"""
msgs = []
if system_message:
msgs.append({"role": "system", "content": system_message})
msgs.extend(history)
msgs.append({"role": "user", "content": user_message})
return tokenizer.apply_chat_template(
msgs,
tokenize=False, # return pure text
add_generation_prompt=True,
)
# ---------------------------------------------------------------------
# 3. Generation endpoint
# ---------------------------------------------------------------------
@spaces.GPU(duration=60) # short GPU reservation if available
def generate_response(
user_message,
max_tokens,
temperature,
top_k,
top_p,
repetition_penalty,
history_state,
):
# Empty input → nothing to do
if not user_message.strip():
return history_state, history_state
# System prompt (kept from your Phi-4 version)
system_message = (
"Your role as an assistant involves thoroughly exploring questions through a "
"systematic thinking process before providing the final precise and accurate "
"solutions. Please structure your response into two main sections: "
"<think> … </think> and Solution."
)
prompt = build_chat_prompt(history_state, user_message, system_message)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Stream tokens as they come
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
streamer=streamer,
)
# Run generate in a background thread so the UI stays responsive
Thread(target=model.generate, kwargs=generation_kwargs).start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""},
]
for token in streamer:
# strip any stray special tokens the model may output
cleaned = (
token.replace("<|im_start|>", "")
.replace("<|im_end|>", "")
.replace("<|im_sep|>", "")
)
assistant_response += cleaned
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
yield new_history, new_history
# ---------------------------------------------------------------------
# 4. Example questions (unchanged)
# ---------------------------------------------------------------------
example_messages = {
"Math reasoning": "If a rectangular prism has a length of 6 cm, a width of 4 cm, and a height of 5 cm, what is the length of the longest line segment that can be drawn from one vertex to another?",
"Logic puzzle": "Four people (Alex, Blake, Casey, and Dana) each have a different favorite color (red, blue, green, yellow) and a different favorite fruit (apple, banana, cherry, date). Given the following clues: 1) The person who likes red doesn't like dates. 2) Alex likes yellow. 3) The person who likes blue likes cherries. 4) Blake doesn't like apples or bananas. 5) Casey doesn't like yellow or green. Who likes what color and what fruit?",
"Physics problem": "A ball is thrown upward with an initial velocity of 15 m/s from a height of 2 meters above the ground. Assuming the acceleration due to gravity is 9.8 m/s², determine: 1) The maximum height the ball reaches. 2) The total time the ball is in the air before hitting the ground. 3) The velocity with which the ball hits the ground.",
}
# ---------------------------------------------------------------------
# 5. Gradio UI (identical to the original, just lower default max_tokens)
# ---------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Fathom-R1-14B Chatbot
The model excels at multi-step reasoning in mathematics, logic, and science.
It returns two sections:\n
1. **<think>** – detailed chain-of-thought (reasoning)\n
2. **Solution** – concise, final answer
"""
)
history_state = gr.State([])
with gr.Row():
# Settings panel
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(
minimum=64, maximum=4096, step=256, value=1024, label="Max Tokens"
)
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.1, maximum=2.0, value=0.8, label="Temperature"
)
top_k_slider = gr.Slider(
minimum=1, maximum=100, step=1, value=50, label="Top-k"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, label="Top-p"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty"
)
# Chat area
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
user_input = gr.Textbox(
label="Your message", placeholder="Type your message here…", scale=3
)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
example1_button = gr.Button("Math reasoning")
example2_button = gr.Button("Logic puzzle")
example3_button = gr.Button("Physics problem")
# Button wiring
submit_button.click(
fn=generate_response,
inputs=[
user_input,
max_tokens_slider,
temperature_slider,
top_k_slider,
top_p_slider,
repetition_penalty_slider,
history_state,
],
outputs=[chatbot, history_state],
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input,
)
clear_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[chatbot, history_state],
)
example1_button.click(
fn=lambda: gr.update(value=example_messages["Math reasoning"]),
inputs=None,
outputs=user_input,
)
example2_button.click(
fn=lambda: gr.update(value=example_messages["Logic puzzle"]),
inputs=None,
outputs=user_input,
)
example3_button.click(
fn=lambda: gr.update(value=example_messages["Physics problem"]),
inputs=None,
outputs=user_input,
)
# ---------------------------------------------------------------------
# 6. Launch
# ---------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(ssr_mode=False)