Commit
·
a795b9b
1
Parent(s):
6ded867
add log to check whether chunking is working
Browse files
cosmos_transfer1/diffusion/model/model_v2w.py
CHANGED
|
@@ -249,6 +249,7 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
| 249 |
assert condition_latent is not None, "condition_latent should be provided"
|
| 250 |
|
| 251 |
# try to add chunking here !!!
|
|
|
|
| 252 |
x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
|
| 253 |
data_batch,
|
| 254 |
guidance,
|
|
@@ -312,6 +313,8 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
| 312 |
Function that takes noisy input and noise level and returns denoised prediction
|
| 313 |
"""
|
| 314 |
if chunking is None:
|
|
|
|
|
|
|
| 315 |
if is_negative_prompt:
|
| 316 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
| 317 |
else:
|
|
@@ -347,6 +350,8 @@ class DiffusionV2WModel(DiffusionT2WModel):
|
|
| 347 |
|
| 348 |
return x0_fn
|
| 349 |
else:
|
|
|
|
|
|
|
| 350 |
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 351 |
if is_negative_prompt:
|
| 352 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
|
|
|
| 249 |
assert condition_latent is not None, "condition_latent should be provided"
|
| 250 |
|
| 251 |
# try to add chunking here !!!
|
| 252 |
+
log.info("x0_fn")
|
| 253 |
x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
|
| 254 |
data_batch,
|
| 255 |
guidance,
|
|
|
|
| 313 |
Function that takes noisy input and noise level and returns denoised prediction
|
| 314 |
"""
|
| 315 |
if chunking is None:
|
| 316 |
+
log.info("no chunking")
|
| 317 |
+
|
| 318 |
if is_negative_prompt:
|
| 319 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
| 320 |
else:
|
|
|
|
| 350 |
|
| 351 |
return x0_fn
|
| 352 |
else:
|
| 353 |
+
log.info("chunking !!!")
|
| 354 |
+
|
| 355 |
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 356 |
if is_negative_prompt:
|
| 357 |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|