rahul7star commited on
Commit
519cd24
·
verified ·
1 Parent(s): 5f1ec06

Update app_train.py

Browse files
Files changed (1) hide show
  1. app_train.py +35 -44
app_train.py CHANGED
@@ -60,7 +60,7 @@ def log_message(output_log, msg):
60
 
61
  # ==== Train model ====
62
  @spaces.GPU(duration=300)
63
- def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo):
64
  output_log = []
65
  test_split = 0.2
66
  mock_question = "Who is referred to as 'O best of Brahmanas' in the Bhagavad Gita?"
@@ -81,9 +81,6 @@ def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate,
81
  train_dataset = dataset["train"]
82
  test_dataset = dataset["test"]
83
 
84
- log_message(output_log, f" Training samples: {len(train_dataset)}")
85
- log_message(output_log, f" Test samples: {len(test_dataset)}")
86
-
87
  # ===== Format examples =====
88
  def format_example(item):
89
  text = item.get("text") or item.get("content") or " ".join(str(v) for v in item.values())
@@ -100,7 +97,6 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
100
  log_message(output_log, f"✅ Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
101
 
102
  # ===== Load model & tokenizer =====
103
- log_message(output_log, f"\n🤖 Loading model: {base_model}")
104
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
105
  if tokenizer.pad_token is None:
106
  tokenizer.pad_token = tokenizer.eos_token
@@ -108,19 +104,15 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
108
  model = AutoModelForCausalLM.from_pretrained(
109
  base_model,
110
  trust_remote_code=True,
111
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
112
- low_cpu_mem_usage=True,
113
  )
114
  if device == "cuda":
115
  model = model.to(device)
116
- log_message(output_log, "✅ Model and tokenizer loaded successfully")
117
 
118
  # ===== LoRA configuration =====
119
- log_message(output_log, "\n⚙️ Configuring LoRA for efficient fine-tuning...")
120
  lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_proj","v_proj"], bias="none")
121
  model = get_peft_model(model, lora_config)
122
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
123
- log_message(output_log, f"Trainable params after LoRA: {trainable_params:,}")
124
 
125
  # ===== Tokenization + labels =====
126
  def tokenize_fn(examples):
@@ -130,9 +122,8 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
130
 
131
  train_dataset = train_dataset.map(tokenize_fn, batched=True)
132
  test_dataset = test_dataset.map(tokenize_fn, batched=True)
133
- log_message(output_log, "✅ Tokenization + labels done")
134
 
135
- # ===== Training arguments =====
136
  output_dir = "./qwen-gita-lora"
137
  training_args = TrainingArguments(
138
  output_dir=output_dir,
@@ -145,45 +136,35 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
145
  fp16=device=="cuda",
146
  optim="adamw_torch",
147
  learning_rate=learning_rate,
148
- max_steps=100,
149
- )
150
-
151
- trainer = Trainer(
152
- model=model,
153
- args=training_args,
154
- train_dataset=train_dataset,
155
- eval_dataset=test_dataset,
156
- tokenizer=tokenizer,
157
  )
158
 
159
- # ===== Train =====
160
  log_message(output_log, "\n🚀 Starting training...")
161
  trainer.train()
162
- log_message(output_log, "\n✅ Training finished!")
 
 
163
 
164
- # ===== Test with mock question =====
165
  inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{mock_question}\n<|assistant|>\n", return_tensors="pt").to(device)
166
  outputs = model.generate(**inputs, max_new_tokens=100)
167
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
168
- log_message(output_log, f"\n🧪 Mock Question Test:\nQ: {mock_question}\nA: {answer}")
169
-
170
- # ===== Save locally (optional upload later) =====
171
- trainer.save_model(output_dir)
172
- tokenizer.save_pretrained(output_dir)
173
 
174
- log_message(output_log, "\n✅ Model saved locally. You can now review the mock answer before uploading.")
 
175
 
176
  except Exception as e:
177
  log_message(output_log, f"\n❌ Error during training: {e}")
178
-
179
- return "\n".join(output_log), output_dir, mock_question
180
 
181
  # ==== Gradio Interface ====
182
  def create_interface():
183
  with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
184
  gr.Markdown("""
185
  # 🧘 PromptWizard Qwen Fine-tuning
186
- Fine-tune Qwen on any dataset and optionally upload to HF Hub.
187
  """)
188
 
189
  with gr.Row():
@@ -201,29 +182,39 @@ def create_interface():
201
  with gr.Column():
202
  output = gr.Textbox(label="Training Log", lines=25, max_lines=40,
203
  value="Click 'Start Fine-tuning' to train your model.")
 
 
204
 
205
  # ==== Train button ====
206
- def train_click(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo):
207
- log, output_dir, mock_question = train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo)
208
- return log, True, output_dir
209
 
210
  train_btn.click(
211
  fn=train_click,
212
- inputs=[base_model, dataset_name, num_epochs, batch_size, learning_rate, hf_repo],
213
- outputs=[output, upload_btn, hf_repo],
214
  )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # ==== Upload button ====
217
  def upload_click(hf_repo):
218
  output_log = []
219
  start_async_upload("./qwen-gita-lora", hf_repo, output_log)
