appkelvo / app.py
savvyf's picture
Upload 4 files
4f66ca8 verified
import os
import threading
import time
import gradio as gr
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from datetime import datetime, timedelta
from openai import OpenAI
import spaces
import moviepy as mp
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype).to(device)
os.makedirs("./output", exist_ok=True)
os.makedirs("./gradio_tmp", exist_ok=True)
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
For example, outputting "a beautiful morning in the woods with the sun peeking through the trees" will trigger your partner bot to output a video of a forest morning, as described. You will be prompted by people looking to create detailed, amazing videos.
You will only ever output a single video description per user request.
When modifications are requested, refactor the entire description to integrate suggestions.
Other times the user will not want modifications but a new video. In that case, ignore previous conversation history.
"""
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
if not os.environ.get("OPENAI_API_KEY"):
return prompt
client = OpenAI()
text = prompt.strip()
for _ in range(retry_times):
response = client.chat.completions.create(
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": f'Create a detailed imaginative video caption for: "{text}"'},
],
model="glm-4-0520",
temperature=0.01,
top_p=0.7,
stream=False,
max_tokens=250,
)
if response.choices:
return response.choices[0].message.content
return prompt
@spaces.GPU(duration=240)
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
torch.cuda.empty_cache()
video = pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=num_inference_steps,
num_frames=49,
guidance_scale=guidance_scale,
).frames[0]
return video
def save_video(tensor):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"./output/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video(tensor, video_path)
return video_path
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.with_fps(8)
clip = clip.resized(height=240)
gif_path = video_path.replace(".mp4", ".gif")
clip.write_gif(gif_path, fps=8)
return gif_path
def delete_old_files():
while True:
now = datetime.now()
cutoff = now - timedelta(minutes=10)
for directory in ["./output", "./gradio_tmp"]:
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
if file_mtime < cutoff:
os.remove(file_path)
time.sleep(600)
threading.Thread(target=delete_old_files, daemon=True).start()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)",
placeholder="Enter your prompt here",
lines=5
)
with gr.Row():
gr.Markdown("✨ Click enhance to polish your prompt with GLM-4.")
enhance_button = gr.Button("✨ Enhance Prompt (Optional)")
with gr.Column():
gr.Markdown("**Optional Parameters:** Default values are recommended.")
with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
generate_button = gr.Button("🎬 Generate Video")
with gr.Column():
video_output = gr.Video(label="Generated Video", width=720, height=480)
with gr.Row():
download_video_button = gr.File(label="πŸ“₯ Download Video", visible=False)
download_gif_button = gr.File(label="πŸ“₯ Download GIF", visible=False)
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path)
return video_path, video_update, gif_update
def enhance_prompt_func(prompt):
return convert_prompt(prompt, retry_times=1)
generate_button.click(
generate,
inputs=[prompt, num_inference_steps, guidance_scale],
outputs=[video_output, download_video_button, download_gif_button]
)
enhance_button.click(
enhance_prompt_func,
inputs=[prompt],
outputs=[prompt]
)
if __name__ == "__main__":
demo.launch()