Spaces:
Running
Running
| import os | |
| import torch | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| Trainer, | |
| TrainingArguments, | |
| DataCollatorForLanguageModeling, | |
| ) | |
| from huggingface_hub import HfApi, HfFolder | |
| # --------------------------------------------------------------------- | |
| # GPU check | |
| # --------------------------------------------------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def check_gpu_status(): | |
| return f"✅ GPU: {torch.cuda.get_device_name(0)}" if device == "cuda" else "⚠️ Using CPU only" | |
| # ------------------------------------------------------ | |
| # 🧩 Download Dataset to /tmp/ | |
| # ------------------------------------------------------ | |
| def download_gita_dataset(): | |
| repo_id = "rahul7star/Gita" | |
| local_dir = "/tmp/gita_data" | |
| if os.path.exists(local_dir): | |
| shutil.rmtree(local_dir) | |
| os.makedirs(local_dir, exist_ok=True) | |
| print(f"📥 Downloading dataset from {repo_id} ...") | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir, repo_type="dataset") | |
| # Try to locate the CSV file | |
| csv_path = None | |
| for root, _, files in os.walk(local_dir): | |
| for f in files: | |
| if f.lower().endswith(".csv"): | |
| csv_path = os.path.join(root, f) | |
| break | |
| if not csv_path: | |
| raise FileNotFoundError("No CSV file found in the Gita dataset repository.") | |
| print(f"✅ Found CSV: {csv_path}") | |
| return csv_path | |
| # ------------------------------------------------------ | |
| # 🚀 Training function | |
| # ------------------------------------------------------ | |
| # --------------------------------------------------------------------- | |
| # Training Logic | |
| # --------------------------------------------------------------------- | |
| def train_model(model_name, num_epochs, batch_size, learning_rate, progress=gr.Progress(track_tqdm=True)): | |
| output_log = [] | |
| # ==== Load dataset ==== | |
| progress(0.1, desc="Loading rahul7star/Gita dataset...") | |
| output_log.append("\n📚 Loading dataset from rahul7star/Gita...") | |
| dataset = load_dataset("rahul7star/Gita", split="train") | |
| output_log.append(f" Loaded {len(dataset)} samples") | |
| output_log.append(f" Columns: {dataset.column_names}") | |
| # ==== Format dataset ==== | |
| def format_example(item): | |
| text = ( | |
| item.get("text") | |
| or item.get("content") | |
| or item.get("verse") | |
| or " ".join(str(v) for v in item.values()) | |
| ) | |
| prompt = f"""<|system|> | |
| You are a wise teacher interpreting Bhagavad Gita with deep insights. | |
| <|user|> | |
| {text} | |
| <|assistant|> | |
| """ | |
| return {"text": prompt} | |
| dataset = dataset.map(format_example) | |
| output_log.append(f" ✅ Formatted {len(dataset)} examples") | |
| # ==== Load tokenizer & model ==== | |
| progress(0.3, desc="Loading model and tokenizer...") | |
| output_log.append("\n🤖 Loading Qwen model and tokenizer...") | |
| base_model = "Qwen/Qwen2.5-0.5B" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| # Fix missing pad token | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ).to(device) | |
| # ==== Tokenize dataset ==== | |
| progress(0.4, desc="Tokenizing dataset...") | |
| output_log.append("\n✏️ Tokenizing dataset...") | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| ) | |
| tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) | |
| output_log.append(f" ✅ Tokenized {len(tokenized_dataset)} samples") | |
| # ==== Training setup ==== | |
| progress(0.5, desc="Starting training...") | |
| output_log.append("\n⚙️ Preparing Trainer...") | |
| output_dir = "./Qwen-Gita-Checkpoints" | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| overwrite_output_dir=True, | |
| per_device_train_batch_size=batch_size, | |
| num_train_epochs=num_epochs, | |
| learning_rate=learning_rate, | |
| fp16=device == "cuda", | |
| save_steps=100, | |
| logging_steps=10, | |
| save_total_limit=1, | |
| ) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| # ==== Train ==== | |
| output_log.append("\n🚀 Training started ...") | |
| trainer.train() | |
| output_log.append("✅ Training complete!") | |
| # ==== Push to Hugging Face Hub ==== | |
| repo_id = "rahul7star/Qwen0.5-3B-Gita" | |
| output_log.append(f"\n☁️ Uploading to Hugging Face Hub: {repo_id}") | |
| api = HfApi() | |
| token = HfFolder.get_token() | |
| model.push_to_hub(repo_id, token=token) | |
| tokenizer.push_to_hub(repo_id, token=token) | |
| output_log.append(f"✅ Model uploaded successfully to {repo_id}") | |
| return "\n".join(output_log) | |
| # --------------------------------------------------------------------- | |
| # Gradio Interface | |
| # --------------------------------------------------------------------- | |
| def create_interface(): | |
| with gr.Blocks(title="🧘 Qwen Gita Trainer") as demo: | |
| gr.Markdown(""" | |
| # 🧘 Fine-tune Qwen 0.5B on Bhagavad Gita | |
| This app downloads `rahul7star/Gita`, trains the model to become a Gita teacher, | |
| and uploads results to `rahul7star/Qwen0.5-3B-Gita`. | |
| """) | |
| gpu_status = gr.Textbox(value=check_gpu_status(), label="GPU Status", interactive=False) | |
| model_name = gr.Textbox(value="Qwen/Qwen2.5-0.5B", label="Base Model", visible=False) | |
| num_epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs") | |
| batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size") | |
| learning_rate = gr.Number(value=5e-5, label="Learning Rate") | |
| train_btn = gr.Button("🚀 Start Fine-tuning", variant="primary") | |
| output = gr.Textbox(label="Training Log", lines=30) | |
| train_btn.click( | |
| fn=train_model, | |
| inputs=[model_name, num_epochs, batch_size, learning_rate], | |
| outputs=output, | |
| ) | |
| return demo | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| demo.launch() |