rahul7star commited on
Commit
829e77a
ยท
verified ยท
1 Parent(s): 817ccbc

Create app_train.py

Browse files
Files changed (1) hide show
  1. app_train.py +232 -0
app_train.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PromptWizard Qwen Training โ€” Configurable Dataset & Repo
3
+ Fine-tunes Qwen using a user-selected dataset and optionally uploads
4
+ the trained model to a Hugging Face Hub repo asynchronously with logs.
5
+ """
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
11
+ from datasets import load_dataset
12
+ from peft import LoraConfig, get_peft_model, TaskType
13
+ from huggingface_hub import upload_folder, HfFolder
14
+ import os, asyncio, threading
15
+ from datetime import datetime
16
+
17
+ # ==== Async upload wrapper ====
18
+ def start_async_upload(local_dir, hf_repo, output_log):
19
+ """Starts async model upload in a background thread."""
20
+ def runner():
21
+ output_log.append(f"[INFO] ๐Ÿš€ Async upload thread started for repo: {hf_repo}")
22
+ asyncio.run(async_upload_model(local_dir, hf_repo, output_log))
23
+ output_log.append(f"[INFO] ๐Ÿ›‘ Async upload thread finished for repo: {hf_repo}")
24
+ threading.Thread(target=runner, daemon=True).start()
25
+
26
+
27
+ async def async_upload_model(local_dir, hf_repo, output_log, max_retries=3):
28
+ """Upload model folder to HF Hub via HTTP API."""
29
+ try:
30
+ token = HfFolder.get_token()
31
+ output_log.append(f"[INFO] โ˜๏ธ Preparing to upload to repo: {hf_repo}")
32
+ attempt = 0
33
+ while attempt < max_retries:
34
+ try:
35
+ output_log.append(f"[INFO] ๐Ÿ”„ Attempt {attempt+1} to upload folder via HTTP API...")
36
+ upload_folder(folder_path=local_dir, repo_id=hf_repo, repo_type="model", token=token, ignore_patterns=["*.lock","*.tmp"], create_pr=False)
37
+ output_log.append("[SUCCESS] โœ… Model successfully uploaded to HF Hub!")
38
+ break
39
+ except Exception as e:
40
+ attempt += 1
41
+ output_log.append(f"[ERROR] Upload attempt {attempt} failed: {e}")
42
+ if attempt < max_retries:
43
+ output_log.append("[INFO] Retrying in 5 seconds...")
44
+ await asyncio.sleep(5)
45
+ else:
46
+ output_log.append("[ERROR] โŒ Max retries reached. Upload failed.")
47
+ except Exception as e:
48
+ output_log.append(f"[ERROR] โŒ Unexpected error during upload: {e}")
49
+
50
+
51
+ # ==== GPU check ====
52
+ def check_gpu_status():
53
+ return "๐Ÿš€ Zero GPU Ready - GPU will be allocated when training starts"
54
+
55
+ # ==== Logging helper ====
56
+ def log_message(output_log, msg):
57
+ line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
58
+ print(line)
59
+ output_log.append(line)
60
+
61
+ # ==== Train model ====
62
+ @spaces.GPU(duration=300)
63
+ def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo):
64
+ output_log = []
65
+ test_split = 0.2
66
+ mock_question = "Who is referred to as 'O best of Brahmanas' in the Bhagavad Gita?"
67
+
68
+ try:
69
+ log_message(output_log, "๐Ÿ” Initializing training sequence...")
70
+
71
+ # ===== Device =====
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ log_message(output_log, f"๐ŸŽฎ Using device: {device}")
74
+ if device == "cuda":
75
+ log_message(output_log, f"โœ… GPU: {torch.cuda.get_device_name(0)}")
76
+
77
+ # ===== Load dataset =====
78
+ log_message(output_log, f"\n๐Ÿ“š Loading dataset: {dataset_name} ...")
79
+ dataset = load_dataset(dataset_name)
80
+ dataset = dataset["train"].train_test_split(test_size=test_split)
81
+ train_dataset = dataset["train"]
82
+ test_dataset = dataset["test"]
83
+
84
+ log_message(output_log, f" Training samples: {len(train_dataset)}")
85
+ log_message(output_log, f" Test samples: {len(test_dataset)}")
86
+
87
+ # ===== Format examples =====
88
+ def format_example(item):
89
+ text = item.get("text") or item.get("content") or " ".join(str(v) for v in item.values())
90
+ prompt = f"""<|system|>
91
+ You are a wise teacher interpreting Bhagavad Gita with deep insights.
92
+ <|user|>
93
+ {text}
94
+ <|assistant|>
95
+ """
96
+ return {"text": prompt}
97
+
98
+ train_dataset = train_dataset.map(format_example)
99
+ test_dataset = test_dataset.map(format_example)
100
+ log_message(output_log, f"โœ… Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
101
+
102
+ # ===== Load model & tokenizer =====
103
+ log_message(output_log, f"\n๐Ÿค– Loading model: {base_model}")
104
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
105
+ if tokenizer.pad_token is None:
106
+ tokenizer.pad_token = tokenizer.eos_token
107
+
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ base_model,
110
+ trust_remote_code=True,
111
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
112
+ low_cpu_mem_usage=True,
113
+ )
114
+ if device == "cuda":
115
+ model = model.to(device)
116
+ log_message(output_log, "โœ… Model and tokenizer loaded successfully")
117
+
118
+ # ===== LoRA configuration =====
119
+ log_message(output_log, "\nโš™๏ธ Configuring LoRA for efficient fine-tuning...")
120
+ lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_proj","v_proj"], bias="none")
121
+ model = get_peft_model(model, lora_config)
122
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
123
+ log_message(output_log, f"Trainable params after LoRA: {trainable_params:,}")
124
+
125
+ # ===== Tokenization + labels =====
126
+ def tokenize_fn(examples):
127
+ tokenized = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=256)
128
+ tokenized["labels"] = tokenized["input_ids"].copy()
129
+ return tokenized
130
+
131
+ train_dataset = train_dataset.map(tokenize_fn, batched=True)
132
+ test_dataset = test_dataset.map(tokenize_fn, batched=True)
133
+ log_message(output_log, "โœ… Tokenization + labels done")
134
+
135
+ # ===== Training arguments =====
136
+ output_dir = "./qwen-gita-lora"
137
+ training_args = TrainingArguments(
138
+ output_dir=output_dir,
139
+ num_train_epochs=num_epochs,
140
+ per_device_train_batch_size=batch_size,
141
+ gradient_accumulation_steps=2,
142
+ warmup_steps=10,
143
+ logging_steps=5,
144
+ save_strategy="epoch",
145
+ fp16=device=="cuda",
146
+ optim="adamw_torch",
147
+ learning_rate=learning_rate,
148
+ max_steps=100,
149
+ )
150
+
151
+ trainer = Trainer(
152
+ model=model,
153
+ args=training_args,
154
+ train_dataset=train_dataset,
155
+ eval_dataset=test_dataset,
156
+ tokenizer=tokenizer,
157
+ )
158
+
159
+ # ===== Train =====
160
+ log_message(output_log, "\n๐Ÿš€ Starting training...")
161
+ trainer.train()
162
+ log_message(output_log, "\nโœ… Training finished!")
163
+
164
+ # ===== Test with mock question =====
165
+ inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{mock_question}\n<|assistant|>\n", return_tensors="pt").to(device)
166
+ outputs = model.generate(**inputs, max_new_tokens=100)
167
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
168
+ log_message(output_log, f"\n๐Ÿงช Mock Question Test:\nQ: {mock_question}\nA: {answer}")
169
+
170
+ # ===== Save locally (optional upload later) =====
171
+ trainer.save_model(output_dir)
172
+ tokenizer.save_pretrained(output_dir)
173
+
174
+ log_message(output_log, "\nโœ… Model saved locally. You can now review the mock answer before uploading.")
175
+
176
+ except Exception as e:
177
+ log_message(output_log, f"\nโŒ Error during training: {e}")
178
+
179
+ return "\n".join(output_log), output_dir, mock_question
180
+
181
+ # ==== Gradio Interface ====
182
+ def create_interface():
183
+ with gr.Blocks(title="PromptWizard โ€” Qwen Trainer") as demo:
184
+ gr.Markdown("""
185
+ # ๐Ÿง˜ PromptWizard Qwen Fine-tuning
186
+ Fine-tune Qwen on any dataset and optionally upload to HF Hub.
187
+ """)
188
+
189
+ with gr.Row():
190
+ with gr.Column():
191
+ gr.Textbox(label="GPU Status", value=check_gpu_status(), interactive=False)
192
+ base_model = gr.Textbox(label="Base Model", value="Qwen/Qwen2.5-0.5B")
193
+ dataset_name = gr.Textbox(label="Dataset Name", value="rahul7star/Gita")
194
+ hf_repo = gr.Textbox(label="HF Repo for Upload", value="rahul7star/Qwen0.5-3B-Gita")
195
+ num_epochs = gr.Slider(1, 3, value=1, step=1, label="Epochs")
196
+ batch_size = gr.Slider(1, 4, value=2, step=1, label="Batch Size")
197
+ learning_rate = gr.Number(value=5e-5, label="Learning Rate")
198
+ train_btn = gr.Button("๐Ÿš€ Start Fine-tuning", variant="primary")
199
+ upload_btn = gr.Button("โ˜๏ธ Upload Model to HF Hub", variant="secondary", interactive=False)
200
+
201
+ with gr.Column():
202
+ output = gr.Textbox(label="Training Log", lines=25, max_lines=40,
203
+ value="Click 'Start Fine-tuning' to train your model.")
204
+
205
+ # ==== Train button ====
206
+ def train_click(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo):
207
+ log, output_dir, mock_question = train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo)
208
+ return log, True, output_dir
209
+
210
+ train_btn.click(
211
+ fn=train_click,
212
+ inputs=[base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo],
213
+ outputs=[output, upload_btn, hf_repo],
214
+ )
215
+
216
+ # ==== Upload button ====
217
+ def upload_click(hf_repo):
218
+ output_log = []
219
+ start_async_upload("./qwen-gita-lora", hf_repo, output_log)
220
+ return "\n".join(output_log)
221
+
222
+ upload_btn.click(
223
+ fn=upload_click,
224
+ inputs=[hf_repo],
225
+ outputs=output,
226
+ )
227
+
228
+ return demo
229
+
230
+ if __name__ == "__main__":
231
+ demo = create_interface()
232
+ demo.launch(server_name="0.0.0.0", server_port=7860)