Spaces:
Build error
Build error
| import os | |
| from typing import Optional, Literal | |
| from types import ModuleType | |
| import enum | |
| from packaging import version | |
| import torch | |
| # collect system information | |
| if version.parse(torch.__version__) >= version.parse("2.0.0"): | |
| SDP_IS_AVAILABLE = True | |
| else: | |
| SDP_IS_AVAILABLE = False | |
| try: | |
| import xformers | |
| import xformers.ops | |
| XFORMERS_IS_AVAILBLE = True | |
| except: | |
| XFORMERS_IS_AVAILBLE = False | |
| class AttnMode(enum.Enum): | |
| SDP = 0 | |
| XFORMERS = 1 | |
| VANILLA = 2 | |
| class Config: | |
| xformers: Optional[ModuleType] = None | |
| attn_mode: AttnMode = AttnMode.VANILLA | |
| # initialize attention mode | |
| if XFORMERS_IS_AVAILBLE: | |
| Config.attn_mode = AttnMode.XFORMERS | |
| print(f"use xformers attention as default") | |
| elif SDP_IS_AVAILABLE: | |
| Config.attn_mode = AttnMode.SDP | |
| print(f"use sdp attention as default") | |
| else: | |
| print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") | |
| if XFORMERS_IS_AVAILBLE: | |
| Config.xformers = xformers | |
| # user-specified attention mode | |
| ATTN_MODE = os.environ.get("ATTN_MODE", None) | |
| if ATTN_MODE is not None: | |
| assert ATTN_MODE in ["vanilla", "sdp", "xformers"] | |
| if ATTN_MODE == "sdp": | |
| assert SDP_IS_AVAILABLE | |
| Config.attn_mode = AttnMode.SDP | |
| elif ATTN_MODE == "xformers": | |
| assert XFORMERS_IS_AVAILBLE | |
| Config.attn_mode = AttnMode.XFORMERS | |
| else: | |
| Config.attn_mode = AttnMode.VANILLA | |
| print(f"set attention mode to {ATTN_MODE}") | |
| else: | |
| print("keep default attention mode") | |