220
  return "\n".join(output_log)
221
 
222
- upload_btn.click(
223
- fn=upload_click,
224
- inputs=[hf_repo],
225
- outputs=output,
226
- )
227
 
228
  return demo
229
 
 
60
 
61
  # ==== Train model ====
62
  @spaces.GPU(duration=300)
63
+ def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate):
64
  output_log = []
65
  test_split = 0.2
66
  mock_question = "Who is referred to as 'O best of Brahmanas' in the Bhagavad Gita?"
 
81
  train_dataset = dataset["train"]
82
  test_dataset = dataset["test"]
83
 
 
 
 
84
  # ===== Format examples =====
85
  def format_example(item):
86
  text = item.get("text") or item.get("content") or " ".join(str(v) for v in item.values())
 
97
  log_message(output_log, f"✅ Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
98
 
99
  # ===== Load model & tokenizer =====
 
100
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
101
  if tokenizer.pad_token is None:
102
  tokenizer.pad_token = tokenizer.eos_token
 
104
  model = AutoModelForCausalLM.from_pretrained(
105
  base_model,
106
  trust_remote_code=True,
107
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32,
108
+ low_cpu_mem_usage=True
109
  )
110
  if device == "cuda":
111
  model = model.to(device)
 
112
 
113
  # ===== LoRA configuration =====
 
114
  lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_proj","v_proj"], bias="none")
115
  model = get_peft_model(model, lora_config)
 
 
116
 
117
  # ===== Tokenization + labels =====
118
  def tokenize_fn(examples):
 
122
 
123
  train_dataset = train_dataset.map(tokenize_fn, batched=True)
124
  test_dataset = test_dataset.map(tokenize_fn, batched=True)
 
125
 
126
+ # ===== Training =====
127
  output_dir = "./qwen-gita-lora"
128
  training_args = TrainingArguments(
129
  output_dir=output_dir,
 
136
  fp16=device=="cuda",
137
  optim="adamw_torch",
138
  learning_rate=learning_rate,
139
+ max_steps=100
 
 
 
 
 
 
 
 
140
  )
141
 
142
+ trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer)
143
  log_message(output_log, "\n🚀 Starting training...")
144
  trainer.train()
145
+ trainer.save_model(output_dir)
146
+ tokenizer.save_pretrained(output_dir)
147
+ log_message(output_log, "\n✅ Training finished and model saved locally.")
148
 
149
+ # ===== Mock question response =====
150
  inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{mock_question}\n<|assistant|>\n", return_tensors="pt").to(device)
151
  outputs = model.generate(**inputs, max_new_tokens=100)
152
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+ log_message(output_log, f"\n🧪 Mock Question:\nQ: {mock_question}\nA: {answer}")
 
 
 
 
154
 
155
+ # Return model and tokenizer for interactive questions
156
+ return "\n".join(output_log), model, tokenizer, output_dir
157
 
158
  except Exception as e:
159
  log_message(output_log, f"\n❌ Error during training: {e}")
160
+ return "\n".join(output_log), None, None, None
 
161
 
162
  # ==== Gradio Interface ====
163
  def create_interface():
164
  with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
165
  gr.Markdown("""
166
  # 🧘 PromptWizard Qwen Fine-tuning
167
+ Fine-tune Qwen and interact with it before optional upload.
168
  """)
169
 
170
  with gr.Row():
 
182
  with gr.Column():
183
  output = gr.Textbox(label="Training Log", lines=25, max_lines=40,
184
  value="Click 'Start Fine-tuning' to train your model.")
185
+ user_question = gr.Textbox(label="Ask your own question", placeholder="Type a question...")
186
+ answer_box = gr.Textbox(label="Answer", lines=5, interactive=False)
187
 
188
  # ==== Train button ====
189
+ def train_click(base_model, dataset_name, num_epochs, batch_size, learning_rate):
190
+ log, model, tokenizer, output_dir = train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate)
191
+ return log, True, model, tokenizer, output_dir
192
 
193
  train_btn.click(
194
  fn=train_click,
195
+ inputs=[base_model, dataset_name, num_epochs, batch_size, learning_rate],
196
+ outputs=[output, upload_btn, gr.State(), gr.State(), gr.State()],
197
  )
198
 
199
+ # ==== User question ====
200
+ def ask_question(user_input, model, tokenizer):
201
+ if not model or not tokenizer:
202
+ return "Model not loaded yet."
203
+ device = "cuda" if torch.cuda.is_available() else "cpu"
204
+ inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{user_input}\n<|assistant|>\n", return_tensors="pt").to(device)
205
+ outputs = model.generate(**inputs, max_new_tokens=100)
206
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+ return answer
208
+
209
+ user_question.submit(ask_question, inputs=[user_question, gr.State(), gr.State()], outputs=[answer_box])
210
+
211
  # ==== Upload button ====
212
  def upload_click(hf_repo):
213
  output_log = []
214
  start_async_upload("./qwen-gita-lora", hf_repo, output_log)
215
  return "\n".join(output_log)
216
 
217
+ upload_btn.click(upload_click, inputs=[hf_repo], outputs=[output])
 
 
 
 
218
 
219
  return demo
220