|
|
import os |
|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
from datasets import load_dataset |
|
|
import random |
|
|
import re |
|
|
import sympy as sp |
|
|
|
|
|
|
|
|
math_samples = None |
|
|
|
|
|
def load_sample_problems(): |
|
|
"""Load sample problems from ALL datasets - FIXED VERSION""" |
|
|
global math_samples |
|
|
if math_samples is not None: |
|
|
return math_samples |
|
|
|
|
|
samples = [] |
|
|
try: |
|
|
print("🔄 Loading GSM8K...") |
|
|
|
|
|
gsm8k = load_dataset("openai/gsm8k", "main", streaming=True) |
|
|
gsm_count = 0 |
|
|
for i, item in enumerate(gsm8k["train"]): |
|
|
samples.append(item["question"]) |
|
|
gsm_count += 1 |
|
|
if gsm_count >= 50: |
|
|
break |
|
|
|
|
|
print("🔄 Loading Fineweb-edu...") |
|
|
|
|
|
fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) |
|
|
fw_count = 0 |
|
|
for item in fw: |
|
|
|
|
|
text_lower = item['text'].lower() |
|
|
if any(word in text_lower for word in ['math', 'calculate', 'solve', 'derivative', 'integral', 'triangle', 'equation', 'area', 'volume', 'probability']): |
|
|
|
|
|
question = item['text'][:150].strip() |
|
|
if len(question) > 20: |
|
|
samples.append(question + " (Solve this math problem.)") |
|
|
fw_count += 1 |
|
|
if fw_count >= 20: |
|
|
break |
|
|
|
|
|
print("🔄 Loading Ultrachat...") |
|
|
|
|
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", streaming=True) |
|
|
ds_count = 0 |
|
|
for item in ds: |
|
|
if len(item['messages']) > 0: |
|
|
content = item['messages'][0]['content'].lower() |
|
|
if any(word in content for word in ['math', 'calculate', 'solve', 'problem', 'equation', 'derivative', 'integral']): |
|
|
user_msg = item['messages'][0]['content'] |
|
|
if len(user_msg) > 10: |
|
|
samples.append(user_msg) |
|
|
ds_count += 1 |
|
|
if ds_count >= 20: |
|
|
break |
|
|
|
|
|
print(f"✅ Loaded {len(samples)} samples: GSM8K ({gsm_count}), Fineweb-edu ({fw_count}), Ultrachat ({ds_count})") |
|
|
math_samples = samples |
|
|
return samples |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Dataset error: {e}, using fallback") |
|
|
math_samples = [ |
|
|
"What is the derivative of f(x) = 3x² + 2x - 1?", |
|
|
"A triangle has sides of length 5, 12, and 13. What is its area?", |
|
|
"If log₂(x) + log₂(x+6) = 4, find the value of x.", |
|
|
"Find the limit: lim(x->0) (sin(x)/x)", |
|
|
"Solve the system: x + 2y = 7, 3x - y = 4", |
|
|
"Calculate the integral of sin(x) from 0 to pi.", |
|
|
"What is the probability of rolling a 6 on a die 3 times in a row?" |
|
|
] |
|
|
return math_samples |
|
|
|
|
|
def create_math_system_message(): |
|
|
"""Specialized system prompt for mathematics with LaTeX""" |
|
|
return r"""You are Mathetics AI, an advanced mathematics tutor and problem solver. |
|
|
|
|
|
🧮 **Your Expertise:** |
|
|
- Step-by-step problem solving with clear explanations |
|
|
- Multiple solution approaches when applicable |
|
|
- Proper mathematical notation and terminology using LaTeX |
|
|
- Verification of answers through different methods |
|
|
|
|
|
📐 **Problem Domains:** |
|
|
- Arithmetic, Algebra, and Number Theory |
|
|
- Geometry, Trigonometry, and Coordinate Geometry |
|
|
- Calculus (Limits, Derivatives, Integrals) |
|
|
- Statistics, Probability, and Data Analysis |
|
|
- Competition Mathematics (AMC, AIME level) |
|
|
|
|
|
💡 **Teaching Style:** |
|
|
1. **Understand the Problem** - Identify what's being asked |
|
|
2. **Plan the Solution** - Choose the appropriate method |
|
|
3. **Execute Step-by-Step** - Show all work clearly with LaTeX formatting |
|
|
4. **Verify the Answer** - Check if the result makes sense |
|
|
5. **Alternative Methods** - Mention other possible approaches |
|
|
|
|
|
**LaTeX Guidelines:** |
|
|
- Use $...$ for inline math: $x^2 + y^2 = z^2$ |
|
|
- Use $$...$$ for display math |
|
|
- Box final answers: \boxed{answer} |
|
|
- Fractions: \frac{numerator}{denominator} |
|
|
- Limits: \lim_{x \to 0} |
|
|
- Derivatives: \frac{d}{dx} or f'(x) |
|
|
|
|
|
Always be precise, educational, and encourage mathematical thinking.""" |
|
|
|
|
|
def render_latex(text): |
|
|
"""Enhanced LaTeX cleanup with support for advanced SymPy outputs""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
try: |
|
|
|
|
|
text = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', text, flags=re.DOTALL) |
|
|
text = re.sub(r'\\\((.*?)\\\)', r'$\1$', text, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
if '\\boxed' in text and not re.search(r'\$.*\\boxed.*\$', text): |
|
|
text = re.sub(r'\\boxed\{([^}]+)\}', r'$$\boxed{\1}$$', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'\\begin\{equation\*\}(.*?)\\end\{equation\*\}', r'$$\1$$', text, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
text = re.sub(r'\\%', '%', text) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ LaTeX error: {e}") |
|
|
|
|
|
return text |
|
|
|
|
|
def try_sympy_compute(message): |
|
|
"""Attempt to compute the result using SymPy for verification and better rendering, with advanced LaTeX options.""" |
|
|
message_lower = message.lower() |
|
|
|
|
|
x = sp.Symbol('x') |
|
|
|
|
|
|
|
|
if 'integral' in message_lower or '∫' in message: |
|
|
match = re.search(r'(?:integral of|∫) (.+?) from (.+?) to (.+)', message_lower) |
|
|
if match: |
|
|
expr_str, lower, upper = match.groups() |
|
|
try: |
|
|
expr = sp.sympify(expr_str.replace('^', '**')) |
|
|
result = sp.integrate(expr, (x, sp.sympify(lower), sp.sympify(upper))) |
|
|
|
|
|
return r'\boxed{' + sp.latex(result, mode='plain', fold_frac_powers=True) + r'}' |
|
|
except Exception as e: |
|
|
print(f"⚠️ SymPy integral error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
elif 'derivative' in message_lower: |
|
|
match = re.search(r'derivative of (.+)', message_lower) |
|
|
if match: |
|
|
expr_str = match.group(1) |
|
|
try: |
|
|
expr = sp.sympify(expr_str.replace('^', '**')) |
|
|
result = sp.diff(expr, x) |
|
|
|
|
|
return r'\boxed{' + sp.latex(result, inv_trig_style='power', fold_short_frac=True) + r'}' |
|
|
except Exception as e: |
|
|
print(f"⚠️ SymPy derivative error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
elif 'limit' in message_lower or 'lim' in message_lower: |
|
|
match = re.search(r'(?:limit|lim) (.+?) as (.+?) to (.+)', message_lower) |
|
|
if match: |
|
|
expr_str, var, to_val = match.groups() |
|
|
try: |
|
|
expr = sp.sympify(expr_str.replace('^', '**')) |
|
|
result = sp.limit(expr, sp.Symbol(var), sp.sympify(to_val)) |
|
|
|
|
|
return sp.latex(result, mode='equation*') |
|
|
except Exception as e: |
|
|
print(f"⚠️ SymPy limit error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
elif 'area of triangle' in message_lower: |
|
|
match = re.search(r'(\d+)[ -](\d+)[ -](\d+)', message_lower) |
|
|
if match: |
|
|
a, b, c = map(float, match.groups()) |
|
|
try: |
|
|
s = (a + b + c) / 2 |
|
|
area = sp.sqrt(s * (s - a) * (s - b) * (s - c)) |
|
|
|
|
|
return r'\boxed{' + sp.latex(area, mode='inline', fold_frac_powers=True) + r'}' |
|
|
except Exception as e: |
|
|
print(f"⚠️ SymPy area error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
elif 'matrix' in message_lower: |
|
|
match = re.search(r'matrix \[\[(.+?)\]\]', message_lower) |
|
|
if match: |
|
|
try: |
|
|
elements = [list(map(sp.sympify, row.split(','))) for row in match.group(1).split('],[')] |
|
|
m = sp.Matrix(elements) |
|
|
|
|
|
return sp.latex(m, mat_delim='[', mat_str='bmatrix') |
|
|
except Exception as e: |
|
|
print(f"⚠️ SymPy matrix error: {e}") |
|
|
return None |
|
|
|
|
|
return None |
|
|
|
|
|
def respond(message, history, system_message, max_tokens, temperature, top_p): |
|
|
"""Non-streaming response for stability, with SymPy verification for supported queries.""" |
|
|
client = InferenceClient(model="Qwen/Qwen2.5-Math-7B-Instruct") |
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
|
|
for msg in history: |
|
|
if msg["role"] == "user": |
|
|
messages.append({"role": "user", "content": msg["content"]}) |
|
|
elif msg["role"] == "assistant": |
|
|
messages.append({"role": "assistant", "content": msg["content"]}) |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
try: |
|
|
completion = client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
response = completion.choices[0].message.content |
|
|
|
|
|
|
|
|
sympy_result = try_sympy_compute(message) |
|
|
if sympy_result: |
|
|
response += "\n\n**Verified with SymPy (for exact symbolic computation):** $$" + sympy_result + "$$" |
|
|
|
|
|
return render_latex(response) |
|
|
except Exception as e: |
|
|
return f"❌ Error: {str(e)[:100]}... Try a simpler problem." |
|
|
|
|
|
def get_random_sample(): |
|
|
"""Get a random sample problem - loads datasets if needed""" |
|
|
global math_samples |
|
|
if math_samples is None: |
|
|
math_samples = load_sample_problems() |
|
|
return random.choice(math_samples) |
|
|
|
|
|
def insert_sample_to_chat(difficulty): |
|
|
"""Insert random sample into chat input""" |
|
|
return get_random_sample() |
|
|
|
|
|
def show_help(): |
|
|
return """**🧮 Math Help Tips:** |
|
|
|
|
|
1. Be Specific: "Find the derivative of x² + 3x" instead of "help with calculus" |
|
|
2. Request Steps: "Show me step-by-step how to solve..." |
|
|
3. Ask for Verification: "Check if my answer x=5 is correct" |
|
|
4. Alternative Methods: "What's another way to solve this integral?" |
|
|
5. Use Clear Notation: "lim(x->0)" for limits |
|
|
|
|
|
Pro Tip: Crank tokens to 1500+ for competition problems!""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="🧮 Mathetics AI") as demo: |
|
|
gr.Markdown("# 🧮 **Mathetics AI** - Math Tutor\nPowered by Qwen 2.5-Math") |
|
|
|
|
|
chatbot = gr.Chatbot(height=500, label="Conversation", type='messages') |
|
|
help_text = gr.Markdown(visible=False) |
|
|
|
|
|
msg = gr.Textbox(placeholder="Ask a math problem...", show_label=False) |
|
|
|
|
|
with gr.Row(): |
|
|
submit = gr.Button("Solve", variant="primary") |
|
|
clear = gr.Button("Clear", variant="secondary") |
|
|
sample = gr.Button("Random Problem", variant="secondary") |
|
|
help_btn = gr.Button("Help", variant="secondary") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["derivative of x^2 sin(x)"], |
|
|
["area of triangle 5-12-13"], |
|
|
["∫x^2 dx from 0 to 2"], |
|
|
["limit sin(x)/x as x to 0"], |
|
|
["matrix [[1,2],[3,4]]"] |
|
|
], |
|
|
inputs=msg |
|
|
) |
|
|
|
|
|
def chat_response(message, history): |
|
|
"""Updated to use dict-based history for type='messages'.""" |
|
|
bot_response = respond(message, history, create_math_system_message(), 1024, 0.3, 0.85) |
|
|
|
|
|
history.append({"role": "user", "content": message}) |
|
|
history.append({"role": "assistant", "content": bot_response}) |
|
|
return history, "" |
|
|
|
|
|
def clear_chat(): |
|
|
"""Clear the chat history and textbox.""" |
|
|
return [], "" |
|
|
|
|
|
msg.submit(chat_response, [msg, chatbot], [chatbot, msg]) |
|
|
submit.click(chat_response, [msg, chatbot], [chatbot, msg]) |
|
|
clear.click(clear_chat, outputs=[chatbot, msg]) |
|
|
sample.click(insert_sample_to_chat, outputs=msg) |
|
|
help_btn.click(lambda: (show_help(), gr.update(visible=True)), outputs=[help_text, help_text]).then( |
|
|
lambda: gr.update(visible=False), outputs=help_text |
|
|
) |
|
|
|
|
|
demo.launch() |