Spaces:
Runtime error
Runtime error
| import os | |
| import spaces | |
| import time | |
| try: | |
| token =os.environ['HF_TOKEN'] | |
| except: | |
| print("paste your hf token here!") | |
| token = "hf_xxxxxxxxxxxxxxxxxxx" | |
| os.environ['HF_TOKEN'] = token | |
| import torch | |
| import gradio as gr | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from faster_whisper import WhisperModel | |
| from moviepy.editor import VideoFileClip | |
| from transformers import AutoTokenizer, AutoModel | |
| # ======================================== | |
| # Model Initialization | |
| # ======================================== | |
| if gr.NO_RELOAD: | |
| if torch.cuda.is_available(): | |
| speech_model = WhisperModel("large-v3", device="cuda", compute_type="float16") | |
| else: | |
| speech_model = WhisperModel("large-v3", device="cpu") | |
| model_path = 'OpenGVLab/VideoChat-Flash-Qwen2-7B_res448' | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda() | |
| model.config.mm_llm_compress = False | |
| # ======================================== | |
| # Define Utils | |
| # ======================================== | |
| def extract_audio(name): | |
| with VideoFileClip(name) as video: | |
| if video.audio == None: | |
| return None | |
| audio = video.audio | |
| audio_name = name[:-4] + '.wav' | |
| audio.write_audiofile(audio_name, fps=16000) | |
| return audio_name | |
| def audio2text(audio): | |
| segments, _ = speech_model.transcribe(audio) | |
| text = "" | |
| for segment in segments: | |
| # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) | |
| text += ("[%.2fs -> %.2fs] %s " % (segment.start, segment.end, segment.text)) | |
| # print(text) | |
| return text | |
| # ======================================== | |
| # Gradio Setting | |
| # ======================================== | |
| def gradio_reset(): | |
| return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False), gr.update(interactive=False) , gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), [], "" | |
| def upload_video(gr_video, text_input="Type and press Enter"): | |
| if gr_video is None: | |
| return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), "" | |
| # if check_asr: #表示需要提取音频 | |
| audio_name = extract_audio(gr_video) | |
| if audio_name != None: | |
| asr_msg = audio2text(audio_name) | |
| else: | |
| asr_msg = "" | |
| # else: | |
| # asr_msg = "" | |
| return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder=text_input), gr.update(value="Start Chatting", interactive=False), asr_msg | |
| def clear_(): | |
| return [], [] | |
| def gradio_ask(user_message, chatbot): | |
| # if len(user_message) == 0: | |
| # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot | |
| chatbot = chatbot + [[user_message, None]] | |
| return user_message, chatbot | |
| def gradio_answer(chatbot, text_input, video_path, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, do_sample, num_beams, top_p, temperature): | |
| if chat_state is None or len(chat_state) == 0: | |
| if asr_msg is None or len(asr_msg) == 0: | |
| # text_input = f"Answer the question based on the video content.\n{text_input}" | |
| pass | |
| elif check_asr: | |
| text_input = f"The speech extracted from the video via ASR is as follows: {asr_msg}\n{text_input}" | |
| print(f"\033[91m== text_input: \033[0m\n{text_input}\n") | |
| response, chat_state = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=text_input, chat_history=chat_state, return_history=True, max_num_frames=max_num_frames, generation_config={ | |
| 'max_new_tokens': max_new_tokens, 'do_sample':do_sample, | |
| 'num_beams':num_beams, 'top_p':top_p, 'temperature':temperature | |
| }) | |
| current_response = "" | |
| for char in response: | |
| current_response += char | |
| chatbot[-1][1] = current_response + "▌" | |
| yield chatbot, chat_state | |
| time.sleep(0.008) | |
| chatbot[-1][1] = current_response | |
| yield chatbot, chat_state | |
| class OpenGVLab(gr.themes.base.Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue=colors.blue, | |
| secondary_hue=colors.sky, | |
| neutral_hue=colors.gray, | |
| spacing_size=sizes.spacing_md, | |
| radius_size=sizes.radius_sm, | |
| text_size=sizes.text_md, | |
| font=( | |
| fonts.GoogleFont("Noto Sans"), | |
| "ui-sans-serif", | |
| "sans-serif", | |
| ), | |
| font_mono=( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| body_background_fill="*neutral_50", | |
| ) | |
| gvlabtheme = OpenGVLab(primary_hue=colors.blue, | |
| secondary_hue=colors.sky, | |
| neutral_hue=colors.gray, | |
| spacing_size=sizes.spacing_md, | |
| radius_size=sizes.radius_sm, | |
| text_size=sizes.text_md, | |
| ) | |
| title = """<h1 align="center"><a href="https://github.com/OpenGVLab/VideoChat-Flash"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="VideoChat-Flash" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" | |
| description =""" | |
| VideoChat-Flash-7B@448 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/VideoChat-Flash'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> | |
| """ | |
| with gr.Blocks(title="VideoChat-Flash",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| # with gr.Row(): | |
| # # options_yes_no = ["YES", "NO"] | |
| # # with gr.Row(): | |
| # # radio_type = gr.Radio(choices=options_1, label="VideoChat-Flash", value=options_1[0]) | |
| # with gr.Row(): | |
| with gr.Row(): | |
| with gr.Column(scale=0.5, visible=True) as video_upload: | |
| with gr.Column(elem_id="image", scale=0.5) as img_part: | |
| up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload") | |
| upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
| restart = gr.Button("Restart") | |
| max_num_frames = gr.Slider( | |
| minimum=4, | |
| maximum=1024, | |
| value=512, | |
| step=4, | |
| interactive=True, | |
| label="Max Input Frames", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=1024, | |
| step=1, | |
| interactive=True, | |
| label="Max Output Tokens", | |
| ) | |
| check_asr = gr.Checkbox(label="Use ASR", info="Whether to extract speech using ASR.") | |
| check_do_sample = gr.Checkbox(label="Do Sample", info="Whether to do sample during decoding.") | |
| num_beams = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| visible=False, | |
| label="beam search numbers)", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.1, | |
| visible=False, | |
| interactive=True, label="Top_P", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.1, | |
| step=0.1, | |
| visible=False, | |
| interactive=True, label="Temperature", | |
| ) | |
| def toggle_slide(is_checked): | |
| return gr.update(visible=is_checked), gr.update(visible=is_checked), gr.update(visible=is_checked) | |
| check_do_sample.select(fn=toggle_slide, inputs=check_do_sample, outputs=[num_beams, top_p, temperature]) | |
| with gr.Column(visible=True) as input_raws: | |
| chat_state = gr.State([]) | |
| asr_msg = gr.State() | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label='VideoChat', | |
| avatar_images=[ | |
| "human.jpg", # 用户头像 | |
| "assistant.png", # AI头像 | |
| ]) | |
| with gr.Row(): | |
| with gr.Column(scale=0.7): | |
| text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False) | |
| with gr.Column(scale=0.15, min_width=0): | |
| run = gr.Button("💭Send", interactive=False) | |
| with gr.Column(scale=0.15, min_width=0): | |
| clear = gr.Button("🔄Clear️", interactive=False) | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=[ | |
| ["demo_videos/basketball.mp4", False, "Describe this video in detail."], | |
| ["demo_videos/cup1.mp4", False, "Describe this video in detail."], | |
| ["demo_videos/dog.mp4", False, "Describe this video in detail."], | |
| ], | |
| inputs = [up_video, text_input], | |
| outputs = [run, clear, up_video, text_input, upload_button, asr_msg], | |
| fn=upload_video, | |
| run_on_click=True | |
| ) | |
| up_video.clear(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False) | |
| upload_button.click(upload_video, [up_video], [run, clear, up_video, text_input, upload_button, asr_msg]) | |
| text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
| gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state] | |
| ).then(lambda: "", None, text_input) | |
| run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
| gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state] | |
| ).then(lambda: "", None, text_input) | |
| clear.click(clear_, None, [chatbot, chat_state]) | |
| restart.click(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False) | |
| demo.launch(server_name='0.0.0.0',server_port=7864) |