Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| import gradio as gr | |
| from difpoint.inference import Inferencer | |
| from TTS.api import TTS | |
| import torch | |
| import time | |
| from flask import send_from_directory | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| import tensorrt | |
| import multiprocessing as mp | |
| import pickle | |
| mp.set_start_method('spawn', force=True) | |
| repo_id = "ChaolongYang/KDTalker" | |
| local_dir = "./downloaded_repo" | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
| print("\nFiles downloaded:") | |
| for root, dirs, files in os.walk(local_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| print(file_path) | |
| result_dir = "results" | |
| def set_upload(): | |
| return "upload" | |
| def set_microphone(): | |
| return "microphone" | |
| def set_tts(): | |
| return "tts" | |
| def create_kd_talker(): | |
| return Inferencer() | |
| example_folder = "example" | |
| example_choices = ["Example 1", "Example 2", "Example 3"] | |
| example_mapping = { | |
| "Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")}, | |
| "Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")}, | |
| "Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")}, | |
| } | |
| def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type): | |
| global result_dir | |
| output_file_path = os.path.join('./downloaded_repo/', 'output.wav') | |
| if reference_audio_type == 'upload': | |
| audio_file_pth = upload_reference_audio | |
| elif reference_audio_type == 'microphone': | |
| audio_file_pth = microphone_reference_audio | |
| tts = TTS('tts_models/multilingual/multi-dataset/your_tts') | |
| tts.tts_to_file( | |
| text=prompt, | |
| file_path=output_file_path, | |
| speaker_wav=audio_file_pth, | |
| language="en", | |
| ) | |
| return gr.Audio(value=output_file_path, type='filepath') | |
| def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t): | |
| return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, | |
| smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t) | |
| def main(): | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| with gr.Blocks(analytics_enabled=False) as interface: | |
| with gr.Row(): | |
| gr.HTML( | |
| """ | |
| <div align='center'> | |
| <h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2> | |
| <div style="display: flex; justify-content: center; align-items: center; gap: 20px;"> | |
| <img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/> | |
| <img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| driven_audio_type = gr.Textbox(value="upload", visible=False) | |
| reference_audio_type = gr.Textbox(value="upload", visible=False) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| with gr.Tabs(elem_id="kdtalker_source_image"): | |
| with gr.TabItem("Upload image"): | |
| source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256) | |
| with gr.Tabs(elem_id="kdtalker_driven_audio"): | |
| with gr.TabItem("Upload"): | |
| upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath") | |
| upload_driven_audio.change(set_upload, outputs=driven_audio_type) | |
| with gr.TabItem("TTS"): | |
| upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath") | |
| upload_reference_audio.change(set_upload, outputs=reference_audio_type) | |
| microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath") | |
| microphone_reference_audio.change(set_microphone, outputs=reference_audio_type) | |
| input_text = gr.Textbox( | |
| label="Generating audio from text", | |
| lines=5, | |
| placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS." | |
| ) | |
| tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary") | |
| tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath") | |
| tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio]) | |
| tts_button.click(set_tts, outputs=driven_audio_type) | |
| with gr.Column(variant="panel"): | |
| gen_video = gr.Video(label="Generated video", format="mp4", width=256) | |
| with gr.Tabs(elem_id="talker_checkbox"): | |
| with gr.TabItem("KDTalker"): | |
| smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8) | |
| smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8) | |
| smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8) | |
| smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8) | |
| kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary") | |
| kd_submit.click( | |
| fn=generate, | |
| inputs=[ | |
| upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, | |
| smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t | |
| ], | |
| outputs=[gen_video] | |
| ) | |
| with gr.TabItem("Example"): | |
| example_choice = gr.Dropdown(choices=example_choices, label="Choose an example") | |
| def load_example(choice): | |
| example = example_mapping.get(choice, {}) | |
| audio_path = example.get("audio", "") | |
| image_path = example.get("image", "") | |
| return [audio_path, image_path] | |
| example_choice.change( | |
| fn=load_example, | |
| inputs=[example_choice], | |
| outputs=[upload_driven_audio, source_image] | |
| ) | |
| example_choice.change(set_upload, outputs=driven_audio_type) | |
| return interface | |
| demo = main() | |
| demo.queue().launch(share=True) |