Spaces:
Sleeping
Sleeping
Update app_train.py
Browse files- app_train.py +35 -44
app_train.py
CHANGED
|
@@ -60,7 +60,7 @@ def log_message(output_log, msg):
|
|
| 60 |
|
| 61 |
# ==== Train model ====
|
| 62 |
@spaces.GPU(duration=300)
|
| 63 |
-
def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate
|
| 64 |
output_log = []
|
| 65 |
test_split = 0.2
|
| 66 |
mock_question = "Who is referred to as 'O best of Brahmanas' in the Bhagavad Gita?"
|
|
@@ -81,9 +81,6 @@ def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate,
|
|
| 81 |
train_dataset = dataset["train"]
|
| 82 |
test_dataset = dataset["test"]
|
| 83 |
|
| 84 |
-
log_message(output_log, f" Training samples: {len(train_dataset)}")
|
| 85 |
-
log_message(output_log, f" Test samples: {len(test_dataset)}")
|
| 86 |
-
|
| 87 |
# ===== Format examples =====
|
| 88 |
def format_example(item):
|
| 89 |
text = item.get("text") or item.get("content") or " ".join(str(v) for v in item.values())
|
|
@@ -100,7 +97,6 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
|
|
| 100 |
log_message(output_log, f"✅ Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
|
| 101 |
|
| 102 |
# ===== Load model & tokenizer =====
|
| 103 |
-
log_message(output_log, f"\n🤖 Loading model: {base_model}")
|
| 104 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 105 |
if tokenizer.pad_token is None:
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
|
@@ -108,19 +104,15 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
|
|
| 108 |
model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
base_model,
|
| 110 |
trust_remote_code=True,
|
| 111 |
-
torch_dtype=torch.float16 if device
|
| 112 |
-
low_cpu_mem_usage=True
|
| 113 |
)
|
| 114 |
if device == "cuda":
|
| 115 |
model = model.to(device)
|
| 116 |
-
log_message(output_log, "✅ Model and tokenizer loaded successfully")
|
| 117 |
|
| 118 |
# ===== LoRA configuration =====
|
| 119 |
-
log_message(output_log, "\n⚙️ Configuring LoRA for efficient fine-tuning...")
|
| 120 |
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_proj","v_proj"], bias="none")
|
| 121 |
model = get_peft_model(model, lora_config)
|
| 122 |
-
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 123 |
-
log_message(output_log, f"Trainable params after LoRA: {trainable_params:,}")
|
| 124 |
|
| 125 |
# ===== Tokenization + labels =====
|
| 126 |
def tokenize_fn(examples):
|
|
@@ -130,9 +122,8 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
|
|
| 130 |
|
| 131 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 132 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
| 133 |
-
log_message(output_log, "✅ Tokenization + labels done")
|
| 134 |
|
| 135 |
-
# ===== Training
|
| 136 |
output_dir = "./qwen-gita-lora"
|
| 137 |
training_args = TrainingArguments(
|
| 138 |
output_dir=output_dir,
|
|
@@ -145,45 +136,35 @@ You are a wise teacher interpreting Bhagavad Gita with deep insights.
|
|
| 145 |
fp16=device=="cuda",
|
| 146 |
optim="adamw_torch",
|
| 147 |
learning_rate=learning_rate,
|
| 148 |
-
max_steps=100
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
trainer = Trainer(
|
| 152 |
-
model=model,
|
| 153 |
-
args=training_args,
|
| 154 |
-
train_dataset=train_dataset,
|
| 155 |
-
eval_dataset=test_dataset,
|
| 156 |
-
tokenizer=tokenizer,
|
| 157 |
)
|
| 158 |
|
| 159 |
-
|
| 160 |
log_message(output_log, "\n🚀 Starting training...")
|
| 161 |
trainer.train()
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
# =====
|
| 165 |
inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{mock_question}\n<|assistant|>\n", return_tensors="pt").to(device)
|
| 166 |
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 167 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 168 |
-
log_message(output_log, f"\n🧪 Mock Question
|
| 169 |
-
|
| 170 |
-
# ===== Save locally (optional upload later) =====
|
| 171 |
-
trainer.save_model(output_dir)
|
| 172 |
-
tokenizer.save_pretrained(output_dir)
|
| 173 |
|
| 174 |
-
|
|
|
|
| 175 |
|
| 176 |
except Exception as e:
|
| 177 |
log_message(output_log, f"\n❌ Error during training: {e}")
|
| 178 |
-
|
| 179 |
-
return "\n".join(output_log), output_dir, mock_question
|
| 180 |
|
| 181 |
# ==== Gradio Interface ====
|
| 182 |
def create_interface():
|
| 183 |
with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
|
| 184 |
gr.Markdown("""
|
| 185 |
# 🧘 PromptWizard Qwen Fine-tuning
|
| 186 |
-
Fine-tune Qwen
|
| 187 |
""")
|
| 188 |
|
| 189 |
with gr.Row():
|
|
@@ -201,29 +182,39 @@ def create_interface():
|
|
| 201 |
with gr.Column():
|
| 202 |
output = gr.Textbox(label="Training Log", lines=25, max_lines=40,
|
| 203 |
value="Click 'Start Fine-tuning' to train your model.")
|
|
|
|
|
|
|
| 204 |
|
| 205 |
# ==== Train button ====
|
| 206 |
-
def train_click(base_model, dataset_name, num_epochs, batch_size, learning_rate
|
| 207 |
-
log,
|
| 208 |
-
return log, True, output_dir
|
| 209 |
|
| 210 |
train_btn.click(
|
| 211 |
fn=train_click,
|
| 212 |
-
inputs=[base_model, dataset_name, num_epochs, batch_size, learning_rate
|
| 213 |
-
outputs=[output, upload_btn,
|
| 214 |
)
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# ==== Upload button ====
|
| 217 |
def upload_click(hf_repo):
|
| 218 |
output_log = []
|
| 219 |
start_async_upload("./qwen-gita-lora", hf_repo, output_log)
|
| 220 |
return "\n".join(output_log)
|
| 221 |
|
| 222 |
-
upload_btn.click(
|
| 223 |
-
fn=upload_click,
|
| 224 |
-
inputs=[hf_repo],
|
| 225 |
-
outputs=output,
|
| 226 |
-
)
|
| 227 |
|
| 228 |
return demo
|
| 229 |
|
|
|
|
| 60 |
|
| 61 |
# ==== Train model ====
|
| 62 |
@spaces.GPU(duration=300)
|
| 63 |
+
def train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate):
|
| 64 |
output_log = []
|
| 65 |
test_split = 0.2
|
| 66 |
mock_question = "Who is referred to as 'O best of Brahmanas' in the Bhagavad Gita?"
|
|
|
|
| 81 |
train_dataset = dataset["train"]
|
| 82 |
test_dataset = dataset["test"]
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
# ===== Format examples =====
|
| 85 |
def format_example(item):
|
| 86 |
text = item.get("text") or item.get("content") or " ".join(str(v) for v in item.values())
|
|
|
|
| 97 |
log_message(output_log, f"✅ Formatted {len(train_dataset)} train + {len(test_dataset)} test examples")
|
| 98 |
|
| 99 |
# ===== Load model & tokenizer =====
|
|
|
|
| 100 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 101 |
if tokenizer.pad_token is None:
|
| 102 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 104 |
model = AutoModelForCausalLM.from_pretrained(
|
| 105 |
base_model,
|
| 106 |
trust_remote_code=True,
|
| 107 |
+
torch_dtype=torch.float16 if device=="cuda" else torch.float32,
|
| 108 |
+
low_cpu_mem_usage=True
|
| 109 |
)
|
| 110 |
if device == "cuda":
|
| 111 |
model = model.to(device)
|
|
|
|
| 112 |
|
| 113 |
# ===== LoRA configuration =====
|
|
|
|
| 114 |
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_proj","v_proj"], bias="none")
|
| 115 |
model = get_peft_model(model, lora_config)
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# ===== Tokenization + labels =====
|
| 118 |
def tokenize_fn(examples):
|
|
|
|
| 122 |
|
| 123 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 124 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
|
|
|
| 125 |
|
| 126 |
+
# ===== Training =====
|
| 127 |
output_dir = "./qwen-gita-lora"
|
| 128 |
training_args = TrainingArguments(
|
| 129 |
output_dir=output_dir,
|
|
|
|
| 136 |
fp16=device=="cuda",
|
| 137 |
optim="adamw_torch",
|
| 138 |
learning_rate=learning_rate,
|
| 139 |
+
max_steps=100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer)
|
| 143 |
log_message(output_log, "\n🚀 Starting training...")
|
| 144 |
trainer.train()
|
| 145 |
+
trainer.save_model(output_dir)
|
| 146 |
+
tokenizer.save_pretrained(output_dir)
|
| 147 |
+
log_message(output_log, "\n✅ Training finished and model saved locally.")
|
| 148 |
|
| 149 |
+
# ===== Mock question response =====
|
| 150 |
inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{mock_question}\n<|assistant|>\n", return_tensors="pt").to(device)
|
| 151 |
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 152 |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 153 |
+
log_message(output_log, f"\n🧪 Mock Question:\nQ: {mock_question}\nA: {answer}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
# Return model and tokenizer for interactive questions
|
| 156 |
+
return "\n".join(output_log), model, tokenizer, output_dir
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
log_message(output_log, f"\n❌ Error during training: {e}")
|
| 160 |
+
return "\n".join(output_log), None, None, None
|
|
|
|
| 161 |
|
| 162 |
# ==== Gradio Interface ====
|
| 163 |
def create_interface():
|
| 164 |
with gr.Blocks(title="PromptWizard — Qwen Trainer") as demo:
|
| 165 |
gr.Markdown("""
|
| 166 |
# 🧘 PromptWizard Qwen Fine-tuning
|
| 167 |
+
Fine-tune Qwen and interact with it before optional upload.
|
| 168 |
""")
|
| 169 |
|
| 170 |
with gr.Row():
|
|
|
|
| 182 |
with gr.Column():
|
| 183 |
output = gr.Textbox(label="Training Log", lines=25, max_lines=40,
|
| 184 |
value="Click 'Start Fine-tuning' to train your model.")
|
| 185 |
+
user_question = gr.Textbox(label="Ask your own question", placeholder="Type a question...")
|
| 186 |
+
answer_box = gr.Textbox(label="Answer", lines=5, interactive=False)
|
| 187 |
|
| 188 |
# ==== Train button ====
|
| 189 |
+
def train_click(base_model, dataset_name, num_epochs, batch_size, learning_rate):
|
| 190 |
+
log, model, tokenizer, output_dir = train_model(base_model, dataset_name, num_epochs, batch_size, learning_rate)
|
| 191 |
+
return log, True, model, tokenizer, output_dir
|
| 192 |
|
| 193 |
train_btn.click(
|
| 194 |
fn=train_click,
|
| 195 |
+
inputs=[base_model, dataset_name, num_epochs, batch_size, learning_rate],
|
| 196 |
+
outputs=[output, upload_btn, gr.State(), gr.State(), gr.State()],
|
| 197 |
)
|
| 198 |
|
| 199 |
+
# ==== User question ====
|
| 200 |
+
def ask_question(user_input, model, tokenizer):
|
| 201 |
+
if not model or not tokenizer:
|
| 202 |
+
return "Model not loaded yet."
|
| 203 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 204 |
+
inputs = tokenizer(f"<|system|>\nYou are a wise teacher interpreting Bhagavad Gita.\n<|user|>\n{user_input}\n<|assistant|>\n", return_tensors="pt").to(device)
|
| 205 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
| 206 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 207 |
+
return answer
|
| 208 |
+
|
| 209 |
+
user_question.submit(ask_question, inputs=[user_question, gr.State(), gr.State()], outputs=[answer_box])
|
| 210 |
+
|
| 211 |
# ==== Upload button ====
|
| 212 |
def upload_click(hf_repo):
|
| 213 |
output_log = []
|
| 214 |
start_async_upload("./qwen-gita-lora", hf_repo, output_log)
|
| 215 |
return "\n".join(output_log)
|
| 216 |
|
| 217 |
+
upload_btn.click(upload_click, inputs=[hf_repo], outputs=[output])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
return demo
|
| 220 |
|