Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -107,9 +107,8 @@ class ModelManager:
|
|
| 107 |
self.realesrgan_x2 = None
|
| 108 |
self.realesrgan_x4 = None
|
| 109 |
|
| 110 |
-
def load_models(self
|
| 111 |
-
if self.pipe is None:
|
| 112 |
-
progress(0, desc="Loading Stable Diffusion pipeline...")
|
| 113 |
self.pipe = self.setup_pipeline()
|
| 114 |
self.pipe.to(device)
|
| 115 |
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
|
|
@@ -118,17 +117,14 @@ class ModelManager:
|
|
| 118 |
progress(0.5, desc="Compiling the model...")
|
| 119 |
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 120 |
|
| 121 |
-
if self.realesrgan_x2 is None:
|
| 122 |
-
progress(0.7, desc="Loading RealESRGAN x2 model...")
|
| 123 |
self.realesrgan_x2 = RealESRGAN(device, scale=2)
|
| 124 |
self.realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
|
| 125 |
|
| 126 |
-
if self.realesrgan_x4 is None:
|
| 127 |
-
progress(0.9, desc="Loading RealESRGAN x4 model...")
|
| 128 |
self.realesrgan_x4 = RealESRGAN(device, scale=4)
|
| 129 |
self.realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
|
| 130 |
|
| 131 |
-
progress(1.0, desc="All models loaded successfully")
|
| 132 |
|
| 133 |
def setup_pipeline(self):
|
| 134 |
controlnet = ControlNetModel.from_single_file(
|
|
@@ -233,6 +229,7 @@ class ModelManager:
|
|
| 233 |
return hdr_result
|
| 234 |
|
| 235 |
model_manager = ModelManager()
|
|
|
|
| 236 |
|
| 237 |
def extract_frames(video_path, output_folder):
|
| 238 |
os.makedirs(output_folder, exist_ok=True)
|
|
@@ -274,11 +271,12 @@ def frames_to_video(input_folder, output_path, fps, original_video_path):
|
|
| 274 |
# Remove the temporary file
|
| 275 |
os.remove(temp_output_path)
|
| 276 |
|
|
|
|
| 277 |
@timer_func
|
| 278 |
def process_video(input_video, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames=None, frame_interval=1, preserve_frames=False, batch_size=8, progress=gr.Progress()):
|
| 279 |
abort_event.clear() # Clear the abort flag at the start of a new job
|
| 280 |
print("Starting video processing...")
|
| 281 |
-
|
| 282 |
|
| 283 |
# Create a new job folder
|
| 284 |
job_id = str(uuid.uuid4())
|
|
|
|
| 107 |
self.realesrgan_x2 = None
|
| 108 |
self.realesrgan_x4 = None
|
| 109 |
|
| 110 |
+
def load_models(self):
|
| 111 |
+
if self.pipe is None:
|
|
|
|
| 112 |
self.pipe = self.setup_pipeline()
|
| 113 |
self.pipe.to(device)
|
| 114 |
self.pipe.unet.set_attn_processor(AttnProcessor2_0())
|
|
|
|
| 117 |
progress(0.5, desc="Compiling the model...")
|
| 118 |
self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True)
|
| 119 |
|
| 120 |
+
if self.realesrgan_x2 is None:
|
|
|
|
| 121 |
self.realesrgan_x2 = RealESRGAN(device, scale=2)
|
| 122 |
self.realesrgan_x2.load_weights('models/upscalers/RealESRGAN_x2.pth', download=False)
|
| 123 |
|
| 124 |
+
if self.realesrgan_x4 is None:
|
|
|
|
| 125 |
self.realesrgan_x4 = RealESRGAN(device, scale=4)
|
| 126 |
self.realesrgan_x4.load_weights('models/upscalers/RealESRGAN_x4.pth', download=False)
|
| 127 |
|
|
|
|
| 128 |
|
| 129 |
def setup_pipeline(self):
|
| 130 |
controlnet = ControlNetModel.from_single_file(
|
|
|
|
| 229 |
return hdr_result
|
| 230 |
|
| 231 |
model_manager = ModelManager()
|
| 232 |
+
model_manager.load_models() # Ensure models are loaded
|
| 233 |
|
| 234 |
def extract_frames(video_path, output_folder):
|
| 235 |
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
| 271 |
# Remove the temporary file
|
| 272 |
os.remove(temp_output_path)
|
| 273 |
|
| 274 |
+
|
| 275 |
@timer_func
|
| 276 |
def process_video(input_video, resolution, num_inference_steps, strength, hdr, guidance_scale, max_frames=None, frame_interval=1, preserve_frames=False, batch_size=8, progress=gr.Progress()):
|
| 277 |
abort_event.clear() # Clear the abort flag at the start of a new job
|
| 278 |
print("Starting video processing...")
|
| 279 |
+
|
| 280 |
|
| 281 |
# Create a new job folder
|
| 282 |
job_id = str(uuid.uuid4())
|