evalstate HF Staff commited on
Commit
bb57d5b
Β·
verified Β·
1 Parent(s): d296218

Upload train_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_demo.py +23 -17
train_demo.py CHANGED
@@ -16,57 +16,60 @@ from trl import SFTTrainer, SFTConfig
16
  # Initialize Trackio for real-time monitoring
17
  trackio.init(
18
  project="qwen-demo-sft",
19
- space_id="evalstate/training-demo-dashboard",
20
  config={
21
  "model": "Qwen/Qwen2.5-0.5B",
22
  "dataset": "trl-lib/Capybara",
 
23
  "learning_rate": 2e-5,
24
  "max_steps": 20,
25
- "peft_method": "LoRA",
26
  }
27
  )
28
 
29
- # Load small subset for quick demo
30
  dataset = load_dataset("trl-lib/Capybara", split="train[:50]")
31
  print(f"βœ… Dataset loaded: {len(dataset)} examples")
 
32
 
33
  # Training configuration
34
  config = SFTConfig(
35
- # CRITICAL: Hub settings
36
  output_dir="qwen-demo-sft",
37
  push_to_hub=True,
38
  hub_model_id="evalstate/qwen-demo-sft",
 
39
 
40
- # Quick demo settings
41
- max_steps=20,
42
  per_device_train_batch_size=2,
43
  gradient_accumulation_steps=2,
44
  learning_rate=2e-5,
45
 
46
  # Logging
47
  logging_steps=5,
48
- save_strategy="steps",
49
- save_steps=20,
50
 
51
  # Optimization
52
- warmup_steps=2,
53
  lr_scheduler_type="cosine",
54
 
55
  # Monitoring
56
  report_to="trackio",
57
  )
58
 
59
- # LoRA configuration
60
  peft_config = LoraConfig(
61
- r=16,
62
- lora_alpha=32,
63
  lora_dropout=0.05,
64
  bias="none",
65
  task_type="CAUSAL_LM",
66
  target_modules=["q_proj", "v_proj"],
67
  )
68
 
69
- # Initialize and train
 
70
  trainer = SFTTrainer(
71
  model="Qwen/Qwen2.5-0.5B",
72
  train_dataset=dataset,
@@ -74,14 +77,17 @@ trainer = SFTTrainer(
74
  peft_config=peft_config,
75
  )
76
 
77
- print("πŸš€ Starting training...")
 
78
  trainer.train()
79
 
80
- print("πŸ’Ύ Pushing to Hub...")
 
81
  trainer.push_to_hub()
82
 
83
  # Finish Trackio tracking
84
  trackio.finish()
85
 
86
- print("βœ… Complete! Model at: https://huggingface.co/evalstate/qwen-demo-sft")
87
- print("πŸ“Š View metrics at: https://huggingface.co/spaces/evalstate/training-demo-dashboard")
 
 
16
  # Initialize Trackio for real-time monitoring
17
  trackio.init(
18
  project="qwen-demo-sft",
19
+ space_id="evalstate/trackio-demo", # Will auto-create if doesn't exist
20
  config={
21
  "model": "Qwen/Qwen2.5-0.5B",
22
  "dataset": "trl-lib/Capybara",
23
+ "dataset_size": 50,
24
  "learning_rate": 2e-5,
25
  "max_steps": 20,
26
+ "demo": True,
27
  }
28
  )
29
 
30
+ # Load dataset (only 50 examples for quick demo)
31
  dataset = load_dataset("trl-lib/Capybara", split="train[:50]")
32
  print(f"βœ… Dataset loaded: {len(dataset)} examples")
33
+ print(f"πŸ“ Sample: {dataset[0]}")
34
 
35
  # Training configuration
36
  config = SFTConfig(
37
+ # Hub settings - CRITICAL for saving results
38
  output_dir="qwen-demo-sft",
39
  push_to_hub=True,
40
  hub_model_id="evalstate/qwen-demo-sft",
41
+ hub_strategy="end", # Push only at end for demo
42
 
43
+ # Training parameters (minimal for quick demo)
44
+ max_steps=20, # Very short training
45
  per_device_train_batch_size=2,
46
  gradient_accumulation_steps=2,
47
  learning_rate=2e-5,
48
 
49
  # Logging
50
  logging_steps=5,
51
+ save_strategy="no", # Don't save checkpoints during training
 
52
 
53
  # Optimization
54
+ warmup_steps=5,
55
  lr_scheduler_type="cosine",
56
 
57
  # Monitoring
58
  report_to="trackio",
59
  )
60
 
61
+ # LoRA configuration (reduces memory usage)
62
  peft_config = LoraConfig(
63
+ r=8, # Small rank for demo
64
+ lora_alpha=16,
65
  lora_dropout=0.05,
66
  bias="none",
67
  task_type="CAUSAL_LM",
68
  target_modules=["q_proj", "v_proj"],
69
  )
70
 
71
+ # Initialize trainer
72
+ print("πŸš€ Initializing trainer...")
73
  trainer = SFTTrainer(
74
  model="Qwen/Qwen2.5-0.5B",
75
  train_dataset=dataset,
 
77
  peft_config=peft_config,
78
  )
79
 
80
+ # Train
81
+ print("πŸ”₯ Starting training (20 steps)...")
82
  trainer.train()
83
 
84
+ # Push to Hub
85
+ print("πŸ’Ύ Pushing model to Hub...")
86
  trainer.push_to_hub()
87
 
88
  # Finish Trackio tracking
89
  trackio.finish()
90
 
91
+ print("βœ… Training complete!")
92
+ print(f"πŸ“¦ Model: https://huggingface.co/evalstate/qwen-demo-sft")
93
+ print(f"πŸ“Š Metrics: https://huggingface.co/spaces/evalstate/trackio-demo")