Spaces:
Build error
Build error
Fix
Browse files- trainer.py +4 -7
trainer.py
CHANGED
|
@@ -67,13 +67,12 @@ class Trainer:
|
|
| 67 |
gradient_accumulation: int,
|
| 68 |
fp16: bool,
|
| 69 |
use_8bit_adam: bool,
|
| 70 |
-
) -> tuple[dict,
|
| 71 |
if not torch.cuda.is_available():
|
| 72 |
raise gr.Error('CUDA is not available.')
|
| 73 |
|
| 74 |
-
out_path = ''
|
| 75 |
if self.is_running:
|
| 76 |
-
return gr.update(value=self.is_running_message),
|
| 77 |
|
| 78 |
if concept_images is None:
|
| 79 |
raise gr.Error('You need to upload images.')
|
|
@@ -116,9 +115,7 @@ class Trainer:
|
|
| 116 |
|
| 117 |
if res.returncode == 0:
|
| 118 |
result_message = 'Training Completed!'
|
| 119 |
-
weight_path = self.output_dir / 'lora_weight.pt'
|
| 120 |
-
if weight_path.exists():
|
| 121 |
-
out_path = weight_path.as_posix()
|
| 122 |
else:
|
| 123 |
result_message = 'Training Failed!'
|
| 124 |
-
|
|
|
|
|
|
| 67 |
gradient_accumulation: int,
|
| 68 |
fp16: bool,
|
| 69 |
use_8bit_adam: bool,
|
| 70 |
+
) -> tuple[dict, list[pathlib.Path]]:
|
| 71 |
if not torch.cuda.is_available():
|
| 72 |
raise gr.Error('CUDA is not available.')
|
| 73 |
|
|
|
|
| 74 |
if self.is_running:
|
| 75 |
+
return gr.update(value=self.is_running_message), []
|
| 76 |
|
| 77 |
if concept_images is None:
|
| 78 |
raise gr.Error('You need to upload images.')
|
|
|
|
| 115 |
|
| 116 |
if res.returncode == 0:
|
| 117 |
result_message = 'Training Completed!'
|
|
|
|
|
|
|
|
|
|
| 118 |
else:
|
| 119 |
result_message = 'Training Failed!'
|
| 120 |
+
weight_paths = sorted(self.output_dir.glob('*.pt'))
|
| 121 |
+
return gr.update(value=result_message), weight_paths
|