Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -171,25 +171,60 @@ ASPECT_RATIOS = {
|
|
| 171 |
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
-
|
| 175 |
-
|
| 176 |
"""
|
| 177 |
ar_config = ASPECT_RATIOS[aspect_ratio]
|
| 178 |
-
latent_h = ar_config["latent_h"]
|
| 179 |
latent_w = ar_config["latent_w"]
|
|
|
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
# These need to be 5D tensors: (batch, channels, time, height, width)
|
| 183 |
cache = []
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
cache.append(torch.zeros(
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
return cache
|
| 192 |
|
|
|
|
| 193 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 194 |
"""
|
| 195 |
Convert frames directly to .ts file using PyAV.
|
|
|
|
| 171 |
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
+
根据不同的长宽比,生成符合 VAE 解码器缓存格式的零张量缓存。
|
| 175 |
+
缓存张量格式必须与 ZERO_VAE_CACHE 保持一致: [batch, time, channels, height, width]
|
| 176 |
"""
|
| 177 |
ar_config = ASPECT_RATIOS[aspect_ratio]
|
|
|
|
| 178 |
latent_w = ar_config["latent_w"]
|
| 179 |
+
latent_h = ar_config["latent_h"]
|
| 180 |
|
| 181 |
+
# 这里 time 维度初始化为 1,channels 对应各级别的通道数
|
|
|
|
| 182 |
cache = []
|
| 183 |
|
| 184 |
+
# 第一级特征,channels=512,下采样 8 倍
|
| 185 |
+
cache.append(torch.zeros(
|
| 186 |
+
1, # batch size
|
| 187 |
+
1, # time frames
|
| 188 |
+
512, # channels
|
| 189 |
+
latent_h // 8, # height
|
| 190 |
+
latent_w // 8, # width
|
| 191 |
+
device=device,
|
| 192 |
+
dtype=dtype
|
| 193 |
+
))
|
| 194 |
+
# 第二级特征,channels=512,下采样 4 倍
|
| 195 |
+
cache.append(torch.zeros(
|
| 196 |
+
1,
|
| 197 |
+
1,
|
| 198 |
+
512,
|
| 199 |
+
latent_h // 4,
|
| 200 |
+
latent_w // 4,
|
| 201 |
+
device=device,
|
| 202 |
+
dtype=dtype
|
| 203 |
+
))
|
| 204 |
+
# 第三级特征,channels=256,下采样 2 倍
|
| 205 |
+
cache.append(torch.zeros(
|
| 206 |
+
1,
|
| 207 |
+
1,
|
| 208 |
+
256,
|
| 209 |
+
latent_h // 2,
|
| 210 |
+
latent_w // 2,
|
| 211 |
+
device=device,
|
| 212 |
+
dtype=dtype
|
| 213 |
+
))
|
| 214 |
+
# 第四级特征,channels=128,不下采样
|
| 215 |
+
cache.append(torch.zeros(
|
| 216 |
+
1,
|
| 217 |
+
1,
|
| 218 |
+
128,
|
| 219 |
+
latent_h,
|
| 220 |
+
latent_w,
|
| 221 |
+
device=device,
|
| 222 |
+
dtype=dtype
|
| 223 |
+
))
|
| 224 |
|
| 225 |
return cache
|
| 226 |
|
| 227 |
+
|
| 228 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 229 |
"""
|
| 230 |
Convert frames directly to .ts file using PyAV.
|