Spaces:
Runtime error
Runtime error
| import gc | |
| import inspect | |
| import threading | |
| from functools import wraps | |
| import torch | |
| from ia_check_versions import ia_check_versions | |
| model_access_sem = threading.Semaphore(1) | |
| def torch_gc(): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| if ia_check_versions.torch_mps_is_available: | |
| if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): | |
| torch.mps.empty_cache() | |
| def clear_cache(): | |
| gc.collect() | |
| torch_gc() | |
| def post_clear_cache(sem): | |
| with sem: | |
| gc.collect() | |
| torch_gc() | |
| def async_post_clear_cache(): | |
| thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,)) | |
| thread.start() | |
| def clear_cache_decorator(func): | |
| def yield_wrapper(*args, **kwargs): | |
| clear_cache() | |
| yield from func(*args, **kwargs) | |
| clear_cache() | |
| def wrapper(*args, **kwargs): | |
| clear_cache() | |
| res = func(*args, **kwargs) | |
| clear_cache() | |
| return res | |
| if inspect.isgeneratorfunction(func): | |
| return yield_wrapper | |
| else: | |
| return wrapper | |