Spaces:
Running
on
T4
Running
on
T4
Update Modules/ControllabilityGAN/wgan/wgan_qc.py
Browse files
Modules/ControllabilityGAN/wgan/wgan_qc.py
CHANGED
|
@@ -237,9 +237,9 @@ class WassersteinGanQuadraticCost:
|
|
| 237 |
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
|
| 238 |
self.G.eval()
|
| 239 |
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
| 240 |
-
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim)
|
| 241 |
else:
|
| 242 |
-
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
|
| 243 |
latent_samples = latent_samples.to(self.device)
|
| 244 |
if nograd:
|
| 245 |
with torch.no_grad():
|
|
|
|
| 237 |
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
|
| 238 |
self.G.eval()
|
| 239 |
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
| 240 |
+
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0)
|
| 241 |
else:
|
| 242 |
+
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0)
|
| 243 |
latent_samples = latent_samples.to(self.device)
|
| 244 |
if nograd:
|
| 245 |
with torch.no_grad():
|