Correct the forward call
Browse files
model.py
CHANGED
|
@@ -13,7 +13,7 @@ class CSDModel(PreTrainedModel):
|
|
| 13 |
|
| 14 |
@torch.inference_mode()
|
| 15 |
def forward(self, pixel_values):
|
| 16 |
-
features = self.backbone(pixel_values)
|
| 17 |
style_embeds = self.out_style(features)
|
| 18 |
content_embeds = self.out_content(features)
|
| 19 |
return features, style_embeds, content_embeds
|
|
|
|
| 13 |
|
| 14 |
@torch.inference_mode()
|
| 15 |
def forward(self, pixel_values):
|
| 16 |
+
features = self.backbone(pixel_values, return_dict=False)[1]
|
| 17 |
style_embeds = self.out_style(features)
|
| 18 |
content_embeds = self.out_content(features)
|
| 19 |
return features, style_embeds, content_embeds
|