Commit
·
a33f468
1
Parent(s):
6f44efe
Revise to be consistent with the paper
Browse files- viz/renderer.py +3 -3
viz/renderer.py
CHANGED
|
@@ -379,14 +379,14 @@ class Renderer:
|
|
| 379 |
distance < round(r1 / 512 * h))
|
| 380 |
direction = direction / \
|
| 381 |
(torch.linalg.norm(direction) + 1e-7)
|
| 382 |
-
gridh = (relis
|
| 383 |
-
gridw = (reljs
|
| 384 |
grid = torch.stack(
|
| 385 |
[gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
| 386 |
target = F.grid_sample(
|
| 387 |
feat_resize.float(), grid, align_corners=True).squeeze(2)
|
| 388 |
loss_motion += F.l1_loss(
|
| 389 |
-
feat_resize[:, :, relis, reljs]
|
| 390 |
|
| 391 |
loss = loss_motion
|
| 392 |
if mask is not None:
|
|
|
|
| 379 |
distance < round(r1 / 512 * h))
|
| 380 |
direction = direction / \
|
| 381 |
(torch.linalg.norm(direction) + 1e-7)
|
| 382 |
+
gridh = (relis+direction[1]) / (h-1) * 2 - 1
|
| 383 |
+
gridw = (reljs+direction[0]) / (w-1) * 2 - 1
|
| 384 |
grid = torch.stack(
|
| 385 |
[gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
| 386 |
target = F.grid_sample(
|
| 387 |
feat_resize.float(), grid, align_corners=True).squeeze(2)
|
| 388 |
loss_motion += F.l1_loss(
|
| 389 |
+
feat_resize[:, :, relis, reljs].detach(), target)
|
| 390 |
|
| 391 |
loss = loss_motion
|
| 392 |
if mask is not None:
|