Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,7 +66,10 @@ is_cuda = torch.cuda.is_available()
|
|
| 66 |
device = torch.device("cuda" if is_cuda else "cpu")
|
| 67 |
print(device)
|
| 68 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False, download_root='./') # Must set jit=False for training
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
clip_model.eval()
|
| 71 |
for p in clip_model.parameters():
|
| 72 |
p.requires_grad = False
|
|
@@ -101,15 +104,16 @@ print('loading transformer checkpoint from {}'.format(args.resume_trans))
|
|
| 101 |
ckpt = torch.load(args.resume_trans, map_location='cpu')
|
| 102 |
trans_encoder.load_state_dict(ckpt['trans'], strict=True)
|
| 103 |
trans_encoder.eval()
|
|
|
|
| 104 |
mean = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/mean.npy'))
|
| 105 |
std = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/std.npy'))
|
|
|
|
| 106 |
if is_cuda:
|
| 107 |
net.cuda()
|
| 108 |
trans_encoder.cuda()
|
| 109 |
mean = mean.cuda()
|
| 110 |
std = std.cuda()
|
| 111 |
|
| 112 |
-
|
| 113 |
def render(motions, device_id=0, name='test_vis'):
|
| 114 |
frames, njoints, nfeats = motions.shape
|
| 115 |
MINS = motions.min(axis=0).min(axis=0)
|
|
|
|
| 66 |
device = torch.device("cuda" if is_cuda else "cpu")
|
| 67 |
print(device)
|
| 68 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False, download_root='./') # Must set jit=False for training
|
| 69 |
+
|
| 70 |
+
if is_cuda:
|
| 71 |
+
clip.model.convert_weights(clip_model)
|
| 72 |
+
|
| 73 |
clip_model.eval()
|
| 74 |
for p in clip_model.parameters():
|
| 75 |
p.requires_grad = False
|
|
|
|
| 104 |
ckpt = torch.load(args.resume_trans, map_location='cpu')
|
| 105 |
trans_encoder.load_state_dict(ckpt['trans'], strict=True)
|
| 106 |
trans_encoder.eval()
|
| 107 |
+
|
| 108 |
mean = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/mean.npy'))
|
| 109 |
std = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/std.npy'))
|
| 110 |
+
|
| 111 |
if is_cuda:
|
| 112 |
net.cuda()
|
| 113 |
trans_encoder.cuda()
|
| 114 |
mean = mean.cuda()
|
| 115 |
std = std.cuda()
|
| 116 |
|
|
|
|
| 117 |
def render(motions, device_id=0, name='test_vis'):
|
| 118 |
frames, njoints, nfeats = motions.shape
|
| 119 |
MINS = motions.min(axis=0).min(axis=0)
|