Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, Glm4vForConditionalGeneration, TextIteratorStreamer | |
| from pathlib import Path | |
| import threading | |
| import re | |
| import copy | |
| import spaces | |
| import fitz | |
| import subprocess | |
| import tempfile | |
| import os | |
| import time | |
| MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking" | |
| stop_generation = False | |
| processor = None | |
| model = None | |
| def load_model(): | |
| """加载模型和处理器""" | |
| global processor, model | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True) | |
| model = Glm4vForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| attn_implementation="sdpa", | |
| ) | |
| class GLM4VModel: | |
| def __init__(self): | |
| pass | |
| def _strip_html(self, t): | |
| return re.sub(r"<[^>]+>", "", t).strip() | |
| def _wrap_text(self, t): | |
| return [{"type": "text", "text": t}] | |
| def _pdf_to_imgs(self, pdf_path): | |
| doc = fitz.open(pdf_path) | |
| imgs = [] | |
| for i in range(doc.page_count): | |
| pix = doc.load_page(i).get_pixmap(dpi=180) | |
| img_p = os.path.join(tempfile.gettempdir(), f"{Path(pdf_path).stem}_{i}.png") | |
| pix.save(img_p) | |
| imgs.append(img_p) | |
| doc.close() | |
| return imgs | |
| def _ppt_to_imgs(self, ppt_path): | |
| tmp = tempfile.mkdtemp() | |
| subprocess.run( | |
| ["libreoffice", "--headless", "--convert-to", "pdf", "--outdir", tmp, ppt_path], | |
| check=True, | |
| ) | |
| pdf_path = os.path.join(tmp, Path(ppt_path).stem + ".pdf") | |
| return self._pdf_to_imgs(pdf_path) | |
| def _files_to_content(self, media): | |
| out = [] | |
| for f in media or []: | |
| ext = Path(f.name).suffix.lower() | |
| if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]: | |
| out.append({"type": "video", "url": f.name}) | |
| elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]: | |
| out.append({"type": "image", "url": f.name}) | |
| elif ext in [".ppt", ".pptx"]: | |
| for p in self._ppt_to_imgs(f.name): | |
| out.append({"type": "image", "url": p}) | |
| elif ext == ".pdf": | |
| for p in self._pdf_to_imgs(f.name): | |
| out.append({"type": "image", "url": p}) | |
| return out | |
| def _stream_fragment(self, buf: str) -> str: | |
| think_html = "" | |
| if "<think>" in buf: | |
| if "</think>" in buf: | |
| seg = re.search(r"<think>(.*?)</think>", buf, re.DOTALL) | |
| if seg: | |
| think_html = ( | |
| "<details open><summary style='cursor:pointer;font-weight:bold;color:#bbbbbb;'>💭 Thinking</summary>" | |
| "<div style='color:#cccccc;line-height:1.4;padding:10px;border-left:3px solid #666;margin:5px 0;background-color:rgba(128,128,128,0.1);'>" | |
| + seg.group(1).strip().replace("\n", "<br>") | |
| + "</div></details>" | |
| ) | |
| else: | |
| part = buf.split("<think>", 1)[1] | |
| think_html = ( | |
| "<details open><summary style='cursor:pointer;font-weight:bold;color:#bbbbbb;'>💭 Thinking</summary>" | |
| "<div style='color:#cccccc;line-height:1.4;padding:10px;border-left:3px solid #666;margin:5px 0;background-color:rgba(128,128,128,0.1);'>" | |
| + part.replace("\n", "<br>") | |
| + "</div></details>" | |
| ) | |
| answer_html = "" | |
| if "<answer>" in buf: | |
| if "</answer>" in buf: | |
| seg = re.search(r"<answer>(.*?)</answer>", buf, re.DOTALL) | |
| if seg: | |
| answer_html = seg.group(1).strip() | |
| else: | |
| answer_html = buf.split("<answer>", 1)[1] | |
| if not think_html and not answer_html: | |
| return self._strip_html(buf) | |
| return think_html + answer_html | |
| def _build_messages(self, raw_hist, sys_prompt): | |
| msgs = [] | |
| if sys_prompt.strip(): | |
| msgs.append({"role": "system", "content": [{"type": "text", "text": sys_prompt.strip()}]}) | |
| for h in raw_hist: | |
| if h["role"] == "user": | |
| msgs.append({"role": "user", "content": h["content"]}) | |
| else: | |
| raw = h["content"] | |
| raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL) | |
| raw = re.sub(r"<details.*?</details>", "", raw, flags=re.DOTALL) | |
| clean = self._strip_html(raw).strip() | |
| msgs.append({"role": "assistant", "content": self._wrap_text(clean)}) | |
| return msgs | |
| def stream_generate(self, raw_hist, sys_prompt): | |
| global stop_generation, processor, model | |
| stop_generation = False | |
| msgs = self._build_messages(raw_hist, sys_prompt) | |
| inputs = processor.apply_chat_template( | |
| msgs, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=False) | |
| gen_args = dict( | |
| inputs, | |
| max_new_tokens=8192, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| top_k=2, | |
| temperature=None, | |
| top_p=1e-5, | |
| streamer=streamer, | |
| ) | |
| generation_thread = threading.Thread(target=model.generate, kwargs=gen_args) | |
| generation_thread.start() | |
| buf = "" | |
| for tok in streamer: | |
| if stop_generation: | |
| break | |
| buf += tok | |
| yield self._stream_fragment(buf) | |
| generation_thread.join() | |
| def format_display_content(content): | |
| if isinstance(content, list): | |
| text_parts = [] | |
| file_count = 0 | |
| for item in content: | |
| if item["type"] == "text": | |
| text_parts.append(item["text"]) | |
| else: | |
| file_count += 1 | |
| display_text = " ".join(text_parts) | |
| if file_count > 0: | |
| return f"[{file_count} file(s) uploaded]\n{display_text}" | |
| return display_text | |
| return content | |
| def create_display_history(raw_hist): | |
| display_hist = [] | |
| for h in raw_hist: | |
| if h["role"] == "user": | |
| display_content = format_display_content(h["content"]) | |
| display_hist.append({"role": "user", "content": display_content}) | |
| else: | |
| display_hist.append({"role": "assistant", "content": h["content"]}) | |
| return display_hist | |
| # 加载模型和处理器 | |
| load_model() | |
| glm4v = GLM4VModel() | |
| def check_files(files): | |
| vids = imgs = ppts = pdfs = 0 | |
| for f in files or []: | |
| ext = Path(f.name).suffix.lower() | |
| if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]: | |
| vids += 1 | |
| elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]: | |
| imgs += 1 | |
| elif ext in [".ppt", ".pptx"]: | |
| ppts += 1 | |
| elif ext == ".pdf": | |
| pdfs += 1 | |
| if vids > 1 or ppts > 1 or pdfs > 1: | |
| return False, "Only one video or one PPT or one PDF allowed" | |
| if imgs > 10: | |
| return False, "Maximum 10 images allowed" | |
| if (ppts or pdfs) and (vids or imgs) or (vids and imgs): | |
| return False, "Cannot mix documents, videos, and images" | |
| return True, "" | |
| def chat(files, msg, raw_hist, sys_prompt): | |
| global stop_generation | |
| stop_generation = False | |
| ok, err = check_files(files) | |
| if not ok: | |
| raw_hist.append({"role": "assistant", "content": err}) | |
| display_hist = create_display_history(raw_hist) | |
| yield display_hist, copy.deepcopy(raw_hist), None, "" | |
| return | |
| payload = glm4v._files_to_content(files) if files else None | |
| if msg.strip(): | |
| if payload is None: | |
| payload = glm4v._wrap_text(msg.strip()) | |
| else: | |
| payload.append({"type": "text", "text": msg.strip()}) | |
| user_rec = {"role": "user", "content": payload if payload else msg.strip()} | |
| if raw_hist is None: | |
| raw_hist = [] | |
| raw_hist.append(user_rec) | |
| place = {"role": "assistant", "content": ""} | |
| raw_hist.append(place) | |
| display_hist = create_display_history(raw_hist) | |
| yield display_hist, copy.deepcopy(raw_hist), None, "" | |
| for chunk in glm4v.stream_generate(raw_hist[:-1], sys_prompt): | |
| if stop_generation: | |
| break | |
| place["content"] = chunk | |
| display_hist = create_display_history(raw_hist) | |
| yield display_hist, copy.deepcopy(raw_hist), None, "" | |
| display_hist = create_display_history(raw_hist) | |
| yield display_hist, copy.deepcopy(raw_hist), None, "" | |
| def reset(): | |
| global stop_generation | |
| stop_generation = True | |
| time.sleep(0.1) | |
| return [], [], None, "" | |
| css = """.chatbot-container .message-wrap .message{font-size:14px!important} | |
| details summary{cursor:pointer;font-weight:bold} | |
| details[open] summary{margin-bottom:10px}""" | |
| demo = gr.Blocks(title="GLM-4.1V Chat", theme=gr.themes.Soft(), css=css) | |
| with demo: | |
| gr.Markdown(""" | |
| <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
| GLM-4.1V-9B-Thinking Gradio Space🤗 | |
| </div> | |
| <div style="text-align: center;"> | |
| <a href="https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking">🤗 Model Hub</a> | | |
| <a href="https://github.com/THUDM/GLM-4.1V-Thinking">🌐 Github</a> | |
| </div> | |
| """) | |
| raw_history = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| chatbox = gr.Chatbot( | |
| label="Conversation", | |
| type="messages", | |
| height=600, | |
| elem_classes="chatbot-container", | |
| ) | |
| textbox = gr.Textbox(label="💭 Message", lines=3) | |
| with gr.Row(): | |
| send = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=3): | |
| up = gr.File( | |
| label="📁 Upload", | |
| file_count="multiple", | |
| file_types=["file"], | |
| type="filepath", | |
| ) | |
| gr.Markdown("Supports images / videos / PPT / PDF") | |
| gr.Markdown( | |
| "The maximum supported input is 10 images or 1 video/PPT/PDF. During the conversation, video and images cannot be present at the same time." | |
| ) | |
| sys = gr.Textbox(label="⚙️ System Prompt", lines=6) | |
| send.click(chat, inputs=[up, textbox, raw_history, sys], outputs=[chatbox, raw_history, up, textbox]) | |
| textbox.submit(chat, inputs=[up, textbox, raw_history, sys], outputs=[chatbox, raw_history, up, textbox]) | |
| clear.click(reset, outputs=[chatbox, raw_history, up, textbox]) | |
| if __name__ == "__main__": | |
| demo.launch() |