Eueuiaa commited on
Commit
8fd9bdd
·
verified ·
1 Parent(s): db24d9a

Create gpu_manager.py

Browse files
Files changed (1) hide show
  1. api/gpu_manager.py +56 -0
api/gpu_manager.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/gpu_manager.py
2
+
3
+ import os
4
+ import torch
5
+
6
+ class GPUManager:
7
+ """
8
+ Gerencia e aloca GPUs disponíveis para diferentes serviços (LTX, SeedVR).
9
+ """
10
+ def __init__(self):
11
+ self.total_gpus = torch.cuda.device_count()
12
+ self.ltx_gpus = []
13
+ self.seedvr_gpus = []
14
+ self._allocate_gpus()
15
+
16
+ def _allocate_gpus(self):
17
+ """
18
+ Divide as GPUs disponíveis entre os serviços LTX e SeedVR.
19
+ """
20
+ print("="*50)
21
+ print("🤖 Gerenciador de GPUs inicializado.")
22
+ print(f" > Total de GPUs detectadas: {self.total_gpus}")
23
+
24
+ if self.total_gpus == 0:
25
+ print(" > Nenhuma GPU detectada. Operando em modo CPU.")
26
+ elif self.total_gpus == 1:
27
+ print(" > 1 GPU detectada. Modo de compartilhamento de memória será usado.")
28
+ # Ambos usarão a GPU 0, mas precisarão gerenciar a memória
29
+ self.ltx_gpus = [0]
30
+ self.seedvr_gpus = [0]
31
+ else:
32
+ # Divide as GPUs entre os dois serviços
33
+ mid_point = self.total_gpus // 2
34
+ self.ltx_gpus = list(range(0, mid_point))
35
+ self.seedvr_gpus = list(range(mid_point, self.total_gpus))
36
+ print(f" > Alocação: LTX usará GPUs {self.ltx_gpus}, SeedVR usará GPUs {self.seedvr_gpus}.")
37
+
38
+ print("="*50)
39
+
40
+ def get_ltx_device(self):
41
+ """Retorna o dispositivo principal para o LTX."""
42
+ if not self.ltx_gpus:
43
+ return torch.device("cpu")
44
+ # Por padrão, o modelo principal do LTX roda na primeira GPU do seu grupo
45
+ return torch.device(f"cuda:{self.ltx_gpus[0]}")
46
+
47
+ def get_seedvr_devices(self) -> list:
48
+ """Retorna a lista de IDs de GPU para o SeedVR."""
49
+ return self.seedvr_gpus
50
+
51
+ def requires_memory_swap(self) -> bool:
52
+ """Verifica se é necessário mover modelos entre CPU e GPU."""
53
+ return self.total_gpus < 2
54
+
55
+ # Instância global para ser importada por outros módulos
56
+ gpu_manager = GPUManager()