Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from .lora import LoraInjectedLinear | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| def test_xformers_backwards(size): | |
| def _grad(size): | |
| q = torch.randn((1, 4, size), device="cuda") | |
| k = torch.randn((1, 4, size), device="cuda") | |
| v = torch.randn((1, 4, size), device="cuda") | |
| q = q.detach().requires_grad_() | |
| k = k.detach().requires_grad_() | |
| v = v.detach().requires_grad_() | |
| out = xformers.ops.memory_efficient_attention(q, k, v) | |
| loss = out.sum(2).mean(0).sum() | |
| return torch.autograd.grad(loss, v) | |
| try: | |
| _grad(size) | |
| print(size, "pass") | |
| return True | |
| except Exception as e: | |
| print(size, "fail") | |
| return False | |
| def set_use_memory_efficient_attention_xformers( | |
| module: torch.nn.Module, valid: bool | |
| ) -> None: | |
| def fn_test_dim_head(module: torch.nn.Module): | |
| if isinstance(module, BasicTransformerBlock): | |
| # dim_head isn't stored anywhere, so back-calculate | |
| source = module.attn1.to_v | |
| if isinstance(source, LoraInjectedLinear): | |
| source = source.linear | |
| dim_head = source.out_features // module.attn1.heads | |
| result = test_xformers_backwards(dim_head) | |
| # If dim_head > dim_head_max, turn xformers off | |
| if not result: | |
| module.set_use_memory_efficient_attention_xformers(False) | |
| for child in module.children(): | |
| fn_test_dim_head(child) | |
| if not is_xformers_available() and valid: | |
| print("XFormers is not available. Skipping.") | |
| return | |
| module.set_use_memory_efficient_attention_xformers(valid) | |
| if valid: | |
| fn_test_dim_head(module) | |