nazdridoy commited on
Commit
d80127a
·
verified ·
1 Parent(s): 6244d01

refactor(video-output): coerce video output for Gradio

Browse files

- [add] Import `tempfile` and `io` modules (video_handler.py:6-7)
- [refactor] Apply `_coerce_video_output` to video before returning (video_handler.py:89,92)
- [add] Implement `_coerce_video_output` for various video return types (video_handler.py:154-180)
- [add] Add `_write_temp_video` to save video data to a temporary file (video_handler.py:183-191)
- [add] Add `_guess_video_suffix` to infer video file extensions from data (video_handler.py:194-203)

Files changed (1) hide show
  1. video_handler.py +59 -1
video_handler.py CHANGED
@@ -5,6 +5,8 @@ Handles text-to-video generation with multiple providers.
5
 
6
  import os
7
  import gradio as gr
 
 
8
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
9
  from huggingface_hub import InferenceClient
10
  from huggingface_hub.errors import HfHubHTTPError
@@ -86,11 +88,14 @@ def generate_video(
86
 
87
  print(f"🎞️ Video: Generation completed! Type: {type(video)}")
88
 
 
 
 
89
  # Report successful token usage
90
  if token_id:
91
  report_token_status(token_id, "success", api_key=proxy_api_key)
92
 
93
- return video, format_success_message("Video generated", f"using {model_name} on {provider}")
94
 
95
  except ConnectionError as e:
96
  error_msg = f"Cannot connect to HF-Inferoxy server: {str(e)}"
@@ -150,3 +155,56 @@ def handle_video_generation(prompt_val, model_val, provider_val, steps_val, guid
150
  )
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import os
7
  import gradio as gr
8
+ import tempfile
9
+ import io
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
11
  from huggingface_hub import InferenceClient
12
  from huggingface_hub.errors import HfHubHTTPError
 
88
 
89
  print(f"🎞️ Video: Generation completed! Type: {type(video)}")
90
 
91
+ # Convert output to a path or URL Gradio can handle
92
+ video_output = _coerce_video_output(video)
93
+
94
  # Report successful token usage
95
  if token_id:
96
  report_token_status(token_id, "success", api_key=proxy_api_key)
97
 
98
+ return video_output, format_success_message("Video generated", f"using {model_name} on {provider}")
99
 
100
  except ConnectionError as e:
101
  error_msg = f"Cannot connect to HF-Inferoxy server: {str(e)}"
 
155
  )
156
 
157
 
158
+ def _coerce_video_output(value):
159
+ """Coerce various return types (bytes, str path/URL, BytesIO) into a filepath/URL for gr.Video."""
160
+ # Case 1: Direct URL or existing file path
161
+ if isinstance(value, str):
162
+ if value.startswith("http://") or value.startswith("https://"):
163
+ return value
164
+ if os.path.exists(value):
165
+ return value
166
+ # Unknown string; fall through to save as file
167
+
168
+ # Case 2: Bytes-like content
169
+ if isinstance(value, (bytes, bytearray)):
170
+ data = bytes(value)
171
+ suffix = _guess_video_suffix(data)
172
+ return _write_temp_video(data, suffix)
173
+
174
+ # Case 3: File-like object
175
+ if isinstance(value, io.IOBase) or hasattr(value, "read"):
176
+ try:
177
+ data = value.read()
178
+ if isinstance(data, (bytes, bytearray)):
179
+ suffix = _guess_video_suffix(data)
180
+ return _write_temp_video(bytes(data), suffix)
181
+ except Exception:
182
+ pass
183
+
184
+ # Fallback: save string representation for debugging
185
+ debug_bytes = str(type(value)).encode("utf-8")
186
+ return _write_temp_video(debug_bytes, ".mp4")
187
+
188
+
189
+ def _write_temp_video(data: bytes, suffix: str) -> str:
190
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
191
+ try:
192
+ tmp.write(data)
193
+ tmp.flush()
194
+ finally:
195
+ tmp.close()
196
+ return tmp.name
197
+
198
+
199
+ def _guess_video_suffix(data: bytes) -> str:
200
+ header = data[:64]
201
+ # MP4 often contains 'ftyp' box near start
202
+ if b"ftyp" in header:
203
+ return ".mp4"
204
+ # WebM/Matroska magic number starts with 0x1A45DFA3 and often contains 'webm'
205
+ if header.startswith(b"\x1aE\xdf\xa3") or b"webm" in header.lower():
206
+ return ".webm"
207
+ # Default to mp4
208
+ return ".mp4"
209
+
210
+