| import gc | |
| import inspect | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| logger = get_logger(__name__) | |
| def reset_memory(device: Union[str, torch.device]) -> None: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats(device) | |
| torch.cuda.reset_accumulated_memory_stats(device) | |
| def print_memory(device: Union[str, torch.device]) -> None: | |
| memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 | |
| max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 | |
| max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 | |
| print(f"{memory_allocated=:.3f} GB") | |
| print(f"{max_memory_allocated=:.3f} GB") | |
| print(f"{max_memory_reserved=:.3f} GB") | |