|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import CogVideoXPipeline |
|
|
from diffusers.utils import export_to_video |
|
|
from datetime import datetime, timedelta |
|
|
from openai import OpenAI |
|
|
import spaces |
|
|
import moviepy as mp |
|
|
|
|
|
dtype = torch.float16 |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device) |
|
|
|
|
|
os.makedirs("./output", exist_ok=True) |
|
|
os.makedirs("./gradio_tmp", exist_ok=True) |
|
|
|
|
|
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets. |
|
|
|
|
|
For example, outputting "a beautiful morning in the woods with the sun peeking through the trees" will trigger your partner bot to output a video of a forest morning, as described. You will be prompted by people looking to create detailed, amazing videos. |
|
|
|
|
|
You will only ever output a single video description per user request. |
|
|
When modifications are requested, refactor the entire description to integrate suggestions. |
|
|
Other times the user will not want modifications but a new video. In that case, ignore previous conversation history. |
|
|
""" |
|
|
|
|
|
def convert_prompt(prompt: str, retry_times: int = 3) -> str: |
|
|
if not os.environ.get("OPENAI_API_KEY"): |
|
|
return prompt |
|
|
client = OpenAI() |
|
|
text = prompt.strip() |
|
|
for _ in range(retry_times): |
|
|
response = client.chat.completions.create( |
|
|
messages=[ |
|
|
{"role": "system", "content": sys_prompt}, |
|
|
{"role": "user", "content": f'Create a detailed imaginative video caption for: "{text}"'}, |
|
|
], |
|
|
model="glm-4-0520", |
|
|
temperature=0.01, |
|
|
top_p=0.7, |
|
|
stream=False, |
|
|
max_tokens=250, |
|
|
) |
|
|
if response.choices: |
|
|
return response.choices[0].message.content |
|
|
return prompt |
|
|
|
|
|
@spaces.GPU(duration=240) |
|
|
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)): |
|
|
torch.cuda.empty_cache() |
|
|
video = pipe( |
|
|
prompt=prompt, |
|
|
num_videos_per_prompt=1, |
|
|
num_inference_steps=num_inference_steps, |
|
|
num_frames=49, |
|
|
guidance_scale=guidance_scale, |
|
|
).frames[0] |
|
|
return video |
|
|
|
|
|
def save_video(tensor): |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
video_path = f"./output/{timestamp}.mp4" |
|
|
os.makedirs(os.path.dirname(video_path), exist_ok=True) |
|
|
export_to_video(tensor, video_path) |
|
|
return video_path |
|
|
|
|
|
def convert_to_gif(video_path): |
|
|
clip = mp.VideoFileClip(video_path) |
|
|
clip = clip.with_fps(8) |
|
|
clip = clip.resized(height=240) |
|
|
gif_path = video_path.replace(".mp4", ".gif") |
|
|
clip.write_gif(gif_path, fps=8) |
|
|
return gif_path |
|
|
|
|
|
def delete_old_files(): |
|
|
while True: |
|
|
now = datetime.now() |
|
|
cutoff = now - timedelta(minutes=10) |
|
|
for directory in ["./output", "./gradio_tmp"]: |
|
|
for filename in os.listdir(directory): |
|
|
file_path = os.path.join(directory, filename) |
|
|
if os.path.isfile(file_path): |
|
|
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) |
|
|
if file_mtime < cutoff: |
|
|
os.remove(file_path) |
|
|
time.sleep(600) |
|
|
|
|
|
threading.Thread(target=delete_old_files, daemon=True).start() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt (Less than 200 Words)", |
|
|
placeholder="Enter your prompt here", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("β¨ Click enhance to polish your prompt with GLM-4.") |
|
|
enhance_button = gr.Button("β¨ Enhance Prompt (Optional)") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Optional Parameters:** Default values are recommended.") |
|
|
with gr.Row(): |
|
|
num_inference_steps = gr.Number(label="Inference Steps", value=50) |
|
|
guidance_scale = gr.Number(label="Guidance Scale", value=6.0) |
|
|
generate_button = gr.Button("π¬ Generate Video") |
|
|
|
|
|
with gr.Column(): |
|
|
video_output = gr.Video(label="Generated Video", width=720, height=480) |
|
|
with gr.Row(): |
|
|
download_video_button = gr.File(label="π₯ Download Video", visible=False) |
|
|
download_gif_button = gr.File(label="π₯ Download GIF", visible=False) |
|
|
|
|
|
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)): |
|
|
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress) |
|
|
video_path = save_video(tensor) |
|
|
video_update = gr.update(visible=True, value=video_path) |
|
|
gif_path = convert_to_gif(video_path) |
|
|
gif_update = gr.update(visible=True, value=gif_path) |
|
|
return video_path, video_update, gif_update |
|
|
|
|
|
def enhance_prompt_func(prompt): |
|
|
return convert_prompt(prompt, retry_times=1) |
|
|
|
|
|
generate_button.click( |
|
|
generate, |
|
|
inputs=[prompt, num_inference_steps, guidance_scale], |
|
|
outputs=[video_output, download_video_button, download_gif_button] |
|
|
) |
|
|
|
|
|
enhance_button.click( |
|
|
enhance_prompt_func, |
|
|
inputs=[prompt], |
|
|
outputs=[prompt] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|