Spaces:
Running
on
Zero
Running
on
Zero
Update src/mdx.py
Browse files- src/mdx.py +3 -1
src/mdx.py
CHANGED
|
@@ -166,6 +166,8 @@ class MDX:
|
|
| 166 |
waves = np.array(wave_p[:, i:i + self.model.chunk_size])
|
| 167 |
mix_waves.append(waves)
|
| 168 |
|
|
|
|
|
|
|
| 169 |
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
|
| 170 |
|
| 171 |
return mix_waves, pad, trim
|
|
@@ -240,7 +242,7 @@ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False,
|
|
| 240 |
|
| 241 |
#device_properties = torch.cuda.get_device_properties(device)
|
| 242 |
print("Device", device)
|
| 243 |
-
vram_gb =
|
| 244 |
m_threads = 1 if vram_gb < 8 else 2
|
| 245 |
|
| 246 |
model_hash = MDX.get_hash(model_path)
|
|
|
|
| 166 |
waves = np.array(wave_p[:, i:i + self.model.chunk_size])
|
| 167 |
mix_waves.append(waves)
|
| 168 |
|
| 169 |
+
print(self.device)
|
| 170 |
+
|
| 171 |
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
|
| 172 |
|
| 173 |
return mix_waves, pad, trim
|
|
|
|
| 242 |
|
| 243 |
#device_properties = torch.cuda.get_device_properties(device)
|
| 244 |
print("Device", device)
|
| 245 |
+
vram_gb = 12 #device_properties.total_memory / 1024**3
|
| 246 |
m_threads = 1 if vram_gb < 8 else 2
|
| 247 |
|
| 248 |
model_hash = MDX.get_hash(model_path)
|