| from subprocess import check_output | |
| from threading import Timer | |
| from typing import Callable, List | |
| def get_gpu_memory() -> List[int]: | |
| """ | |
| Get the free GPU memory (VRAM) in MiB | |
| :return memory_free_values: List of free GPU memory (VRAM) in MiB | |
| """ | |
| command = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" | |
| memory_free_info = check_output(command.split()).decode("ascii").replace("\r", "").split("\n")[:-1] | |
| memory_free_values = list(map(int, memory_free_info)) | |
| return memory_free_values | |
| class RepeatingTimer(Timer): | |
| def run(self): | |
| self.finished.wait(self.interval) | |
| while not self.finished.is_set(): | |
| self.function(*self.args, **self.kwargs) | |
| self.finished.wait(self.interval) | |
| gpu_memory_watcher: RepeatingTimer = None | |
| def watch_gpu_memory(interval: int = 1, callback: Callable[[List[int]], None] = None) -> RepeatingTimer: | |
| """ | |
| Start a repeating timer to watch the GPU memory usage | |
| :param interval: Interval in seconds | |
| :return timer: RepeatingTimer object | |
| """ | |
| global gpu_memory_watcher | |
| if gpu_memory_watcher is not None: | |
| raise RuntimeError("GPU memory watcher is already running") | |
| if callback is None: | |
| callback = print | |
| gpu_memory_watcher = RepeatingTimer(interval, lambda: callback(get_gpu_memory())) | |
| gpu_memory_watcher.start() | |
| return gpu_memory_watcher | |
| if __name__ == "__main__": | |
| from time import sleep | |
| t = watch_gpu_memory() | |
| counter = 0 | |
| while True: | |
| sleep(1) | |
| counter += 1 | |
| if counter == 10: | |
| try: | |
| watch_gpu_memory() | |
| except RuntimeError: | |
| print("Got exception") | |
| pass | |
| elif counter >= 20: | |
| gpu_memory_watcher.cancel() | |
| break | |