rahul7star commited on
Commit
ff77368
·
verified ·
1 Parent(s): 4ba54ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -18,10 +18,42 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  def check_gpu_status():
20
  return f"✅ GPU: {torch.cuda.get_device_name(0)}" if device == "cuda" else "⚠️ Using CPU only"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ---------------------------------------------------------------------
23
  # Training Logic
24
  # ---------------------------------------------------------------------
 
25
  def train_model(model_name, num_epochs, batch_size, learning_rate, progress=gr.Progress(track_tqdm=True)):
26
  output_log = []
27
 
 
18
 
19
  def check_gpu_status():
20
  return f"✅ GPU: {torch.cuda.get_device_name(0)}" if device == "cuda" else "⚠️ Using CPU only"
21
+ # ------------------------------------------------------
22
+ # 🧩 Download Dataset to /tmp/
23
+ # ------------------------------------------------------
24
+ def download_gita_dataset():
25
+ repo_id = "rahul7star/Gita"
26
+ local_dir = "/tmp/gita_data"
27
+
28
+ if os.path.exists(local_dir):
29
+ shutil.rmtree(local_dir)
30
+ os.makedirs(local_dir, exist_ok=True)
31
+
32
+ print(f"📥 Downloading dataset from {repo_id} ...")
33
+ snapshot_download(repo_id=repo_id, local_dir=local_dir, repo_type="dataset")
34
+
35
+ # Try to locate the CSV file
36
+ csv_path = None
37
+ for root, _, files in os.walk(local_dir):
38
+ for f in files:
39
+ if f.lower().endswith(".csv"):
40
+ csv_path = os.path.join(root, f)
41
+ break
42
+ if not csv_path:
43
+ raise FileNotFoundError("No CSV file found in the Gita dataset repository.")
44
+
45
+ print(f"✅ Found CSV: {csv_path}")
46
+ return csv_path
47
+
48
+
49
+ # ------------------------------------------------------
50
+ # 🚀 Training function
51
+ # ------------------------------------------------------
52
 
53
  # ---------------------------------------------------------------------
54
  # Training Logic
55
  # ---------------------------------------------------------------------
56
+ @spaces.GPU(duration=300)
57
  def train_model(model_name, num_epochs, batch_size, learning_rate, progress=gr.Progress(track_tqdm=True)):
58
  output_log = []
59