Spaces:
Paused
Paused
Update hymm_sp/modules/models_audio.py
Browse files
hymm_sp/modules/models_audio.py
CHANGED
|
@@ -166,39 +166,7 @@ class DoubleStreamBlock(nn.Module):
|
|
| 166 |
v = torch.cat((img_v, txt_v), dim=1)
|
| 167 |
|
| 168 |
# Compute attention.
|
| 169 |
-
|
| 170 |
-
assert cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
| 171 |
-
|
| 172 |
-
q, k, v = [
|
| 173 |
-
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
| 174 |
-
for x in [q, k, v]
|
| 175 |
-
]
|
| 176 |
-
attn = None
|
| 177 |
-
|
| 178 |
-
# attn = flash_attn_varlen_func(
|
| 179 |
-
# q,
|
| 180 |
-
# k,
|
| 181 |
-
# v,
|
| 182 |
-
# cu_seqlens_q,
|
| 183 |
-
# cu_seqlens_kv,
|
| 184 |
-
# max_seqlen_q,
|
| 185 |
-
# max_seqlen_kv,
|
| 186 |
-
# )
|
| 187 |
-
attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
|
| 188 |
-
else:
|
| 189 |
-
# attn, _ = parallel_attention(
|
| 190 |
-
# (img_q, txt_q),
|
| 191 |
-
# (img_k, txt_k),
|
| 192 |
-
# (img_v, txt_v),
|
| 193 |
-
# img_q_len=img_q.shape[1],
|
| 194 |
-
# img_kv_len=img_k.shape[1],
|
| 195 |
-
# cu_seqlens_q=cu_seqlens_q,
|
| 196 |
-
# cu_seqlens_kv=cu_seqlens_kv,
|
| 197 |
-
# max_seqlen_q=max_seqlen_q,
|
| 198 |
-
# max_seqlen_kv=max_seqlen_kv,
|
| 199 |
-
# )
|
| 200 |
-
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
|
| 201 |
-
|
| 202 |
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
| 203 |
|
| 204 |
# Calculate the img bloks.
|
|
|
|
| 166 |
v = torch.cat((img_v, txt_v), dim=1)
|
| 167 |
|
| 168 |
# Compute attention.
|
| 169 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
| 171 |
|
| 172 |
# Calculate the img bloks.
|