Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						67ab16c
	
0
								Parent(s):
							
							
Duplicate from microsoft/HuggingGPT
Browse filesCo-authored-by: Yongliang Shen <[email protected]>
- .gitattributes +34 -0
- .gitignore +5 -0
- README.md +13 -0
- app.py +202 -0
- awesome_chat.py +920 -0
- config.gradio.yaml +34 -0
- data/p0_models.jsonl +0 -0
- demos/demo_choose_model.json +10 -0
- demos/demo_parse_task.json +65 -0
- demos/demo_response_results.json +10 -0
- get_token_ids.py +49 -0
- models_server.py +618 -0
- packages.txt +1 -0
- public/examples/a.jpg +0 -0
- public/examples/b.jpg +0 -0
- public/examples/c.jpg +0 -0
- public/examples/d.jpg +0 -0
- public/examples/e.jpg +0 -0
- public/examples/f.jpg +0 -0
- public/examples/g.jpg +0 -0
- requirements.txt +18 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            logs/
         | 
| 2 | 
            +
            models
         | 
| 3 | 
            +
            public/*
         | 
| 4 | 
            +
            *.pyc
         | 
| 5 | 
            +
            !public/examples
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: HuggingGPT
         | 
| 3 | 
            +
            emoji: 😻
         | 
| 4 | 
            +
            colorFrom: gray
         | 
| 5 | 
            +
            colorTo: yellow
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.24.1
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            duplicated_from: microsoft/HuggingGPT
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,202 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import uuid
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from diffusers.utils import load_image
         | 
| 5 | 
            +
            import requests
         | 
| 6 | 
            +
            from awesome_chat import chat_huggingface
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            os.makedirs("public/images", exist_ok=True)
         | 
| 10 | 
            +
            os.makedirs("public/audios", exist_ok=True)
         | 
| 11 | 
            +
            os.makedirs("public/videos", exist_ok=True)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            class Client:
         | 
| 14 | 
            +
                def __init__(self) -> None:
         | 
| 15 | 
            +
                    self.OPENAI_KEY = ""
         | 
| 16 | 
            +
                    self.HUGGINGFACE_TOKEN = ""
         | 
| 17 | 
            +
                    self.all_messages = []
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def set_key(self, openai_key):
         | 
| 20 | 
            +
                    self.OPENAI_KEY = openai_key
         | 
| 21 | 
            +
                    if len(self.HUGGINGFACE_TOKEN)>0:
         | 
| 22 | 
            +
                        gr.update(visible = True)
         | 
| 23 | 
            +
                    return self.OPENAI_KEY
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def set_token(self, huggingface_token):
         | 
| 26 | 
            +
                    self.HUGGINGFACE_TOKEN = huggingface_token
         | 
| 27 | 
            +
                    if len(self.OPENAI_KEY)>0:
         | 
| 28 | 
            +
                        gr.update(visible = True)
         | 
| 29 | 
            +
                    return self.HUGGINGFACE_TOKEN
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                def add_message(self, content, role):
         | 
| 32 | 
            +
                    message = {"role":role, "content":content}
         | 
| 33 | 
            +
                    self.all_messages.append(message)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def extract_medias(self, message):
         | 
| 36 | 
            +
                    # url_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?")
         | 
| 37 | 
            +
                    urls = []
         | 
| 38 | 
            +
                    # for match in url_pattern.finditer(message):
         | 
| 39 | 
            +
                    #     if match.group(0) not in urls:
         | 
| 40 | 
            +
                    #         urls.append(match.group(0))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
         | 
| 43 | 
            +
                    image_urls = []
         | 
| 44 | 
            +
                    for match in image_pattern.finditer(message):
         | 
| 45 | 
            +
                        if match.group(0) not in image_urls:
         | 
| 46 | 
            +
                            image_urls.append(match.group(0))
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
         | 
| 49 | 
            +
                    audio_urls = []
         | 
| 50 | 
            +
                    for match in audio_pattern.finditer(message):
         | 
| 51 | 
            +
                        if match.group(0) not in audio_urls:
         | 
| 52 | 
            +
                            audio_urls.append(match.group(0))
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
         | 
| 55 | 
            +
                    video_urls = []
         | 
| 56 | 
            +
                    for match in video_pattern.finditer(message):
         | 
| 57 | 
            +
                        if match.group(0) not in video_urls:
         | 
| 58 | 
            +
                            video_urls.append(match.group(0))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    return urls, image_urls, audio_urls, video_urls
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def add_text(self, messages, message):
         | 
| 63 | 
            +
                    if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
         | 
| 64 | 
            +
                        return messages, "Please set your OpenAI API key and Hugging Face token first!!!"
         | 
| 65 | 
            +
                    self.add_message(message, "user")
         | 
| 66 | 
            +
                    messages = messages + [(message, None)]
         | 
| 67 | 
            +
                    urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    for image_url in image_urls:
         | 
| 70 | 
            +
                        if not image_url.startswith("http") and not image_url.startswith("public"):
         | 
| 71 | 
            +
                            image_url = "public/" + image_url
         | 
| 72 | 
            +
                        image = load_image(image_url)
         | 
| 73 | 
            +
                        name = f"public/images/{str(uuid.uuid4())[:4]}.jpg" 
         | 
| 74 | 
            +
                        image.save(name)
         | 
| 75 | 
            +
                        messages = messages + [((f"{name}",), None)]
         | 
| 76 | 
            +
                    for audio_url in audio_urls and not audio_url.startswith("public"):
         | 
| 77 | 
            +
                        if not audio_url.startswith("http"):
         | 
| 78 | 
            +
                            audio_url = "public/" + audio_url
         | 
| 79 | 
            +
                        ext = audio_url.split(".")[-1]
         | 
| 80 | 
            +
                        name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
         | 
| 81 | 
            +
                        response = requests.get(audio_url)
         | 
| 82 | 
            +
                        with open(name, "wb") as f:
         | 
| 83 | 
            +
                            f.write(response.content)
         | 
| 84 | 
            +
                        messages = messages + [((f"{name}",), None)]
         | 
| 85 | 
            +
                    for video_url in video_urls and not video_url.startswith("public"):
         | 
| 86 | 
            +
                        if not video_url.startswith("http"):
         | 
| 87 | 
            +
                            video_url = "public/" + video_url
         | 
| 88 | 
            +
                        ext = video_url.split(".")[-1]
         | 
| 89 | 
            +
                        name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
         | 
| 90 | 
            +
                        response = requests.get(video_url)
         | 
| 91 | 
            +
                        with open(name, "wb") as f:
         | 
| 92 | 
            +
                            f.write(response.content)
         | 
| 93 | 
            +
                        messages = messages + [((f"{name}",), None)]
         | 
| 94 | 
            +
                    return messages, ""
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def bot(self, messages):
         | 
| 97 | 
            +
                    if len(self.OPENAI_KEY) == 0 or not self.OPENAI_KEY.startswith("sk-") or len(self.HUGGINGFACE_TOKEN) == 0 or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
         | 
| 98 | 
            +
                        return messages, {}
         | 
| 99 | 
            +
                    message, results = chat_huggingface(self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN)
         | 
| 100 | 
            +
                    urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
         | 
| 101 | 
            +
                    self.add_message(message, "assistant")
         | 
| 102 | 
            +
                    messages[-1][1] = message
         | 
| 103 | 
            +
                    for image_url in image_urls:
         | 
| 104 | 
            +
                        if not image_url.startswith("http"):
         | 
| 105 | 
            +
                            image_url = image_url.replace("public/", "")
         | 
| 106 | 
            +
                            messages = messages + [((None, (f"public/{image_url}",)))]
         | 
| 107 | 
            +
                        # else:
         | 
| 108 | 
            +
                        #     messages = messages + [((None, (f"{image_url}",)))]
         | 
| 109 | 
            +
                    for audio_url in audio_urls:
         | 
| 110 | 
            +
                        if not audio_url.startswith("http"):
         | 
| 111 | 
            +
                            audio_url = audio_url.replace("public/", "")
         | 
| 112 | 
            +
                            messages = messages + [((None, (f"public/{audio_url}",)))]
         | 
| 113 | 
            +
                        # else:
         | 
| 114 | 
            +
                        #     messages = messages + [((None, (f"{audio_url}",)))]
         | 
| 115 | 
            +
                    for video_url in video_urls:
         | 
| 116 | 
            +
                        if not video_url.startswith("http"):
         | 
| 117 | 
            +
                            video_url = video_url.replace("public/", "")
         | 
| 118 | 
            +
                            messages = messages + [((None, (f"public/{video_url}",)))]
         | 
| 119 | 
            +
                        # else:
         | 
| 120 | 
            +
                        #     messages = messages + [((None, (f"{video_url}",)))]
         | 
| 121 | 
            +
                    # replace int key to string key
         | 
| 122 | 
            +
                    results = {str(k): v for k, v in results.items()}
         | 
| 123 | 
            +
                    return messages, results
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
            css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
         | 
| 126 | 
            +
            with gr.Blocks(css=css) as demo:
         | 
| 127 | 
            +
                state = gr.State(value={"client": Client()})
         | 
| 128 | 
            +
                gr.Markdown("<h1><center>HuggingGPT</center></h1>")
         | 
| 129 | 
            +
                gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
         | 
| 130 | 
            +
                gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
         | 
| 131 | 
            +
                gr.HTML('''<center><a href="https://huggingface.co/spaces/microsoft/HuggingGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key and Hugging Face Token</center>''')
         | 
| 132 | 
            +
                with gr.Row().style():
         | 
| 133 | 
            +
                    with gr.Column(scale=0.85):
         | 
| 134 | 
            +
                        openai_api_key = gr.Textbox(
         | 
| 135 | 
            +
                            show_label=False,
         | 
| 136 | 
            +
                            placeholder="Set your OpenAI API key here and press Enter",
         | 
| 137 | 
            +
                            lines=1,
         | 
| 138 | 
            +
                            type="password"
         | 
| 139 | 
            +
                        ).style(container=False)
         | 
| 140 | 
            +
                    with gr.Column(scale=0.15, min_width=0):
         | 
| 141 | 
            +
                        btn1 = gr.Button("Submit").style(full_height=True)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                with gr.Row().style():
         | 
| 144 | 
            +
                    with gr.Column(scale=0.85):
         | 
| 145 | 
            +
                        hugging_face_token = gr.Textbox(
         | 
| 146 | 
            +
                            show_label=False,
         | 
| 147 | 
            +
                            placeholder="Set your Hugging Face Token here and press Enter",
         | 
| 148 | 
            +
                            lines=1,
         | 
| 149 | 
            +
                            type="password"
         | 
| 150 | 
            +
                        ).style(container=False)
         | 
| 151 | 
            +
                    with gr.Column(scale=0.15, min_width=0):
         | 
| 152 | 
            +
                        btn3 = gr.Button("Submit").style(full_height=True)
         | 
| 153 | 
            +
                
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                with gr.Row().style():
         | 
| 156 | 
            +
                    with gr.Column(scale=0.6):
         | 
| 157 | 
            +
                        chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
         | 
| 158 | 
            +
                    with gr.Column(scale=0.4):
         | 
| 159 | 
            +
                        results = gr.JSON(elem_classes="json")
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
                with gr.Row().style():
         | 
| 163 | 
            +
                    with gr.Column(scale=0.85):
         | 
| 164 | 
            +
                        txt = gr.Textbox(
         | 
| 165 | 
            +
                            show_label=False,
         | 
| 166 | 
            +
                            placeholder="Enter text and press enter. The url of the multimedia resource must contain the extension name.",
         | 
| 167 | 
            +
                            lines=1,
         | 
| 168 | 
            +
                        ).style(container=False)
         | 
| 169 | 
            +
                    with gr.Column(scale=0.15, min_width=0):
         | 
| 170 | 
            +
                        btn2 = gr.Button("Send").style(full_height=True)
         | 
| 171 | 
            +
                    
         | 
| 172 | 
            +
                def set_key(state, openai_api_key):
         | 
| 173 | 
            +
                    return state["client"].set_key(openai_api_key)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def add_text(state, chatbot, txt):
         | 
| 176 | 
            +
                    return state["client"].add_text(chatbot, txt)
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                def set_token(state, hugging_face_token):
         | 
| 179 | 
            +
                    return state["client"].set_token(hugging_face_token)
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                def bot(state, chatbot):
         | 
| 182 | 
            +
                    return state["client"].bot(chatbot)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                openai_api_key.submit(set_key, [state, openai_api_key], [openai_api_key])
         | 
| 185 | 
            +
                txt.submit(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
         | 
| 186 | 
            +
                hugging_face_token.submit(set_token, [state, hugging_face_token], [hugging_face_token])
         | 
| 187 | 
            +
                btn1.click(set_key, [state, openai_api_key], [openai_api_key])
         | 
| 188 | 
            +
                btn2.click(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
         | 
| 189 | 
            +
                btn3.click(set_token, [state, hugging_face_token], [hugging_face_token])
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                gr.Examples(
         | 
| 192 | 
            +
                    examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
         | 
| 193 | 
            +
                            "Please generate a canny image based on /examples/f.jpg",
         | 
| 194 | 
            +
                            "show me a joke and an image of cat",
         | 
| 195 | 
            +
                            "what is in the examples/a.jpg",
         | 
| 196 | 
            +
                            "based on the /examples/a.jpg, please generate a video and audio",
         | 
| 197 | 
            +
                            "based on pose of /examples/d.jpg and content of /examples/e.jpg, please show me a new image",
         | 
| 198 | 
            +
                            ],
         | 
| 199 | 
            +
                    inputs=txt
         | 
| 200 | 
            +
                )
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            demo.launch()
         | 
    	
        awesome_chat.py
    ADDED
    
    | @@ -0,0 +1,920 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            from io import BytesIO
         | 
| 4 | 
            +
            import io
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import traceback
         | 
| 9 | 
            +
            import uuid
         | 
| 10 | 
            +
            import requests
         | 
| 11 | 
            +
            import re
         | 
| 12 | 
            +
            import json
         | 
| 13 | 
            +
            import logging
         | 
| 14 | 
            +
            import argparse
         | 
| 15 | 
            +
            import yaml
         | 
| 16 | 
            +
            from PIL import Image, ImageDraw
         | 
| 17 | 
            +
            from diffusers.utils import load_image
         | 
| 18 | 
            +
            from pydub import AudioSegment
         | 
| 19 | 
            +
            import threading
         | 
| 20 | 
            +
            from queue import Queue
         | 
| 21 | 
            +
            from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
         | 
| 22 | 
            +
            from huggingface_hub.inference_api import InferenceApi
         | 
| 23 | 
            +
            from huggingface_hub.inference_api import ALL_TASKS
         | 
| 24 | 
            +
            from models_server import models, status
         | 
| 25 | 
            +
            from functools import partial
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            parser = argparse.ArgumentParser()
         | 
| 28 | 
            +
            parser.add_argument("--config", type=str, default="config.yaml.dev")
         | 
| 29 | 
            +
            parser.add_argument("--mode", type=str, default="cli")
         | 
| 30 | 
            +
            args = parser.parse_args()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            if __name__ != "__main__":
         | 
| 33 | 
            +
                args.config = "config.gradio.yaml"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            if not os.path.exists("logs"):
         | 
| 38 | 
            +
                os.mkdir("logs")
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 41 | 
            +
            logger.setLevel(logging.DEBUG)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            handler = logging.StreamHandler()
         | 
| 44 | 
            +
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
         | 
| 45 | 
            +
            handler.setFormatter(formatter)
         | 
| 46 | 
            +
            if not config["debug"]:
         | 
| 47 | 
            +
                handler.setLevel(logging.INFO)
         | 
| 48 | 
            +
            logger.addHandler(handler)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            log_file = config["log_file"]
         | 
| 51 | 
            +
            if log_file:
         | 
| 52 | 
            +
                filehandler = logging.FileHandler(log_file)
         | 
| 53 | 
            +
                filehandler.setLevel(logging.DEBUG)
         | 
| 54 | 
            +
                filehandler.setFormatter(formatter)
         | 
| 55 | 
            +
                logger.addHandler(filehandler)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            LLM = config["model"]
         | 
| 58 | 
            +
            use_completion = config["use_completion"]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            # consistent: wrong msra model name 
         | 
| 61 | 
            +
            LLM_encoding = LLM
         | 
| 62 | 
            +
            if LLM == "gpt-3.5-turbo":
         | 
| 63 | 
            +
                LLM_encoding = "text-davinci-003"
         | 
| 64 | 
            +
            task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
         | 
| 65 | 
            +
            choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            # ENDPOINT	MODEL NAME	
         | 
| 68 | 
            +
            # /v1/chat/completions	gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301	
         | 
| 69 | 
            +
            # /v1/completions	text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            if use_completion:
         | 
| 72 | 
            +
                api_name = "completions"
         | 
| 73 | 
            +
            else:
         | 
| 74 | 
            +
                api_name = "chat/completions"
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            if not config["dev"]:
         | 
| 77 | 
            +
                if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
         | 
| 78 | 
            +
                    raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
         | 
| 79 | 
            +
                OPENAI_KEY = config["openai"]["key"]
         | 
| 80 | 
            +
                endpoint = f"https://api.openai.com/v1/{api_name}"
         | 
| 81 | 
            +
                if OPENAI_KEY.startswith("sk-"):
         | 
| 82 | 
            +
                    HEADER = {
         | 
| 83 | 
            +
                        "Authorization": f"Bearer {OPENAI_KEY}"
         | 
| 84 | 
            +
                    }
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    HEADER = None
         | 
| 87 | 
            +
            else:
         | 
| 88 | 
            +
                endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
         | 
| 89 | 
            +
                HEADER = None
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            PROXY = None
         | 
| 92 | 
            +
            if config["proxy"]:
         | 
| 93 | 
            +
                PROXY = {
         | 
| 94 | 
            +
                    "https": config["proxy"],
         | 
| 95 | 
            +
                }
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            inference_mode = config["inference_mode"]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
         | 
| 100 | 
            +
            choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
         | 
| 101 | 
            +
            response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            parse_task_prompt = config["prompt"]["parse_task"]
         | 
| 104 | 
            +
            choose_model_prompt = config["prompt"]["choose_model"]
         | 
| 105 | 
            +
            response_results_prompt = config["prompt"]["response_results"]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            parse_task_tprompt = config["tprompt"]["parse_task"]
         | 
| 108 | 
            +
            choose_model_tprompt = config["tprompt"]["choose_model"]
         | 
| 109 | 
            +
            response_results_tprompt = config["tprompt"]["response_results"]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
         | 
| 112 | 
            +
            MODELS_MAP = {}
         | 
| 113 | 
            +
            for model in MODELS:
         | 
| 114 | 
            +
                tag = model["task"]
         | 
| 115 | 
            +
                if tag not in MODELS_MAP:
         | 
| 116 | 
            +
                    MODELS_MAP[tag] = []
         | 
| 117 | 
            +
                MODELS_MAP[tag].append(model)
         | 
| 118 | 
            +
            METADATAS = {}
         | 
| 119 | 
            +
            for model in MODELS:
         | 
| 120 | 
            +
                METADATAS[model["id"]] = model
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            def convert_chat_to_completion(data):
         | 
| 123 | 
            +
                messages = data.pop('messages', [])
         | 
| 124 | 
            +
                tprompt = ""
         | 
| 125 | 
            +
                if messages[0]['role'] == "system":
         | 
| 126 | 
            +
                    tprompt = messages[0]['content']
         | 
| 127 | 
            +
                    messages = messages[1:]
         | 
| 128 | 
            +
                final_prompt = ""
         | 
| 129 | 
            +
                for message in messages:
         | 
| 130 | 
            +
                    if message['role'] == "user":
         | 
| 131 | 
            +
                        final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
         | 
| 132 | 
            +
                    elif message['role'] == "assistant":
         | 
| 133 | 
            +
                        final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
         | 
| 136 | 
            +
                final_prompt = tprompt + final_prompt
         | 
| 137 | 
            +
                final_prompt = final_prompt + "<im_start>assistant"
         | 
| 138 | 
            +
                data["prompt"] = final_prompt
         | 
| 139 | 
            +
                data['stop'] = data.get('stop', ["<im_end>"])
         | 
| 140 | 
            +
                data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
         | 
| 141 | 
            +
                return data
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            def send_request(data):
         | 
| 144 | 
            +
                global HEADER
         | 
| 145 | 
            +
                openaikey = data.pop("openaikey")
         | 
| 146 | 
            +
                if use_completion:
         | 
| 147 | 
            +
                    data = convert_chat_to_completion(data)
         | 
| 148 | 
            +
                if openaikey and openaikey.startswith("sk-"):
         | 
| 149 | 
            +
                    HEADER = {
         | 
| 150 | 
            +
                        "Authorization": f"Bearer {openaikey}"
         | 
| 151 | 
            +
                    }
         | 
| 152 | 
            +
                
         | 
| 153 | 
            +
                response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
         | 
| 154 | 
            +
                logger.debug(response.text.strip())
         | 
| 155 | 
            +
                if "choices" not in response.json():
         | 
| 156 | 
            +
                    return response.json()
         | 
| 157 | 
            +
                if use_completion:
         | 
| 158 | 
            +
                    return response.json()["choices"][0]["text"].strip()
         | 
| 159 | 
            +
                else:
         | 
| 160 | 
            +
                    return response.json()["choices"][0]["message"]["content"].strip()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            def replace_slot(text, entries):
         | 
| 163 | 
            +
                for key, value in entries.items():
         | 
| 164 | 
            +
                    if not isinstance(value, str):
         | 
| 165 | 
            +
                        value = str(value)
         | 
| 166 | 
            +
                    text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
         | 
| 167 | 
            +
                return text
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            def find_json(s):
         | 
| 170 | 
            +
                s = s.replace("\'", "\"")
         | 
| 171 | 
            +
                start = s.find("{")
         | 
| 172 | 
            +
                end = s.rfind("}")
         | 
| 173 | 
            +
                res = s[start:end+1]
         | 
| 174 | 
            +
                res = res.replace("\n", "")
         | 
| 175 | 
            +
                return res
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            def field_extract(s, field):
         | 
| 178 | 
            +
                try:
         | 
| 179 | 
            +
                    field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
         | 
| 180 | 
            +
                    extracted = field_rep.search(s).group(1).replace("\"", "\'")
         | 
| 181 | 
            +
                except:
         | 
| 182 | 
            +
                    field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
         | 
| 183 | 
            +
                    extracted = field_rep.search(s).group(1).replace("\"", "\'")
         | 
| 184 | 
            +
                return extracted
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            def get_id_reason(choose_str):
         | 
| 187 | 
            +
                reason = field_extract(choose_str, "reason")
         | 
| 188 | 
            +
                id = field_extract(choose_str, "id")
         | 
| 189 | 
            +
                choose = {"id": id, "reason": reason}
         | 
| 190 | 
            +
                return id.strip(), reason.strip(), choose
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            def record_case(success, **args):
         | 
| 193 | 
            +
                if success:
         | 
| 194 | 
            +
                    f = open("logs/log_success.jsonl", "a")
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    f = open("logs/log_fail.jsonl", "a")
         | 
| 197 | 
            +
                log = args
         | 
| 198 | 
            +
                f.write(json.dumps(log) + "\n")
         | 
| 199 | 
            +
                f.close()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            def image_to_bytes(img_url):
         | 
| 202 | 
            +
                img_byte = io.BytesIO()
         | 
| 203 | 
            +
                type = img_url.split(".")[-1]
         | 
| 204 | 
            +
                load_image(img_url).save(img_byte, format="png")
         | 
| 205 | 
            +
                img_data = img_byte.getvalue()
         | 
| 206 | 
            +
                return img_data
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            def resource_has_dep(command):
         | 
| 209 | 
            +
                args = command["args"]
         | 
| 210 | 
            +
                for _, v in args.items():
         | 
| 211 | 
            +
                    if "<GENERATED>" in v:
         | 
| 212 | 
            +
                        return True
         | 
| 213 | 
            +
                return False
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            def fix_dep(tasks):
         | 
| 216 | 
            +
                for task in tasks:
         | 
| 217 | 
            +
                    args = task["args"]
         | 
| 218 | 
            +
                    task["dep"] = []
         | 
| 219 | 
            +
                    for k, v in args.items():
         | 
| 220 | 
            +
                        if "<GENERATED>" in v:
         | 
| 221 | 
            +
                            dep_task_id = int(v.split("-")[1])
         | 
| 222 | 
            +
                            if dep_task_id not in task["dep"]:
         | 
| 223 | 
            +
                                task["dep"].append(dep_task_id)
         | 
| 224 | 
            +
                    if len(task["dep"]) == 0:
         | 
| 225 | 
            +
                        task["dep"] = [-1]
         | 
| 226 | 
            +
                return tasks
         | 
| 227 | 
            +
             | 
| 228 | 
            +
            def unfold(tasks):
         | 
| 229 | 
            +
                flag_unfold_task = False
         | 
| 230 | 
            +
                try:
         | 
| 231 | 
            +
                    for task in tasks:
         | 
| 232 | 
            +
                        for key, value in task["args"].items():
         | 
| 233 | 
            +
                            if "<GENERATED>" in value:
         | 
| 234 | 
            +
                                generated_items = value.split(",")
         | 
| 235 | 
            +
                                if len(generated_items) > 1:
         | 
| 236 | 
            +
                                    flag_unfold_task = True
         | 
| 237 | 
            +
                                    for item in generated_items:
         | 
| 238 | 
            +
                                        new_task = copy.deepcopy(task)
         | 
| 239 | 
            +
                                        dep_task_id = int(item.split("-")[1])
         | 
| 240 | 
            +
                                        new_task["dep"] = [dep_task_id]
         | 
| 241 | 
            +
                                        new_task["args"][key] = item
         | 
| 242 | 
            +
                                        tasks.append(new_task)
         | 
| 243 | 
            +
                                    tasks.remove(task)
         | 
| 244 | 
            +
                except Exception as e:
         | 
| 245 | 
            +
                    print(e)
         | 
| 246 | 
            +
                    traceback.print_exc()
         | 
| 247 | 
            +
                    logger.debug("unfold task failed.")
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                if flag_unfold_task:
         | 
| 250 | 
            +
                    logger.debug(f"unfold tasks: {tasks}")
         | 
| 251 | 
            +
                    
         | 
| 252 | 
            +
                return tasks
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            def chitchat(messages, openaikey=None):
         | 
| 255 | 
            +
                data = {
         | 
| 256 | 
            +
                    "model": LLM,
         | 
| 257 | 
            +
                    "messages": messages,
         | 
| 258 | 
            +
                    "openaikey": openaikey
         | 
| 259 | 
            +
                }
         | 
| 260 | 
            +
                return send_request(data)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            def parse_task(context, input, openaikey=None):
         | 
| 263 | 
            +
                demos_or_presteps = parse_task_demos_or_presteps
         | 
| 264 | 
            +
                messages = json.loads(demos_or_presteps)
         | 
| 265 | 
            +
                messages.insert(0, {"role": "system", "content": parse_task_tprompt})
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                # cut chat logs
         | 
| 268 | 
            +
                start = 0
         | 
| 269 | 
            +
                while start <= len(context):
         | 
| 270 | 
            +
                    history = context[start:]
         | 
| 271 | 
            +
                    prompt = replace_slot(parse_task_prompt, {
         | 
| 272 | 
            +
                        "input": input,
         | 
| 273 | 
            +
                        "context": history 
         | 
| 274 | 
            +
                    })
         | 
| 275 | 
            +
                    messages.append({"role": "user", "content": prompt})
         | 
| 276 | 
            +
                    history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
         | 
| 277 | 
            +
                    num = count_tokens(LLM_encoding, history_text)
         | 
| 278 | 
            +
                    if get_max_context_length(LLM) - num > 800:
         | 
| 279 | 
            +
                        break
         | 
| 280 | 
            +
                    messages.pop()
         | 
| 281 | 
            +
                    start += 2
         | 
| 282 | 
            +
                
         | 
| 283 | 
            +
                logger.debug(messages)
         | 
| 284 | 
            +
                data = {
         | 
| 285 | 
            +
                    "model": LLM,
         | 
| 286 | 
            +
                    "messages": messages,
         | 
| 287 | 
            +
                    "temperature": 0,
         | 
| 288 | 
            +
                    "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
         | 
| 289 | 
            +
                    "openaikey": openaikey
         | 
| 290 | 
            +
                }
         | 
| 291 | 
            +
                return send_request(data)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
            def choose_model(input, task, metas, openaikey = None):
         | 
| 294 | 
            +
                prompt = replace_slot(choose_model_prompt, {
         | 
| 295 | 
            +
                    "input": input,
         | 
| 296 | 
            +
                    "task": task,
         | 
| 297 | 
            +
                    "metas": metas,
         | 
| 298 | 
            +
                })
         | 
| 299 | 
            +
                demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
         | 
| 300 | 
            +
                    "input": input,
         | 
| 301 | 
            +
                    "task": task,
         | 
| 302 | 
            +
                    "metas": metas
         | 
| 303 | 
            +
                })
         | 
| 304 | 
            +
                messages = json.loads(demos_or_presteps)
         | 
| 305 | 
            +
                messages.insert(0, {"role": "system", "content": choose_model_tprompt})
         | 
| 306 | 
            +
                messages.append({"role": "user", "content": prompt})
         | 
| 307 | 
            +
                logger.debug(messages)
         | 
| 308 | 
            +
                data = {
         | 
| 309 | 
            +
                    "model": LLM,
         | 
| 310 | 
            +
                    "messages": messages,
         | 
| 311 | 
            +
                    "temperature": 0,
         | 
| 312 | 
            +
                    "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
         | 
| 313 | 
            +
                    "openaikey": openaikey
         | 
| 314 | 
            +
                }
         | 
| 315 | 
            +
                return send_request(data)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            def response_results(input, results, openaikey=None):
         | 
| 319 | 
            +
                results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
         | 
| 320 | 
            +
                prompt = replace_slot(response_results_prompt, {
         | 
| 321 | 
            +
                    "input": input,
         | 
| 322 | 
            +
                })
         | 
| 323 | 
            +
                demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
         | 
| 324 | 
            +
                    "input": input,
         | 
| 325 | 
            +
                    "processes": results
         | 
| 326 | 
            +
                })
         | 
| 327 | 
            +
                messages = json.loads(demos_or_presteps)
         | 
| 328 | 
            +
                messages.insert(0, {"role": "system", "content": response_results_tprompt})
         | 
| 329 | 
            +
                messages.append({"role": "user", "content": prompt})
         | 
| 330 | 
            +
                logger.debug(messages)
         | 
| 331 | 
            +
                data = {
         | 
| 332 | 
            +
                    "model": LLM,
         | 
| 333 | 
            +
                    "messages": messages,
         | 
| 334 | 
            +
                    "temperature": 0,
         | 
| 335 | 
            +
                    "openaikey": openaikey
         | 
| 336 | 
            +
                }
         | 
| 337 | 
            +
                return send_request(data)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
         | 
| 340 | 
            +
                if huggingfacetoken is None:
         | 
| 341 | 
            +
                    HUGGINGFACE_HEADERS = {}
         | 
| 342 | 
            +
                else:
         | 
| 343 | 
            +
                    HUGGINGFACE_HEADERS = {
         | 
| 344 | 
            +
                        "Authorization": f"Bearer {huggingfacetoken}",
         | 
| 345 | 
            +
                }
         | 
| 346 | 
            +
                task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
         | 
| 347 | 
            +
                inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
         | 
| 348 | 
            +
                
         | 
| 349 | 
            +
                # NLP tasks
         | 
| 350 | 
            +
                if task == "question-answering":
         | 
| 351 | 
            +
                    inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
         | 
| 352 | 
            +
                    result = inference(inputs)
         | 
| 353 | 
            +
                if task == "sentence-similarity":
         | 
| 354 | 
            +
                    inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
         | 
| 355 | 
            +
                    result = inference(inputs)
         | 
| 356 | 
            +
                if task in ["text-classification",  "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
         | 
| 357 | 
            +
                    inputs = data["text"]
         | 
| 358 | 
            +
                    result = inference(inputs)
         | 
| 359 | 
            +
                
         | 
| 360 | 
            +
                # CV tasks
         | 
| 361 | 
            +
                if task == "visual-question-answering" or task == "document-question-answering":
         | 
| 362 | 
            +
                    img_url = data["image"]
         | 
| 363 | 
            +
                    text = data["text"]
         | 
| 364 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 365 | 
            +
                    img_base64 = base64.b64encode(img_data).decode("utf-8")
         | 
| 366 | 
            +
                    json_data = {}
         | 
| 367 | 
            +
                    json_data["inputs"] = {}
         | 
| 368 | 
            +
                    json_data["inputs"]["question"] = text
         | 
| 369 | 
            +
                    json_data["inputs"]["image"] = img_base64
         | 
| 370 | 
            +
                    result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
         | 
| 371 | 
            +
                    # result = inference(inputs) # not support
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                if task == "image-to-image":
         | 
| 374 | 
            +
                    img_url = data["image"]
         | 
| 375 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 376 | 
            +
                    # result = inference(data=img_data) # not support
         | 
| 377 | 
            +
                    HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
         | 
| 378 | 
            +
                    r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
         | 
| 379 | 
            +
                    result = r.json()
         | 
| 380 | 
            +
                    if "path" in result:
         | 
| 381 | 
            +
                        result["generated image"] = result.pop("path")
         | 
| 382 | 
            +
                
         | 
| 383 | 
            +
                if task == "text-to-image":
         | 
| 384 | 
            +
                    inputs = data["text"]
         | 
| 385 | 
            +
                    img = inference(inputs)
         | 
| 386 | 
            +
                    name = str(uuid.uuid4())[:4]
         | 
| 387 | 
            +
                    img.save(f"public/images/{name}.png")
         | 
| 388 | 
            +
                    result = {}
         | 
| 389 | 
            +
                    result["generated image"] = f"/images/{name}.png"
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                if task == "image-segmentation":
         | 
| 392 | 
            +
                    img_url = data["image"]
         | 
| 393 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 394 | 
            +
                    image = Image.open(BytesIO(img_data))
         | 
| 395 | 
            +
                    predicted = inference(data=img_data)
         | 
| 396 | 
            +
                    colors = []
         | 
| 397 | 
            +
                    for i in range(len(predicted)):
         | 
| 398 | 
            +
                        colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
         | 
| 399 | 
            +
                    for i, pred in enumerate(predicted):
         | 
| 400 | 
            +
                        label = pred["label"]
         | 
| 401 | 
            +
                        mask = pred.pop("mask").encode("utf-8")
         | 
| 402 | 
            +
                        mask = base64.b64decode(mask)
         | 
| 403 | 
            +
                        mask = Image.open(BytesIO(mask), mode='r')
         | 
| 404 | 
            +
                        mask = mask.convert('L')
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                        layer = Image.new('RGBA', mask.size, colors[i])
         | 
| 407 | 
            +
                        image.paste(layer, (0, 0), mask)
         | 
| 408 | 
            +
                    name = str(uuid.uuid4())[:4]
         | 
| 409 | 
            +
                    image.save(f"public/images/{name}.jpg")
         | 
| 410 | 
            +
                    result = {}
         | 
| 411 | 
            +
                    result["generated image with segmentation mask"] = f"/images/{name}.jpg"
         | 
| 412 | 
            +
                    result["predicted"] = predicted
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                if task == "object-detection":
         | 
| 415 | 
            +
                    img_url = data["image"]
         | 
| 416 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 417 | 
            +
                    predicted = inference(data=img_data)
         | 
| 418 | 
            +
                    image = Image.open(BytesIO(img_data))
         | 
| 419 | 
            +
                    draw = ImageDraw.Draw(image)
         | 
| 420 | 
            +
                    labels = list(item['label'] for item in predicted)
         | 
| 421 | 
            +
                    color_map = {}
         | 
| 422 | 
            +
                    for label in labels:
         | 
| 423 | 
            +
                        if label not in color_map:
         | 
| 424 | 
            +
                            color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
         | 
| 425 | 
            +
                    for label in predicted:
         | 
| 426 | 
            +
                        box = label["box"]
         | 
| 427 | 
            +
                        draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
         | 
| 428 | 
            +
                        draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
         | 
| 429 | 
            +
                    name = str(uuid.uuid4())[:4]
         | 
| 430 | 
            +
                    image.save(f"public/images/{name}.jpg")
         | 
| 431 | 
            +
                    result = {}
         | 
| 432 | 
            +
                    result["generated image with predicted box"] = f"/images/{name}.jpg"
         | 
| 433 | 
            +
                    result["predicted"] = predicted
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                if task in ["image-classification"]:
         | 
| 436 | 
            +
                    img_url = data["image"]
         | 
| 437 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 438 | 
            +
                    result = inference(data=img_data)
         | 
| 439 | 
            +
             
         | 
| 440 | 
            +
                if task == "image-to-text":
         | 
| 441 | 
            +
                    img_url = data["image"]
         | 
| 442 | 
            +
                    img_data = image_to_bytes(img_url)
         | 
| 443 | 
            +
                    HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
         | 
| 444 | 
            +
                    r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
         | 
| 445 | 
            +
                    result = {}
         | 
| 446 | 
            +
                    if "generated_text" in r.json()[0]:
         | 
| 447 | 
            +
                        result["generated text"] = r.json()[0].pop("generated_text")
         | 
| 448 | 
            +
                
         | 
| 449 | 
            +
                # AUDIO tasks
         | 
| 450 | 
            +
                if task == "text-to-speech":
         | 
| 451 | 
            +
                    inputs = data["text"]
         | 
| 452 | 
            +
                    response = inference(inputs, raw_response=True)
         | 
| 453 | 
            +
                    # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
         | 
| 454 | 
            +
                    name = str(uuid.uuid4())[:4]
         | 
| 455 | 
            +
                    with open(f"public/audios/{name}.flac", "wb") as f:
         | 
| 456 | 
            +
                        f.write(response.content)
         | 
| 457 | 
            +
                    result = {"generated audio": f"/audios/{name}.flac"}
         | 
| 458 | 
            +
                if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
         | 
| 459 | 
            +
                    audio_url = data["audio"]
         | 
| 460 | 
            +
                    audio_data = requests.get(audio_url, timeout=10).content
         | 
| 461 | 
            +
                    response = inference(data=audio_data, raw_response=True)
         | 
| 462 | 
            +
                    result = response.json()
         | 
| 463 | 
            +
                    if task == "audio-to-audio":
         | 
| 464 | 
            +
                        content = None
         | 
| 465 | 
            +
                        type = None
         | 
| 466 | 
            +
                        for k, v in result[0].items():
         | 
| 467 | 
            +
                            if k == "blob":
         | 
| 468 | 
            +
                                content = base64.b64decode(v.encode("utf-8"))
         | 
| 469 | 
            +
                            if k == "content-type":
         | 
| 470 | 
            +
                                type = "audio/flac".split("/")[-1]
         | 
| 471 | 
            +
                        audio = AudioSegment.from_file(BytesIO(content))
         | 
| 472 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 473 | 
            +
                        audio.export(f"public/audios/{name}.{type}", format=type)
         | 
| 474 | 
            +
                        result = {"generated audio": f"/audios/{name}.{type}"}
         | 
| 475 | 
            +
                return result
         | 
| 476 | 
            +
             | 
| 477 | 
            +
            def local_model_inference(model_id, data, task):
         | 
| 478 | 
            +
                inference = partial(models, model_id)
         | 
| 479 | 
            +
                # contronlet
         | 
| 480 | 
            +
                if model_id.startswith("lllyasviel/sd-controlnet-"):
         | 
| 481 | 
            +
                    img_url = data["image"]
         | 
| 482 | 
            +
                    text = data["text"]
         | 
| 483 | 
            +
                    results = inference({"img_url": img_url, "text": text})
         | 
| 484 | 
            +
                    if "path" in results:
         | 
| 485 | 
            +
                        results["generated image"] = results.pop("path")
         | 
| 486 | 
            +
                    return results
         | 
| 487 | 
            +
                if model_id.endswith("-control"):
         | 
| 488 | 
            +
                    img_url = data["image"]
         | 
| 489 | 
            +
                    results = inference({"img_url": img_url})
         | 
| 490 | 
            +
                    if "path" in results:
         | 
| 491 | 
            +
                        results["generated image"] = results.pop("path")
         | 
| 492 | 
            +
                    return results
         | 
| 493 | 
            +
                    
         | 
| 494 | 
            +
                if task == "text-to-video":
         | 
| 495 | 
            +
                    results = inference(data)
         | 
| 496 | 
            +
                    if "path" in results:
         | 
| 497 | 
            +
                        results["generated video"] = results.pop("path")
         | 
| 498 | 
            +
                    return results
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                # NLP tasks
         | 
| 501 | 
            +
                if task == "question-answering" or task == "sentence-similarity":
         | 
| 502 | 
            +
                    results = inference(json=data)
         | 
| 503 | 
            +
                    return results
         | 
| 504 | 
            +
                if task in ["text-classification",  "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
         | 
| 505 | 
            +
                    results = inference(json=data)
         | 
| 506 | 
            +
                    return results
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                # CV tasks
         | 
| 509 | 
            +
                if task == "depth-estimation":
         | 
| 510 | 
            +
                    img_url = data["image"]
         | 
| 511 | 
            +
                    results = inference({"img_url": img_url})
         | 
| 512 | 
            +
                    if "path" in results:
         | 
| 513 | 
            +
                        results["generated depth image"] = results.pop("path")
         | 
| 514 | 
            +
                    return results
         | 
| 515 | 
            +
                if task == "image-segmentation":
         | 
| 516 | 
            +
                    img_url = data["image"]
         | 
| 517 | 
            +
                    results = inference({"img_url": img_url})
         | 
| 518 | 
            +
                    results["generated image with segmentation mask"] = results.pop("path")
         | 
| 519 | 
            +
                    return results
         | 
| 520 | 
            +
                if task == "image-to-image":
         | 
| 521 | 
            +
                    img_url = data["image"]
         | 
| 522 | 
            +
                    results = inference({"img_url": img_url})
         | 
| 523 | 
            +
                    if "path" in results:
         | 
| 524 | 
            +
                        results["generated image"] = results.pop("path")
         | 
| 525 | 
            +
                    return results
         | 
| 526 | 
            +
                if task == "text-to-image":
         | 
| 527 | 
            +
                    results = inference(data)
         | 
| 528 | 
            +
                    if "path" in results:
         | 
| 529 | 
            +
                        results["generated image"] = results.pop("path")
         | 
| 530 | 
            +
                    return results
         | 
| 531 | 
            +
                if task == "object-detection":
         | 
| 532 | 
            +
                    img_url = data["image"]
         | 
| 533 | 
            +
                    predicted = inference({"img_url": img_url})
         | 
| 534 | 
            +
                    if "error" in predicted:
         | 
| 535 | 
            +
                        return predicted
         | 
| 536 | 
            +
                    image = load_image(img_url)
         | 
| 537 | 
            +
                    draw = ImageDraw.Draw(image)
         | 
| 538 | 
            +
                    labels = list(item['label'] for item in predicted)
         | 
| 539 | 
            +
                    color_map = {}
         | 
| 540 | 
            +
                    for label in labels:
         | 
| 541 | 
            +
                        if label not in color_map:
         | 
| 542 | 
            +
                            color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
         | 
| 543 | 
            +
                    for label in predicted:
         | 
| 544 | 
            +
                        box = label["box"]
         | 
| 545 | 
            +
                        draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
         | 
| 546 | 
            +
                        draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
         | 
| 547 | 
            +
                    name = str(uuid.uuid4())[:4]
         | 
| 548 | 
            +
                    image.save(f"public/images/{name}.jpg")
         | 
| 549 | 
            +
                    results = {}
         | 
| 550 | 
            +
                    results["generated image with predicted box"] = f"/images/{name}.jpg"
         | 
| 551 | 
            +
                    results["predicted"] = predicted
         | 
| 552 | 
            +
                    return results
         | 
| 553 | 
            +
                if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
         | 
| 554 | 
            +
                    img_url = data["image"]
         | 
| 555 | 
            +
                    text = None
         | 
| 556 | 
            +
                    if "text" in data:
         | 
| 557 | 
            +
                        text = data["text"]
         | 
| 558 | 
            +
                    results = inference({"img_url": img_url, "text": text})
         | 
| 559 | 
            +
                    return results
         | 
| 560 | 
            +
                # AUDIO tasks
         | 
| 561 | 
            +
                if task == "text-to-speech":
         | 
| 562 | 
            +
                    results = inference(data)
         | 
| 563 | 
            +
                    if "path" in results:
         | 
| 564 | 
            +
                        results["generated audio"] = results.pop("path")
         | 
| 565 | 
            +
                    return results
         | 
| 566 | 
            +
                if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
         | 
| 567 | 
            +
                    audio_url = data["audio"]
         | 
| 568 | 
            +
                    results = inference({"audio_url": audio_url})
         | 
| 569 | 
            +
                    return results
         | 
| 570 | 
            +
             | 
| 571 | 
            +
             | 
| 572 | 
            +
            def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
         | 
| 573 | 
            +
                if huggingfacetoken:
         | 
| 574 | 
            +
                    HUGGINGFACE_HEADERS = {
         | 
| 575 | 
            +
                        "Authorization": f"Bearer {huggingfacetoken}",
         | 
| 576 | 
            +
                    }
         | 
| 577 | 
            +
                else:
         | 
| 578 | 
            +
                    HUGGINGFACE_HEADERS = None
         | 
| 579 | 
            +
                if hosted_on == "unknown":
         | 
| 580 | 
            +
                    r = status(model_id)
         | 
| 581 | 
            +
                    logger.debug("Local Server Status: " + str(r))
         | 
| 582 | 
            +
                    if "loaded" in r and r["loaded"]:
         | 
| 583 | 
            +
                        hosted_on = "local"
         | 
| 584 | 
            +
                    else:
         | 
| 585 | 
            +
                        huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
         | 
| 586 | 
            +
                        r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
         | 
| 587 | 
            +
                        logger.debug("Huggingface Status: " + str(r.json()))
         | 
| 588 | 
            +
                        if "loaded" in r and r["loaded"]:
         | 
| 589 | 
            +
                            hosted_on = "huggingface"
         | 
| 590 | 
            +
                try:
         | 
| 591 | 
            +
                    if hosted_on == "local":
         | 
| 592 | 
            +
                        inference_result = local_model_inference(model_id, data, task)
         | 
| 593 | 
            +
                    elif hosted_on == "huggingface":
         | 
| 594 | 
            +
                        inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
         | 
| 595 | 
            +
                except Exception as e:
         | 
| 596 | 
            +
                    print(e)
         | 
| 597 | 
            +
                    traceback.print_exc()
         | 
| 598 | 
            +
                    inference_result = {"error":{"message": str(e)}}
         | 
| 599 | 
            +
                return inference_result
         | 
| 600 | 
            +
             | 
| 601 | 
            +
             | 
| 602 | 
            +
            def get_model_status(model_id, url, headers, queue = None):
         | 
| 603 | 
            +
                endpoint_type = "huggingface" if "huggingface" in url else "local"
         | 
| 604 | 
            +
                if "huggingface" in url:
         | 
| 605 | 
            +
                    r = requests.get(url, headers=headers, proxies=PROXY)
         | 
| 606 | 
            +
                else:
         | 
| 607 | 
            +
                    r = status(model_id)
         | 
| 608 | 
            +
                if "loaded" in r and r["loaded"]:
         | 
| 609 | 
            +
                    if queue:
         | 
| 610 | 
            +
                        queue.put((model_id, True, endpoint_type))
         | 
| 611 | 
            +
                    return True
         | 
| 612 | 
            +
                else:
         | 
| 613 | 
            +
                    if queue:
         | 
| 614 | 
            +
                        queue.put((model_id, False, None))
         | 
| 615 | 
            +
                    return False
         | 
| 616 | 
            +
             | 
| 617 | 
            +
            def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
         | 
| 618 | 
            +
                all_available_models = {"local": [], "huggingface": []}
         | 
| 619 | 
            +
                threads = []
         | 
| 620 | 
            +
                result_queue = Queue()
         | 
| 621 | 
            +
                HUGGINGFACE_HEADERS = {
         | 
| 622 | 
            +
                    "Authorization": f"Bearer {huggingfacetoken}",
         | 
| 623 | 
            +
                }
         | 
| 624 | 
            +
                for candidate in candidates:
         | 
| 625 | 
            +
                    model_id = candidate["id"]
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    if inference_mode != "local":
         | 
| 628 | 
            +
                        huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
         | 
| 629 | 
            +
                        thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
         | 
| 630 | 
            +
                        threads.append(thread)
         | 
| 631 | 
            +
                        thread.start()
         | 
| 632 | 
            +
                    
         | 
| 633 | 
            +
                    if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
         | 
| 634 | 
            +
                        thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
         | 
| 635 | 
            +
                        threads.append(thread)
         | 
| 636 | 
            +
                        thread.start()
         | 
| 637 | 
            +
                    
         | 
| 638 | 
            +
                result_count = len(threads)
         | 
| 639 | 
            +
                while result_count:
         | 
| 640 | 
            +
                    model_id, status, endpoint_type = result_queue.get()
         | 
| 641 | 
            +
                    if status and model_id not in all_available_models:
         | 
| 642 | 
            +
                        all_available_models[endpoint_type].append(model_id)
         | 
| 643 | 
            +
                    if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
         | 
| 644 | 
            +
                        break
         | 
| 645 | 
            +
                    result_count -= 1
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                for thread in threads:
         | 
| 648 | 
            +
                    thread.join()
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                return all_available_models
         | 
| 651 | 
            +
             | 
| 652 | 
            +
            def collect_result(command, choose, inference_result):
         | 
| 653 | 
            +
                result = {"task": command}
         | 
| 654 | 
            +
                result["inference result"] = inference_result
         | 
| 655 | 
            +
                result["choose model result"] = choose
         | 
| 656 | 
            +
                logger.debug(f"inference result: {inference_result}")
         | 
| 657 | 
            +
                return result
         | 
| 658 | 
            +
             | 
| 659 | 
            +
             | 
| 660 | 
            +
            def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
         | 
| 661 | 
            +
                id = command["id"]
         | 
| 662 | 
            +
                args = command["args"]
         | 
| 663 | 
            +
                task = command["task"]
         | 
| 664 | 
            +
                deps = command["dep"]
         | 
| 665 | 
            +
                if deps[0] != -1:
         | 
| 666 | 
            +
                    dep_tasks = [results[dep] for dep in deps]
         | 
| 667 | 
            +
                else:
         | 
| 668 | 
            +
                    dep_tasks = []
         | 
| 669 | 
            +
                
         | 
| 670 | 
            +
                logger.debug(f"Run task: {id} - {task}")
         | 
| 671 | 
            +
                logger.debug("Deps: " + json.dumps(dep_tasks))
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                if deps[0] != -1:
         | 
| 674 | 
            +
                    if "image" in args and "<GENERATED>-" in args["image"]:
         | 
| 675 | 
            +
                        resource_id = int(args["image"].split("-")[1])
         | 
| 676 | 
            +
                        if "generated image" in results[resource_id]["inference result"]:
         | 
| 677 | 
            +
                            args["image"] = results[resource_id]["inference result"]["generated image"]
         | 
| 678 | 
            +
                    if "audio" in args and "<GENERATED>-" in args["audio"]:
         | 
| 679 | 
            +
                        resource_id = int(args["audio"].split("-")[1])
         | 
| 680 | 
            +
                        if "generated audio" in results[resource_id]["inference result"]:
         | 
| 681 | 
            +
                            args["audio"] = results[resource_id]["inference result"]["generated audio"]
         | 
| 682 | 
            +
                    if "text" in args and "<GENERATED>-" in args["text"]:
         | 
| 683 | 
            +
                        resource_id = int(args["text"].split("-")[1])
         | 
| 684 | 
            +
                        if "generated text" in results[resource_id]["inference result"]:
         | 
| 685 | 
            +
                            args["text"] = results[resource_id]["inference result"]["generated text"]
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                text = image = audio = None
         | 
| 688 | 
            +
                for dep_task in dep_tasks:
         | 
| 689 | 
            +
                    if "generated text" in dep_task["inference result"]:
         | 
| 690 | 
            +
                        text = dep_task["inference result"]["generated text"]
         | 
| 691 | 
            +
                        logger.debug("Detect the generated text of dependency task (from results):" + text)
         | 
| 692 | 
            +
                    elif "text" in dep_task["task"]["args"]:
         | 
| 693 | 
            +
                        text = dep_task["task"]["args"]["text"]
         | 
| 694 | 
            +
                        logger.debug("Detect the text of dependency task (from args): " + text)
         | 
| 695 | 
            +
                    if "generated image" in dep_task["inference result"]:
         | 
| 696 | 
            +
                        image = dep_task["inference result"]["generated image"]
         | 
| 697 | 
            +
                        logger.debug("Detect the generated image of dependency task (from results): " + image)
         | 
| 698 | 
            +
                    elif "image" in dep_task["task"]["args"]:
         | 
| 699 | 
            +
                        image = dep_task["task"]["args"]["image"]
         | 
| 700 | 
            +
                        logger.debug("Detect the image of dependency task (from args): " + image)
         | 
| 701 | 
            +
                    if "generated audio" in dep_task["inference result"]:
         | 
| 702 | 
            +
                        audio = dep_task["inference result"]["generated audio"]
         | 
| 703 | 
            +
                        logger.debug("Detect the generated audio of dependency task (from results): " + audio)
         | 
| 704 | 
            +
                    elif "audio" in dep_task["task"]["args"]:
         | 
| 705 | 
            +
                        audio = dep_task["task"]["args"]["audio"]
         | 
| 706 | 
            +
                        logger.debug("Detect the audio of dependency task (from args): " + audio)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                if "image" in args and "<GENERATED>" in args["image"]:
         | 
| 709 | 
            +
                    if image:
         | 
| 710 | 
            +
                        args["image"] = image
         | 
| 711 | 
            +
                if "audio" in args and "<GENERATED>" in args["audio"]:
         | 
| 712 | 
            +
                    if audio:
         | 
| 713 | 
            +
                        args["audio"] = audio
         | 
| 714 | 
            +
                if "text" in args and "<GENERATED>" in args["text"]:
         | 
| 715 | 
            +
                    if text:
         | 
| 716 | 
            +
                        args["text"] = text
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                for resource in ["image", "audio"]:
         | 
| 719 | 
            +
                    if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
         | 
| 720 | 
            +
                        args[resource] = f"public/{args[resource]}"
         | 
| 721 | 
            +
                
         | 
| 722 | 
            +
                if "-text-to-image" in command['task'] and "text" not in args:
         | 
| 723 | 
            +
                    logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
         | 
| 724 | 
            +
                    control = task.split("-")[0]
         | 
| 725 | 
            +
                    
         | 
| 726 | 
            +
                    if control == "seg":
         | 
| 727 | 
            +
                        task = "image-segmentation"
         | 
| 728 | 
            +
                        command['task'] = task
         | 
| 729 | 
            +
                    elif control == "depth":
         | 
| 730 | 
            +
                        task = "depth-estimation"
         | 
| 731 | 
            +
                        command['task'] = task
         | 
| 732 | 
            +
                    else:
         | 
| 733 | 
            +
                        task = f"{control}-control"
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                command["args"] = args
         | 
| 736 | 
            +
                logger.debug(f"parsed task: {command}")
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                if task.endswith("-text-to-image") or task.endswith("-control"):
         | 
| 739 | 
            +
                    if inference_mode != "huggingface":
         | 
| 740 | 
            +
                        if task.endswith("-text-to-image"):
         | 
| 741 | 
            +
                            control = task.split("-")[0]
         | 
| 742 | 
            +
                            best_model_id = f"lllyasviel/sd-controlnet-{control}"
         | 
| 743 | 
            +
                        else:
         | 
| 744 | 
            +
                            best_model_id = task
         | 
| 745 | 
            +
                        hosted_on = "local"
         | 
| 746 | 
            +
                        reason = "ControlNet is the best model for this task."
         | 
| 747 | 
            +
                        choose = {"id": best_model_id, "reason": reason}
         | 
| 748 | 
            +
                        logger.debug(f"chosen model: {choose}")
         | 
| 749 | 
            +
                    else:
         | 
| 750 | 
            +
                        logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
         | 
| 751 | 
            +
                        record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
         | 
| 752 | 
            +
                        inference_result = {"error": f"service related to ControlNet is not available."}
         | 
| 753 | 
            +
                        results[id] = collect_result(command, "", inference_result)
         | 
| 754 | 
            +
                        return False
         | 
| 755 | 
            +
                elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
         | 
| 756 | 
            +
                    best_model_id = "ChatGPT"
         | 
| 757 | 
            +
                    reason = "ChatGPT performs well on some NLP tasks as well."
         | 
| 758 | 
            +
                    choose = {"id": best_model_id, "reason": reason}
         | 
| 759 | 
            +
                    messages = [{
         | 
| 760 | 
            +
                        "role": "user",
         | 
| 761 | 
            +
                        "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
         | 
| 762 | 
            +
                    }]
         | 
| 763 | 
            +
                    response = chitchat(messages, openaikey)
         | 
| 764 | 
            +
                    results[id] = collect_result(command, choose, {"response": response})
         | 
| 765 | 
            +
                    return True
         | 
| 766 | 
            +
                else:
         | 
| 767 | 
            +
                    if task not in MODELS_MAP:
         | 
| 768 | 
            +
                        logger.warning(f"no available models on {task} task.")
         | 
| 769 | 
            +
                        record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
         | 
| 770 | 
            +
                        inference_result = {"error": f"{command['task']} not found in available tasks."}
         | 
| 771 | 
            +
                        results[id] = collect_result(command, "", inference_result)
         | 
| 772 | 
            +
                        return False
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                    candidates = MODELS_MAP[task][:20]
         | 
| 775 | 
            +
                    all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
         | 
| 776 | 
            +
                    all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
         | 
| 777 | 
            +
                    logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                    if len(all_avaliable_model_ids) == 0:
         | 
| 780 | 
            +
                        logger.warning(f"no available models on {command['task']}")
         | 
| 781 | 
            +
                        record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
         | 
| 782 | 
            +
                        inference_result = {"error": f"no available models on {command['task']} task."}
         | 
| 783 | 
            +
                        results[id] = collect_result(command, "", inference_result)
         | 
| 784 | 
            +
                        return False
         | 
| 785 | 
            +
                        
         | 
| 786 | 
            +
                    if len(all_avaliable_model_ids) == 1:
         | 
| 787 | 
            +
                        best_model_id = all_avaliable_model_ids[0]
         | 
| 788 | 
            +
                        hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
         | 
| 789 | 
            +
                        reason = "Only one model available."
         | 
| 790 | 
            +
                        choose = {"id": best_model_id, "reason": reason}
         | 
| 791 | 
            +
                        logger.debug(f"chosen model: {choose}")
         | 
| 792 | 
            +
                    else:
         | 
| 793 | 
            +
                        cand_models_info = [
         | 
| 794 | 
            +
                            {
         | 
| 795 | 
            +
                                "id": model["id"],
         | 
| 796 | 
            +
                                "inference endpoint": all_avaliable_models.get(
         | 
| 797 | 
            +
                                    "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
         | 
| 798 | 
            +
                                ),
         | 
| 799 | 
            +
                                "likes": model.get("likes"),
         | 
| 800 | 
            +
                                "description": model.get("description", "")[:config["max_description_length"]],
         | 
| 801 | 
            +
                                "language": model.get("language"),
         | 
| 802 | 
            +
                                "tags": model.get("tags"),
         | 
| 803 | 
            +
                            }
         | 
| 804 | 
            +
                            for model in candidates
         | 
| 805 | 
            +
                            if model["id"] in all_avaliable_model_ids
         | 
| 806 | 
            +
                        ]
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                        choose_str = choose_model(input, command, cand_models_info, openaikey)
         | 
| 809 | 
            +
                        logger.debug(f"chosen model: {choose_str}")
         | 
| 810 | 
            +
                        try:
         | 
| 811 | 
            +
                            choose = json.loads(choose_str)
         | 
| 812 | 
            +
                            reason = choose["reason"]
         | 
| 813 | 
            +
                            best_model_id = choose["id"]
         | 
| 814 | 
            +
                            hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
         | 
| 815 | 
            +
                        except Exception as e:
         | 
| 816 | 
            +
                            logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
         | 
| 817 | 
            +
                            choose_str = find_json(choose_str)
         | 
| 818 | 
            +
                            best_model_id, reason, choose  = get_id_reason(choose_str)
         | 
| 819 | 
            +
                            hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
         | 
| 820 | 
            +
                inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                if "error" in inference_result:
         | 
| 823 | 
            +
                    logger.warning(f"Inference error: {inference_result['error']}")
         | 
| 824 | 
            +
                    record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
         | 
| 825 | 
            +
                    results[id] = collect_result(command, choose, inference_result)
         | 
| 826 | 
            +
                    return False
         | 
| 827 | 
            +
                
         | 
| 828 | 
            +
                results[id] = collect_result(command, choose, inference_result)
         | 
| 829 | 
            +
                return True
         | 
| 830 | 
            +
             | 
| 831 | 
            +
            def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
         | 
| 832 | 
            +
                start = time.time()
         | 
| 833 | 
            +
                context = messages[:-1]
         | 
| 834 | 
            +
                input = messages[-1]["content"]
         | 
| 835 | 
            +
                logger.info("*"*80)
         | 
| 836 | 
            +
                logger.info(f"input: {input}")
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                task_str = parse_task(context, input, openaikey)
         | 
| 839 | 
            +
                logger.info(task_str)
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                if "error" in task_str:
         | 
| 842 | 
            +
                    return str(task_str), {}
         | 
| 843 | 
            +
                else:
         | 
| 844 | 
            +
                    task_str = task_str.strip()
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                try:
         | 
| 847 | 
            +
                    tasks = json.loads(task_str)
         | 
| 848 | 
            +
                except Exception as e:
         | 
| 849 | 
            +
                    logger.debug(e)
         | 
| 850 | 
            +
                    response = chitchat(messages, openaikey)
         | 
| 851 | 
            +
                    record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
         | 
| 852 | 
            +
                    return response, {}
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                if task_str == "[]":  # using LLM response for empty task
         | 
| 855 | 
            +
                    record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
         | 
| 856 | 
            +
                    response = chitchat(messages, openaikey)
         | 
| 857 | 
            +
                    return response, {}
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
         | 
| 860 | 
            +
                    record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
         | 
| 861 | 
            +
                    response = chitchat(messages, openaikey)
         | 
| 862 | 
            +
                    best_model_id = "ChatGPT"
         | 
| 863 | 
            +
                    reason = "ChatGPT performs well on some NLP tasks as well."
         | 
| 864 | 
            +
                    choose = {"id": best_model_id, "reason": reason}
         | 
| 865 | 
            +
                    return response, collect_result(tasks[0], choose, {"response": response})
         | 
| 866 | 
            +
                
         | 
| 867 | 
            +
             | 
| 868 | 
            +
                tasks = unfold(tasks)
         | 
| 869 | 
            +
                tasks = fix_dep(tasks)
         | 
| 870 | 
            +
                logger.debug(tasks)
         | 
| 871 | 
            +
                
         | 
| 872 | 
            +
                if return_planning:
         | 
| 873 | 
            +
                    return tasks
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                results = {}
         | 
| 876 | 
            +
                threads = []
         | 
| 877 | 
            +
                tasks = tasks[:]
         | 
| 878 | 
            +
                d = dict()
         | 
| 879 | 
            +
                retry = 0
         | 
| 880 | 
            +
                while True:
         | 
| 881 | 
            +
                    num_threads = len(threads)
         | 
| 882 | 
            +
                    for task in tasks:
         | 
| 883 | 
            +
                        dep = task["dep"]
         | 
| 884 | 
            +
                        # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
         | 
| 885 | 
            +
                        for dep_id in dep:
         | 
| 886 | 
            +
                            if dep_id >= task["id"]:
         | 
| 887 | 
            +
                                task["dep"] = [-1]
         | 
| 888 | 
            +
                                dep = [-1]
         | 
| 889 | 
            +
                                break
         | 
| 890 | 
            +
                        if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
         | 
| 891 | 
            +
                            tasks.remove(task)
         | 
| 892 | 
            +
                            thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
         | 
| 893 | 
            +
                            thread.start()
         | 
| 894 | 
            +
                            threads.append(thread)
         | 
| 895 | 
            +
                    if num_threads == len(threads):
         | 
| 896 | 
            +
                        time.sleep(0.5)
         | 
| 897 | 
            +
                        retry += 1
         | 
| 898 | 
            +
                    if retry > 160:
         | 
| 899 | 
            +
                        logger.debug("User has waited too long, Loop break.")
         | 
| 900 | 
            +
                        break
         | 
| 901 | 
            +
                    if len(tasks) == 0:
         | 
| 902 | 
            +
                        break
         | 
| 903 | 
            +
                for thread in threads:
         | 
| 904 | 
            +
                    thread.join()
         | 
| 905 | 
            +
                
         | 
| 906 | 
            +
                results = d.copy()
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                logger.debug(results)
         | 
| 909 | 
            +
                if return_results:
         | 
| 910 | 
            +
                    return results
         | 
| 911 | 
            +
                
         | 
| 912 | 
            +
                response = response_results(input, results, openaikey).strip()
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                end = time.time()
         | 
| 915 | 
            +
                during = end - start
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                answer = {"message": response}
         | 
| 918 | 
            +
                record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
         | 
| 919 | 
            +
                logger.info(f"response: {response}")
         | 
| 920 | 
            +
                return response, results
         | 
    	
        config.gradio.yaml
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            openai:
         | 
| 2 | 
            +
              key: gradio # "gradio" (set when request) or your_personal_key
         | 
| 3 | 
            +
            huggingface:
         | 
| 4 | 
            +
              token: # required: huggingface token @ https://huggingface.co/settings/tokens
         | 
| 5 | 
            +
            dev: false
         | 
| 6 | 
            +
            debug: true
         | 
| 7 | 
            +
            log_file: logs/debug.log
         | 
| 8 | 
            +
            model: text-davinci-003 # text-davinci-003
         | 
| 9 | 
            +
            use_completion: true
         | 
| 10 | 
            +
            inference_mode: hybrid # local, huggingface or hybrid
         | 
| 11 | 
            +
            local_deployment: standard # minimal, standard or full
         | 
| 12 | 
            +
            num_candidate_models: 5
         | 
| 13 | 
            +
            max_description_length: 100
         | 
| 14 | 
            +
            proxy: 
         | 
| 15 | 
            +
            logit_bias:
         | 
| 16 | 
            +
              parse_task: 0.5
         | 
| 17 | 
            +
              choose_model: 5
         | 
| 18 | 
            +
            tprompt:
         | 
| 19 | 
            +
              parse_task: >-
         | 
| 20 | 
            +
                #1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": {"text": text or <GENERATED>-dep_id, "image": image_url or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one genereted text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The "args" field must in ["text", "image", "audio"], nothing else. The task MUST be selected from the following options: "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "text-to-video", "visual-question-answering", "document-question-answering", "image-segmentation", "depth-estimation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON []. 
         | 
| 21 | 
            +
              choose_model: >-
         | 
| 22 | 
            +
                #2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.
         | 
| 23 | 
            +
              response_results: >-
         | 
| 24 | 
            +
                #4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
         | 
| 25 | 
            +
            demos_or_presteps:
         | 
| 26 | 
            +
              parse_task: demos/demo_parse_task.json
         | 
| 27 | 
            +
              choose_model: demos/demo_choose_model.json
         | 
| 28 | 
            +
              response_results: demos/demo_response_results.json 
         | 
| 29 | 
            +
            prompt:
         | 
| 30 | 
            +
              parse_task: The chat log [ {{context}} ] may contain the resources I mentioned. Now I input { {{input}} }. Pay attention to the input and output types of tasks and the dependencies between tasks.
         | 
| 31 | 
            +
              choose_model: >-
         | 
| 32 | 
            +
                Please choose the most suitable model from {{metas}} for the task {{task}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
         | 
| 33 | 
            +
              response_results: >-
         | 
| 34 | 
            +
                Yes. Please first think carefully and directly answer my request based on the inference results. Some of the inferences may not always turn out to be correct and require you to make careful consideration in making decisions. Then please detail your workflow including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. }
         | 
    	
        data/p0_models.jsonl
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        demos/demo_choose_model.json
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [
         | 
| 2 | 
            +
                {
         | 
| 3 | 
            +
                    "role": "user",
         | 
| 4 | 
            +
                    "content": "{{input}}"
         | 
| 5 | 
            +
                },
         | 
| 6 | 
            +
                {
         | 
| 7 | 
            +
                    "role": "assistant",
         | 
| 8 | 
            +
                    "content": "{{task}}"
         | 
| 9 | 
            +
                }
         | 
| 10 | 
            +
            ]
         | 
    	
        demos/demo_parse_task.json
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [
         | 
| 2 | 
            +
                {
         | 
| 3 | 
            +
                    "role": "user",
         | 
| 4 | 
            +
                    "content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?"
         | 
| 5 | 
            +
                },
         | 
| 6 | 
            +
                {
         | 
| 7 | 
            +
                    "role": "assistant",
         | 
| 8 | 
            +
                    "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }},  {\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}]"
         | 
| 9 | 
            +
                },
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                {
         | 
| 12 | 
            +
                    "role":"user",
         | 
| 13 | 
            +
                    "content":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture and video similar to this one."
         | 
| 14 | 
            +
                },
         | 
| 15 | 
            +
                {
         | 
| 16 | 
            +
                    "role":"assistant",
         | 
| 17 | 
            +
                    "content":"[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}, {\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}, {\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"text-to-video\", \"id\": 5, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}]"
         | 
| 18 | 
            +
                },
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                {
         | 
| 21 | 
            +
                    "role":"user",
         | 
| 22 | 
            +
                    "content":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice"
         | 
| 23 | 
            +
                },
         | 
| 24 | 
            +
                {
         | 
| 25 | 
            +
                    "role":"assistant",
         | 
| 26 | 
            +
                    "content":"{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpg\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
         | 
| 27 | 
            +
                },
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                {
         | 
| 30 | 
            +
                    "role": "user",
         | 
| 31 | 
            +
                    "content": "Given an image /example.jpg, first generate a hed image, then based on the hed image generate a new image where a girl is reading a book"
         | 
| 32 | 
            +
                },
         | 
| 33 | 
            +
                {
         | 
| 34 | 
            +
                    "role": "assistant",
         | 
| 35 | 
            +
                    "content": "[{\"task\": \"openpose-control\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }},  {\"task\": \"openpose-text-to-image\", \"id\": 1, \"dep\": [0], \"args\": {\"text\": \"a girl is reading a book\", \"image\": \"<GENERATED>-0\" }}]"
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                {
         | 
| 39 | 
            +
                    "role": "user",
         | 
| 40 | 
            +
                    "content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it"
         | 
| 41 | 
            +
                },
         | 
| 42 | 
            +
                {
         | 
| 43 | 
            +
                    "role": "assistant",
         | 
| 44 | 
            +
                    "content": "[{\"task\": \"text-to-video\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-image\", \"id\": 2, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}]"
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
                {
         | 
| 49 | 
            +
                    "role": "user",
         | 
| 50 | 
            +
                    "content": "please show me a joke and an image of cat"
         | 
| 51 | 
            +
                },
         | 
| 52 | 
            +
                {
         | 
| 53 | 
            +
                    "role": "assistant",
         | 
| 54 | 
            +
                    "content": "[{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"please show me a joke of cat\" }}, {\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a photo of cat\" }}]"
         | 
| 55 | 
            +
                },
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                {
         | 
| 58 | 
            +
                    "role": "user",
         | 
| 59 | 
            +
                    "content": "give me a picture about a cut dog, then describe the image to me and tell a story about it"
         | 
| 60 | 
            +
                },
         | 
| 61 | 
            +
                {
         | 
| 62 | 
            +
                    "role": "assistant",
         | 
| 63 | 
            +
                    "content": "[{\"task\": \"text-to-image\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a picture of a cut dog\" }}, {\"task\": \"image-to-text\", \"id\": 1, \"dep\": [0], \"args\": {\"image\": \"<GENERATED>-0\" }}, {\"task\": \"text-generation\", \"id\": 2, \"dep\": [1], \"args\": {\"text\": \"<GENERATED>-1\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
         | 
| 64 | 
            +
                }
         | 
| 65 | 
            +
            ]
         | 
    	
        demos/demo_response_results.json
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [
         | 
| 2 | 
            +
                {
         | 
| 3 | 
            +
                    "role": "user",
         | 
| 4 | 
            +
                    "content": "{{input}}"
         | 
| 5 | 
            +
                },
         | 
| 6 | 
            +
                {
         | 
| 7 | 
            +
                    "role": "assistant",
         | 
| 8 | 
            +
                    "content": "Before give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {{processes}}. Do you have any demands regarding my response?"
         | 
| 9 | 
            +
                }
         | 
| 10 | 
            +
            ]
         | 
    	
        get_token_ids.py
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import tiktoken
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            encodings = {
         | 
| 4 | 
            +
                "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
         | 
| 5 | 
            +
                "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
         | 
| 6 | 
            +
                "text-davinci-003": tiktoken.get_encoding("p50k_base"),
         | 
| 7 | 
            +
                "text-davinci-002": tiktoken.get_encoding("p50k_base"),
         | 
| 8 | 
            +
                "text-davinci-001": tiktoken.get_encoding("r50k_base"),
         | 
| 9 | 
            +
                "text-curie-001": tiktoken.get_encoding("r50k_base"),
         | 
| 10 | 
            +
                "text-babbage-001": tiktoken.get_encoding("r50k_base"),
         | 
| 11 | 
            +
                "text-ada-001": tiktoken.get_encoding("r50k_base"),
         | 
| 12 | 
            +
                "davinci": tiktoken.get_encoding("r50k_base"),
         | 
| 13 | 
            +
                "curie": tiktoken.get_encoding("r50k_base"),
         | 
| 14 | 
            +
                "babbage": tiktoken.get_encoding("r50k_base"),
         | 
| 15 | 
            +
                "ada": tiktoken.get_encoding("r50k_base"),
         | 
| 16 | 
            +
            }
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            max_length = {
         | 
| 19 | 
            +
                "gpt-3.5-turbo": 4096,
         | 
| 20 | 
            +
                "gpt-3.5-turbo-0301": 4096,
         | 
| 21 | 
            +
                "text-davinci-003": 4096,
         | 
| 22 | 
            +
                "text-davinci-002": 4096,
         | 
| 23 | 
            +
                "text-davinci-001": 2049,
         | 
| 24 | 
            +
                "text-curie-001": 2049,
         | 
| 25 | 
            +
                "text-babbage-001": 2049,
         | 
| 26 | 
            +
                "text-ada-001": 2049,
         | 
| 27 | 
            +
                "davinci": 2049,
         | 
| 28 | 
            +
                "curie": 2049,
         | 
| 29 | 
            +
                "babbage": 2049,
         | 
| 30 | 
            +
                "ada": 2049
         | 
| 31 | 
            +
            }
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def count_tokens(model_name, text):
         | 
| 34 | 
            +
                return len(encodings[model_name].encode(text))
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def get_max_context_length(model_name):
         | 
| 37 | 
            +
                return max_length[model_name]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def get_token_ids_for_task_parsing(model_name):
         | 
| 40 | 
            +
                text = '''{"task": "text-classification",  "token-classification", "text2text-generation", "summarization", "translation",  "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}'''
         | 
| 41 | 
            +
                res = encodings[model_name].encode(text)
         | 
| 42 | 
            +
                res = list(set(res))
         | 
| 43 | 
            +
                return res
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def get_token_ids_for_choose_model(model_name):
         | 
| 46 | 
            +
                text = '''{"id": "reason"}'''
         | 
| 47 | 
            +
                res = encodings[model_name].encode(text)
         | 
| 48 | 
            +
                res = list(set(res))
         | 
| 49 | 
            +
                return res
         | 
    	
        models_server.py
    ADDED
    
    | @@ -0,0 +1,618 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import uuid
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from transformers import pipeline
         | 
| 7 | 
            +
            from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
         | 
| 8 | 
            +
            from diffusers.utils import load_image
         | 
| 9 | 
            +
            from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
         | 
| 10 | 
            +
            from diffusers.utils import export_to_video
         | 
| 11 | 
            +
            from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech
         | 
| 12 | 
            +
            from transformers import BlipProcessor, BlipForConditionalGeneration
         | 
| 13 | 
            +
            from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
         | 
| 14 | 
            +
            from datasets import load_dataset
         | 
| 15 | 
            +
            from PIL import Image
         | 
| 16 | 
            +
            import io
         | 
| 17 | 
            +
            from torchvision import transforms
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torchaudio
         | 
| 20 | 
            +
            from speechbrain.pretrained import WaveformEnhancement
         | 
| 21 | 
            +
            import joblib
         | 
| 22 | 
            +
            from huggingface_hub import hf_hub_url, cached_download
         | 
| 23 | 
            +
            from transformers import AutoImageProcessor, TimesformerForVideoClassification
         | 
| 24 | 
            +
            from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, AutoFeatureExtractor
         | 
| 25 | 
            +
            from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
         | 
| 26 | 
            +
            from controlnet_aux.open_pose.body import Body
         | 
| 27 | 
            +
            from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
         | 
| 28 | 
            +
            from controlnet_aux.hed import Network
         | 
| 29 | 
            +
            from transformers import DPTForDepthEstimation, DPTFeatureExtractor
         | 
| 30 | 
            +
            import warnings
         | 
| 31 | 
            +
            import time
         | 
| 32 | 
            +
            from espnet2.bin.tts_inference import Text2Speech
         | 
| 33 | 
            +
            import soundfile as sf
         | 
| 34 | 
            +
            from asteroid.models import BaseModel
         | 
| 35 | 
            +
            import traceback
         | 
| 36 | 
            +
            import os
         | 
| 37 | 
            +
            import yaml
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            parser = argparse.ArgumentParser()
         | 
| 42 | 
            +
            parser.add_argument("--config", type=str, default="config.yaml")
         | 
| 43 | 
            +
            args = parser.parse_args()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            if __name__ != "__main__":
         | 
| 46 | 
            +
                args.config = "config.gradio.yaml"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 49 | 
            +
            logger.setLevel(logging.INFO)
         | 
| 50 | 
            +
            handler = logging.StreamHandler()
         | 
| 51 | 
            +
            handler.setLevel(logging.INFO)
         | 
| 52 | 
            +
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
         | 
| 53 | 
            +
            handler.setFormatter(formatter)
         | 
| 54 | 
            +
            logger.addHandler(handler)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            local_deployment = config["local_deployment"]
         | 
| 59 | 
            +
            if config["inference_mode"] == "huggingface":
         | 
| 60 | 
            +
                local_deployment = "none"
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            PROXY = None
         | 
| 63 | 
            +
            if config["proxy"]:
         | 
| 64 | 
            +
                PROXY = {
         | 
| 65 | 
            +
                    "https": config["proxy"],
         | 
| 66 | 
            +
                }
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            start = time.time()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # local_models = "models/"
         | 
| 71 | 
            +
            local_models = ""
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def load_pipes(local_deployment):
         | 
| 75 | 
            +
                other_pipes = {}
         | 
| 76 | 
            +
                standard_pipes = {}
         | 
| 77 | 
            +
                controlnet_sd_pipes = {}
         | 
| 78 | 
            +
                if local_deployment in ["full"]:
         | 
| 79 | 
            +
                    other_pipes = {
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        # "Salesforce/blip-image-captioning-large": {
         | 
| 82 | 
            +
                        #     "model": BlipForConditionalGeneration.from_pretrained(f"Salesforce/blip-image-captioning-large"),
         | 
| 83 | 
            +
                        #     "processor": BlipProcessor.from_pretrained(f"Salesforce/blip-image-captioning-large"),
         | 
| 84 | 
            +
                        #     "device": "cuda:0"
         | 
| 85 | 
            +
                        # },
         | 
| 86 | 
            +
                        "damo-vilab/text-to-video-ms-1.7b": {
         | 
| 87 | 
            +
                            "model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
         | 
| 88 | 
            +
                            "device": "cuda:0"
         | 
| 89 | 
            +
                        },
         | 
| 90 | 
            +
                        # "facebook/maskformer-swin-large-ade": {
         | 
| 91 | 
            +
                        #     "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-large-ade"),
         | 
| 92 | 
            +
                        #     "feature_extractor" : AutoFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade"),
         | 
| 93 | 
            +
                        #     "device": "cuda:0"
         | 
| 94 | 
            +
                        # },
         | 
| 95 | 
            +
                        # "microsoft/trocr-base-printed": {
         | 
| 96 | 
            +
                        #     "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-printed"),
         | 
| 97 | 
            +
                        #     "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-printed"),
         | 
| 98 | 
            +
                        #     "device": "cuda:0"
         | 
| 99 | 
            +
                        # },
         | 
| 100 | 
            +
                        # "microsoft/trocr-base-handwritten": {
         | 
| 101 | 
            +
                        #     "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-handwritten"),
         | 
| 102 | 
            +
                        #     "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-handwritten"),
         | 
| 103 | 
            +
                        #     "device": "cuda:0"
         | 
| 104 | 
            +
                        # },
         | 
| 105 | 
            +
                        "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
         | 
| 106 | 
            +
                            "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
         | 
| 107 | 
            +
                            "device": "cuda:0"
         | 
| 108 | 
            +
                        },
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        # "CompVis/stable-diffusion-v1-4": {
         | 
| 111 | 
            +
                        #     "model": DiffusionPipeline.from_pretrained(f"CompVis/stable-diffusion-v1-4"),
         | 
| 112 | 
            +
                        #     "device": "cuda:0"
         | 
| 113 | 
            +
                        # },
         | 
| 114 | 
            +
                        # "stabilityai/stable-diffusion-2-1": {
         | 
| 115 | 
            +
                        #     "model": DiffusionPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1"),
         | 
| 116 | 
            +
                        #     "device": "cuda:0"
         | 
| 117 | 
            +
                        # },
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # "microsoft/speecht5_tts":{
         | 
| 120 | 
            +
                        #     "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_tts"),
         | 
| 121 | 
            +
                        #     "model": SpeechT5ForTextToSpeech.from_pretrained(f"microsoft/speecht5_tts"),
         | 
| 122 | 
            +
                        #     "vocoder":  SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
         | 
| 123 | 
            +
                        #     "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
         | 
| 124 | 
            +
                        #     "device": "cuda:0"
         | 
| 125 | 
            +
                        # },
         | 
| 126 | 
            +
                        # "speechbrain/mtl-mimic-voicebank": {
         | 
| 127 | 
            +
                        #     "model": WaveformEnhancement.from_hparams(source="speechbrain/mtl-mimic-voicebank", savedir="models/mtl-mimic-voicebank"),
         | 
| 128 | 
            +
                        #     "device": "cuda:0"
         | 
| 129 | 
            +
                        # },
         | 
| 130 | 
            +
                        "microsoft/speecht5_vc":{
         | 
| 131 | 
            +
                            "processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
         | 
| 132 | 
            +
                            "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
         | 
| 133 | 
            +
                            "vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
         | 
| 134 | 
            +
                            "embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
         | 
| 135 | 
            +
                            "device": "cuda:0"
         | 
| 136 | 
            +
                        },
         | 
| 137 | 
            +
                        # "julien-c/wine-quality": {
         | 
| 138 | 
            +
                        #     "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
         | 
| 139 | 
            +
                        # },
         | 
| 140 | 
            +
                        # "facebook/timesformer-base-finetuned-k400": {
         | 
| 141 | 
            +
                        #     "processor": AutoImageProcessor.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
         | 
| 142 | 
            +
                        #     "model": TimesformerForVideoClassification.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
         | 
| 143 | 
            +
                        #     "device": "cuda:0"
         | 
| 144 | 
            +
                        # },
         | 
| 145 | 
            +
                        "facebook/maskformer-swin-base-coco": {
         | 
| 146 | 
            +
                            "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
         | 
| 147 | 
            +
                            "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
         | 
| 148 | 
            +
                            "device": "cuda:0"
         | 
| 149 | 
            +
                        },
         | 
| 150 | 
            +
                        "Intel/dpt-hybrid-midas": {
         | 
| 151 | 
            +
                            "model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
         | 
| 152 | 
            +
                            "feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
         | 
| 153 | 
            +
                            "device": "cuda:0"
         | 
| 154 | 
            +
                        }
         | 
| 155 | 
            +
                    }
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                if local_deployment in ["full", "standard"]:
         | 
| 158 | 
            +
                    standard_pipes = {
         | 
| 159 | 
            +
                        # "nlpconnect/vit-gpt2-image-captioning":{
         | 
| 160 | 
            +
                        #     "model": VisionEncoderDecoderModel.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
         | 
| 161 | 
            +
                        #     "feature_extractor": ViTImageProcessor.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
         | 
| 162 | 
            +
                        #     "tokenizer": AutoTokenizer.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
         | 
| 163 | 
            +
                        #     "device": "cuda:0"
         | 
| 164 | 
            +
                        # },
         | 
| 165 | 
            +
                        "espnet/kan-bayashi_ljspeech_vits": {
         | 
| 166 | 
            +
                            "model": Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits"),
         | 
| 167 | 
            +
                            "device": "cuda:0"
         | 
| 168 | 
            +
                        },
         | 
| 169 | 
            +
                        # "lambdalabs/sd-image-variations-diffusers": {
         | 
| 170 | 
            +
                        #     "model": DiffusionPipeline.from_pretrained(f"{local_models}lambdalabs/sd-image-variations-diffusers"), #torch_dtype=torch.float16
         | 
| 171 | 
            +
                        #     "device": "cuda:0"
         | 
| 172 | 
            +
                        # },
         | 
| 173 | 
            +
                        "runwayml/stable-diffusion-v1-5": {
         | 
| 174 | 
            +
                            "model": DiffusionPipeline.from_pretrained(f"{local_models}runwayml/stable-diffusion-v1-5"),
         | 
| 175 | 
            +
                            "device": "cuda:0"
         | 
| 176 | 
            +
                        },
         | 
| 177 | 
            +
                        # "superb/wav2vec2-base-superb-ks": {
         | 
| 178 | 
            +
                        #     "model": pipeline(task="audio-classification", model=f"superb/wav2vec2-base-superb-ks"), 
         | 
| 179 | 
            +
                        #     "device": "cuda:0"
         | 
| 180 | 
            +
                        # },
         | 
| 181 | 
            +
                        "openai/whisper-base": {
         | 
| 182 | 
            +
                            "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}openai/whisper-base"), 
         | 
| 183 | 
            +
                            "device": "cuda:0"
         | 
| 184 | 
            +
                        },
         | 
| 185 | 
            +
                        # "microsoft/speecht5_asr": {
         | 
| 186 | 
            +
                        #     "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}microsoft/speecht5_asr"), 
         | 
| 187 | 
            +
                        #     "device": "cuda:0"
         | 
| 188 | 
            +
                        # },
         | 
| 189 | 
            +
                        "Intel/dpt-large": {
         | 
| 190 | 
            +
                            "model": pipeline(task="depth-estimation", model=f"{local_models}Intel/dpt-large"), 
         | 
| 191 | 
            +
                            "device": "cuda:0"
         | 
| 192 | 
            +
                        },
         | 
| 193 | 
            +
                        # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
         | 
| 194 | 
            +
                        #     "model": pipeline(task="image-classification", model=f"microsoft/beit-base-patch16-224-pt22k-ft22k"), 
         | 
| 195 | 
            +
                        #     "device": "cuda:0"
         | 
| 196 | 
            +
                        # },
         | 
| 197 | 
            +
                        "facebook/detr-resnet-50-panoptic": {
         | 
| 198 | 
            +
                            "model": pipeline(task="image-segmentation", model=f"{local_models}facebook/detr-resnet-50-panoptic"), 
         | 
| 199 | 
            +
                            "device": "cuda:0"
         | 
| 200 | 
            +
                        },
         | 
| 201 | 
            +
                        "facebook/detr-resnet-101": {
         | 
| 202 | 
            +
                            "model": pipeline(task="object-detection", model=f"{local_models}facebook/detr-resnet-101"), 
         | 
| 203 | 
            +
                            "device": "cuda:0"
         | 
| 204 | 
            +
                        },
         | 
| 205 | 
            +
                        # "openai/clip-vit-large-patch14": {
         | 
| 206 | 
            +
                        #     "model": pipeline(task="zero-shot-image-classification", model=f"openai/clip-vit-large-patch14"), 
         | 
| 207 | 
            +
                        #     "device": "cuda:0"
         | 
| 208 | 
            +
                        # },
         | 
| 209 | 
            +
                        # "google/owlvit-base-patch32": {
         | 
| 210 | 
            +
                        #     "model": pipeline(task="zero-shot-object-detection", model=f"{local_models}google/owlvit-base-patch32"), 
         | 
| 211 | 
            +
                        #     "device": "cuda:0"
         | 
| 212 | 
            +
                        # },
         | 
| 213 | 
            +
                        # "microsoft/DialoGPT-medium": {
         | 
| 214 | 
            +
                        #     "model": pipeline(task="conversational", model=f"microsoft/DialoGPT-medium"), 
         | 
| 215 | 
            +
                        #     "device": "cuda:0"
         | 
| 216 | 
            +
                        # },
         | 
| 217 | 
            +
                        # "bert-base-uncased": {
         | 
| 218 | 
            +
                        #     "model": pipeline(task="fill-mask", model=f"bert-base-uncased"), 
         | 
| 219 | 
            +
                        #     "device": "cuda:0"
         | 
| 220 | 
            +
                        # },
         | 
| 221 | 
            +
                        # "deepset/roberta-base-squad2": {
         | 
| 222 | 
            +
                        #     "model": pipeline(task = "question-answering", model=f"deepset/roberta-base-squad2"), 
         | 
| 223 | 
            +
                        #     "device": "cuda:0"
         | 
| 224 | 
            +
                        # },
         | 
| 225 | 
            +
                        # "facebook/bart-large-cnn": {
         | 
| 226 | 
            +
                        #     "model": pipeline(task="summarization", model=f"facebook/bart-large-cnn"), 
         | 
| 227 | 
            +
                        #     "device": "cuda:0"
         | 
| 228 | 
            +
                        # },
         | 
| 229 | 
            +
                        # "google/tapas-base-finetuned-wtq": {
         | 
| 230 | 
            +
                        #     "model": pipeline(task="table-question-answering", model=f"google/tapas-base-finetuned-wtq"), 
         | 
| 231 | 
            +
                        #     "device": "cuda:0"
         | 
| 232 | 
            +
                        # },
         | 
| 233 | 
            +
                        # "distilbert-base-uncased-finetuned-sst-2-english": {
         | 
| 234 | 
            +
                        #     "model": pipeline(task="text-classification", model=f"distilbert-base-uncased-finetuned-sst-2-english"), 
         | 
| 235 | 
            +
                        #     "device": "cuda:0"
         | 
| 236 | 
            +
                        # },
         | 
| 237 | 
            +
                        # "gpt2": {
         | 
| 238 | 
            +
                        #     "model": pipeline(task="text-generation", model="gpt2"), 
         | 
| 239 | 
            +
                        #     "device": "cuda:0"
         | 
| 240 | 
            +
                        # },
         | 
| 241 | 
            +
                        # "mrm8488/t5-base-finetuned-question-generation-ap": {
         | 
| 242 | 
            +
                        #     "model": pipeline(task="text2text-generation", model=f"mrm8488/t5-base-finetuned-question-generation-ap"), 
         | 
| 243 | 
            +
                        #     "device": "cuda:0"
         | 
| 244 | 
            +
                        # },
         | 
| 245 | 
            +
                        # "Jean-Baptiste/camembert-ner": {
         | 
| 246 | 
            +
                        #     "model": pipeline(task="token-classification", model=f"Jean-Baptiste/camembert-ner", aggregation_strategy="simple"), 
         | 
| 247 | 
            +
                        #     "device": "cuda:0"
         | 
| 248 | 
            +
                        # },
         | 
| 249 | 
            +
                        # "t5-base": {
         | 
| 250 | 
            +
                        #     "model": pipeline(task="translation", model=f"t5-base"), 
         | 
| 251 | 
            +
                        #     "device": "cuda:0"
         | 
| 252 | 
            +
                        # },
         | 
| 253 | 
            +
                        "impira/layoutlm-document-qa": {
         | 
| 254 | 
            +
                            "model": pipeline(task="document-question-answering", model=f"{local_models}impira/layoutlm-document-qa"), 
         | 
| 255 | 
            +
                            "device": "cuda:0"
         | 
| 256 | 
            +
                        },
         | 
| 257 | 
            +
                        "ydshieh/vit-gpt2-coco-en": {
         | 
| 258 | 
            +
                            "model": pipeline(task="image-to-text", model=f"{local_models}ydshieh/vit-gpt2-coco-en"), 
         | 
| 259 | 
            +
                            "device": "cuda:0"
         | 
| 260 | 
            +
                        },
         | 
| 261 | 
            +
                        "dandelin/vilt-b32-finetuned-vqa": {
         | 
| 262 | 
            +
                            "model": pipeline(task="visual-question-answering", model=f"{local_models}dandelin/vilt-b32-finetuned-vqa"), 
         | 
| 263 | 
            +
                            "device": "cuda:0"
         | 
| 264 | 
            +
                        }
         | 
| 265 | 
            +
                    }
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                if local_deployment in ["full", "standard", "minimal"]:
         | 
| 268 | 
            +
                    
         | 
| 269 | 
            +
                    controlnet = ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
         | 
| 270 | 
            +
                    controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
         | 
| 271 | 
            +
                        f"{local_models}runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
         | 
| 272 | 
            +
                    )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
                    hed_network = HEDdetector.from_pretrained('lllyasviel/ControlNet')
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    controlnet_sd_pipes = {
         | 
| 278 | 
            +
                        "openpose-control": {
         | 
| 279 | 
            +
                            "model": OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
         | 
| 280 | 
            +
                        },
         | 
| 281 | 
            +
                        "mlsd-control": {
         | 
| 282 | 
            +
                            "model": MLSDdetector.from_pretrained('lllyasviel/ControlNet')
         | 
| 283 | 
            +
                        },
         | 
| 284 | 
            +
                        "hed-control": {
         | 
| 285 | 
            +
                            "model": hed_network
         | 
| 286 | 
            +
                        },
         | 
| 287 | 
            +
                        "scribble-control": {
         | 
| 288 | 
            +
                            "model": hed_network
         | 
| 289 | 
            +
                        },
         | 
| 290 | 
            +
                        "midas-control": {
         | 
| 291 | 
            +
                            "model": MidasDetector.from_pretrained('lllyasviel/ControlNet')
         | 
| 292 | 
            +
                        },
         | 
| 293 | 
            +
                        "canny-control": {
         | 
| 294 | 
            +
                            "model": CannyDetector()
         | 
| 295 | 
            +
                        },
         | 
| 296 | 
            +
                        "lllyasviel/sd-controlnet-canny":{
         | 
| 297 | 
            +
                            "control": controlnet, 
         | 
| 298 | 
            +
                            "model": controlnetpipe,
         | 
| 299 | 
            +
                            "device": "cuda:0"
         | 
| 300 | 
            +
                        },
         | 
| 301 | 
            +
                        "lllyasviel/sd-controlnet-depth":{
         | 
| 302 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16),
         | 
| 303 | 
            +
                            "model": controlnetpipe,
         | 
| 304 | 
            +
                            "device": "cuda:0"
         | 
| 305 | 
            +
                        },
         | 
| 306 | 
            +
                        "lllyasviel/sd-controlnet-hed":{
         | 
| 307 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16), 
         | 
| 308 | 
            +
                            "model": controlnetpipe,
         | 
| 309 | 
            +
                            "device": "cuda:0"
         | 
| 310 | 
            +
                        },
         | 
| 311 | 
            +
                        "lllyasviel/sd-controlnet-mlsd":{
         | 
| 312 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16), 
         | 
| 313 | 
            +
                            "model": controlnetpipe,
         | 
| 314 | 
            +
                            "device": "cuda:0"
         | 
| 315 | 
            +
                        },
         | 
| 316 | 
            +
                        "lllyasviel/sd-controlnet-openpose":{
         | 
| 317 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16), 
         | 
| 318 | 
            +
                            "model": controlnetpipe,
         | 
| 319 | 
            +
                            "device": "cuda:0"
         | 
| 320 | 
            +
                        },
         | 
| 321 | 
            +
                        "lllyasviel/sd-controlnet-scribble":{
         | 
| 322 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16), 
         | 
| 323 | 
            +
                            "model": controlnetpipe,
         | 
| 324 | 
            +
                            "device": "cuda:0"
         | 
| 325 | 
            +
                        },
         | 
| 326 | 
            +
                        "lllyasviel/sd-controlnet-seg":{
         | 
| 327 | 
            +
                            "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16), 
         | 
| 328 | 
            +
                            "model": controlnetpipe,
         | 
| 329 | 
            +
                            "device": "cuda:0"
         | 
| 330 | 
            +
                        }    
         | 
| 331 | 
            +
                    }
         | 
| 332 | 
            +
                pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes}
         | 
| 333 | 
            +
                return pipes
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            pipes = load_pipes(local_deployment)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
            end = time.time()
         | 
| 338 | 
            +
            during = end - start
         | 
| 339 | 
            +
             | 
| 340 | 
            +
            print(f"[ ready ] {during}s")
         | 
| 341 | 
            +
             | 
| 342 | 
            +
            def running():
         | 
| 343 | 
            +
                return {"running": True}
         | 
| 344 | 
            +
             | 
| 345 | 
            +
            def status(model_id):
         | 
| 346 | 
            +
                disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"]
         | 
| 347 | 
            +
                if model_id in pipes.keys() and model_id not in disabled_models:
         | 
| 348 | 
            +
                    print(f"[ check {model_id} ] success")
         | 
| 349 | 
            +
                    return {"loaded": True}
         | 
| 350 | 
            +
                else:
         | 
| 351 | 
            +
                    print(f"[ check {model_id} ] failed")
         | 
| 352 | 
            +
                    return {"loaded": False}
         | 
| 353 | 
            +
             | 
| 354 | 
            +
            def models(model_id, data):
         | 
| 355 | 
            +
                while "using" in pipes[model_id] and pipes[model_id]["using"]:
         | 
| 356 | 
            +
                    print(f"[ inference {model_id} ] waiting")
         | 
| 357 | 
            +
                    time.sleep(0.1)
         | 
| 358 | 
            +
                pipes[model_id]["using"] = True
         | 
| 359 | 
            +
                print(f"[ inference {model_id} ] start")
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                start = time.time()
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                pipe = pipes[model_id]["model"]
         | 
| 364 | 
            +
                
         | 
| 365 | 
            +
                if "device" in pipes[model_id]:
         | 
| 366 | 
            +
                    try:
         | 
| 367 | 
            +
                        pipe.to(pipes[model_id]["device"])
         | 
| 368 | 
            +
                    except:
         | 
| 369 | 
            +
                        pipe.device = torch.device(pipes[model_id]["device"])
         | 
| 370 | 
            +
                        pipe.model.to(pipes[model_id]["device"])
         | 
| 371 | 
            +
                
         | 
| 372 | 
            +
                result = None
         | 
| 373 | 
            +
                try:
         | 
| 374 | 
            +
                    # text to video
         | 
| 375 | 
            +
                    if model_id == "damo-vilab/text-to-video-ms-1.7b":
         | 
| 376 | 
            +
                        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
         | 
| 377 | 
            +
                        # pipe.enable_model_cpu_offload()
         | 
| 378 | 
            +
                        prompt = data["text"]
         | 
| 379 | 
            +
                        video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
         | 
| 380 | 
            +
                        file_name = str(uuid.uuid4())[:4]
         | 
| 381 | 
            +
                        video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
         | 
| 382 | 
            +
                        
         | 
| 383 | 
            +
                        new_file_name = str(uuid.uuid4())[:4]
         | 
| 384 | 
            +
                        os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4")
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                        if os.path.exists(f"public/videos/{new_file_name}.mp4"): 
         | 
| 387 | 
            +
                            result = {"path": f"/videos/{new_file_name}.mp4"}
         | 
| 388 | 
            +
                        else:
         | 
| 389 | 
            +
                            result = {"path": f"/videos/{file_name}.mp4"}
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    # controlnet
         | 
| 392 | 
            +
                    if model_id.startswith("lllyasviel/sd-controlnet-"):
         | 
| 393 | 
            +
                        pipe.controlnet.to('cpu')
         | 
| 394 | 
            +
                        pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
         | 
| 395 | 
            +
                        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
         | 
| 396 | 
            +
                        control_image = load_image(data["img_url"])
         | 
| 397 | 
            +
                        # generator = torch.manual_seed(66)
         | 
| 398 | 
            +
                        out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
         | 
| 399 | 
            +
                        file_name = str(uuid.uuid4())[:4]
         | 
| 400 | 
            +
                        out_image.save(f"public/images/{file_name}.png")
         | 
| 401 | 
            +
                        result = {"path": f"/images/{file_name}.png"}
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    if model_id.endswith("-control"):
         | 
| 404 | 
            +
                        image = load_image(data["img_url"])
         | 
| 405 | 
            +
                        if "scribble" in model_id:
         | 
| 406 | 
            +
                            control = pipe(image, scribble = True)
         | 
| 407 | 
            +
                        elif "canny" in model_id:
         | 
| 408 | 
            +
                            control = pipe(image, low_threshold=100, high_threshold=200)
         | 
| 409 | 
            +
                        else:
         | 
| 410 | 
            +
                            control = pipe(image)
         | 
| 411 | 
            +
                        file_name = str(uuid.uuid4())[:4]
         | 
| 412 | 
            +
                        control.save(f"public/images/{file_name}.png")
         | 
| 413 | 
            +
                        result = {"path": f"/images/{file_name}.png"}
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # image to image
         | 
| 416 | 
            +
                    if model_id == "lambdalabs/sd-image-variations-diffusers":
         | 
| 417 | 
            +
                        im = load_image(data["img_url"])
         | 
| 418 | 
            +
                        file_name = str(uuid.uuid4())[:4]
         | 
| 419 | 
            +
                        with open(f"public/images/{file_name}.png", "wb") as f:
         | 
| 420 | 
            +
                            f.write(data)
         | 
| 421 | 
            +
                        tform = transforms.Compose([
         | 
| 422 | 
            +
                            transforms.ToTensor(),
         | 
| 423 | 
            +
                            transforms.Resize(
         | 
| 424 | 
            +
                                (224, 224),
         | 
| 425 | 
            +
                                interpolation=transforms.InterpolationMode.BICUBIC,
         | 
| 426 | 
            +
                                antialias=False,
         | 
| 427 | 
            +
                                ),
         | 
| 428 | 
            +
                            transforms.Normalize(
         | 
| 429 | 
            +
                            [0.48145466, 0.4578275, 0.40821073],
         | 
| 430 | 
            +
                            [0.26862954, 0.26130258, 0.27577711]),
         | 
| 431 | 
            +
                        ])
         | 
| 432 | 
            +
                        inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
         | 
| 433 | 
            +
                        out = pipe(inp, guidance_scale=3)
         | 
| 434 | 
            +
                        out["images"][0].save(f"public/images/{file_name}.jpg")
         | 
| 435 | 
            +
                        result = {"path": f"/images/{file_name}.jpg"}
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    # image to text
         | 
| 438 | 
            +
                    if model_id == "Salesforce/blip-image-captioning-large":
         | 
| 439 | 
            +
                        raw_image = load_image(data["img_url"]).convert('RGB')
         | 
| 440 | 
            +
                        text = data["text"]
         | 
| 441 | 
            +
                        inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
         | 
| 442 | 
            +
                        out = pipe.generate(**inputs)
         | 
| 443 | 
            +
                        caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
         | 
| 444 | 
            +
                        result = {"generated text": caption}
         | 
| 445 | 
            +
                    if model_id == "ydshieh/vit-gpt2-coco-en":
         | 
| 446 | 
            +
                        img_url = data["img_url"]
         | 
| 447 | 
            +
                        generated_text = pipe(img_url)[0]['generated_text']
         | 
| 448 | 
            +
                        result = {"generated text": generated_text}
         | 
| 449 | 
            +
                    if model_id == "nlpconnect/vit-gpt2-image-captioning":
         | 
| 450 | 
            +
                        image = load_image(data["img_url"]).convert("RGB")
         | 
| 451 | 
            +
                        pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
         | 
| 452 | 
            +
                        pixel_values = pixel_values.to(pipes[model_id]["device"])
         | 
| 453 | 
            +
                        generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
         | 
| 454 | 
            +
                        generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
         | 
| 455 | 
            +
                        result = {"generated text": generated_text}
         | 
| 456 | 
            +
                    # image to text: OCR
         | 
| 457 | 
            +
                    if model_id == "microsoft/trocr-base-printed" or  model_id == "microsoft/trocr-base-handwritten":
         | 
| 458 | 
            +
                        image = load_image(data["img_url"]).convert("RGB")
         | 
| 459 | 
            +
                        pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
         | 
| 460 | 
            +
                        pixel_values = pixel_values.to(pipes[model_id]["device"])
         | 
| 461 | 
            +
                        generated_ids = pipe.generate(pixel_values)
         | 
| 462 | 
            +
                        generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
         | 
| 463 | 
            +
                        result = {"generated text": generated_text}
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    # text to image
         | 
| 466 | 
            +
                    if model_id == "runwayml/stable-diffusion-v1-5":
         | 
| 467 | 
            +
                        file_name = str(uuid.uuid4())[:4]
         | 
| 468 | 
            +
                        text = data["text"]
         | 
| 469 | 
            +
                        out = pipe(prompt=text)
         | 
| 470 | 
            +
                        out["images"][0].save(f"public/images/{file_name}.jpg")
         | 
| 471 | 
            +
                        result = {"path": f"/images/{file_name}.jpg"}
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    # object detection
         | 
| 474 | 
            +
                    if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
         | 
| 475 | 
            +
                        img_url = data["img_url"]
         | 
| 476 | 
            +
                        open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
         | 
| 477 | 
            +
                        result = pipe(img_url, candidate_labels=open_types)
         | 
| 478 | 
            +
                    
         | 
| 479 | 
            +
                    # VQA
         | 
| 480 | 
            +
                    if model_id == "dandelin/vilt-b32-finetuned-vqa":
         | 
| 481 | 
            +
                        question = data["text"]
         | 
| 482 | 
            +
                        img_url = data["img_url"]
         | 
| 483 | 
            +
                        result = pipe(question=question, image=img_url)
         | 
| 484 | 
            +
                    
         | 
| 485 | 
            +
                    #DQA
         | 
| 486 | 
            +
                    if model_id == "impira/layoutlm-document-qa":
         | 
| 487 | 
            +
                        question = data["text"]
         | 
| 488 | 
            +
                        img_url = data["img_url"]
         | 
| 489 | 
            +
                        result = pipe(img_url, question)
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    # depth-estimation
         | 
| 492 | 
            +
                    if model_id == "Intel/dpt-large":
         | 
| 493 | 
            +
                        output = pipe(data["img_url"])
         | 
| 494 | 
            +
                        image = output['depth']
         | 
| 495 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 496 | 
            +
                        image.save(f"public/images/{name}.jpg")
         | 
| 497 | 
            +
                        result = {"path": f"/images/{name}.jpg"}
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                    if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
         | 
| 500 | 
            +
                        image = load_image(data["img_url"])
         | 
| 501 | 
            +
                        inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
         | 
| 502 | 
            +
                        with torch.no_grad():
         | 
| 503 | 
            +
                            outputs = pipe(**inputs)
         | 
| 504 | 
            +
                            predicted_depth = outputs.predicted_depth
         | 
| 505 | 
            +
                        prediction = torch.nn.functional.interpolate(
         | 
| 506 | 
            +
                            predicted_depth.unsqueeze(1),
         | 
| 507 | 
            +
                            size=image.size[::-1],
         | 
| 508 | 
            +
                            mode="bicubic",
         | 
| 509 | 
            +
                            align_corners=False,
         | 
| 510 | 
            +
                        )
         | 
| 511 | 
            +
                        output = prediction.squeeze().cpu().numpy()
         | 
| 512 | 
            +
                        formatted = (output * 255 / np.max(output)).astype("uint8")
         | 
| 513 | 
            +
                        image = Image.fromarray(formatted)
         | 
| 514 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 515 | 
            +
                        image.save(f"public/images/{name}.jpg")
         | 
| 516 | 
            +
                        result = {"path": f"/images/{name}.jpg"}
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    # TTS
         | 
| 519 | 
            +
                    if model_id == "espnet/kan-bayashi_ljspeech_vits":
         | 
| 520 | 
            +
                        text = data["text"]
         | 
| 521 | 
            +
                        wav = pipe(text)["wav"]
         | 
| 522 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 523 | 
            +
                        sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
         | 
| 524 | 
            +
                        result = {"path": f"/audios/{name}.wav"}
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                    if model_id == "microsoft/speecht5_tts":
         | 
| 527 | 
            +
                        text = data["text"]
         | 
| 528 | 
            +
                        inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
         | 
| 529 | 
            +
                        embeddings_dataset = pipes[model_id]["embeddings_dataset"]
         | 
| 530 | 
            +
                        speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
         | 
| 531 | 
            +
                        pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
         | 
| 532 | 
            +
                        speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
         | 
| 533 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 534 | 
            +
                        sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
         | 
| 535 | 
            +
                        result = {"path": f"/audios/{name}.wav"}
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    # ASR
         | 
| 538 | 
            +
                    if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
         | 
| 539 | 
            +
                        audio_url = data["audio_url"]
         | 
| 540 | 
            +
                        result = { "text": pipe(audio_url)["text"]}
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    # audio to audio
         | 
| 543 | 
            +
                    if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
         | 
| 544 | 
            +
                        audio_url = data["audio_url"]
         | 
| 545 | 
            +
                        wav, sr = torchaudio.load(audio_url)
         | 
| 546 | 
            +
                        with torch.no_grad():
         | 
| 547 | 
            +
                            result_wav = pipe(wav.to(pipes[model_id]["device"]))
         | 
| 548 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 549 | 
            +
                        sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
         | 
| 550 | 
            +
                        result = {"path": f"/audios/{name}.wav"}
         | 
| 551 | 
            +
                    
         | 
| 552 | 
            +
                    if model_id == "microsoft/speecht5_vc":
         | 
| 553 | 
            +
                        audio_url = data["audio_url"]
         | 
| 554 | 
            +
                        wav, sr = torchaudio.load(audio_url)
         | 
| 555 | 
            +
                        inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
         | 
| 556 | 
            +
                        embeddings_dataset = pipes[model_id]["embeddings_dataset"]
         | 
| 557 | 
            +
                        speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
         | 
| 558 | 
            +
                        pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
         | 
| 559 | 
            +
                        speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
         | 
| 560 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 561 | 
            +
                        sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
         | 
| 562 | 
            +
                        result = {"path": f"/audios/{name}.wav"}
         | 
| 563 | 
            +
                    
         | 
| 564 | 
            +
                    # segmentation
         | 
| 565 | 
            +
                    if model_id == "facebook/detr-resnet-50-panoptic":
         | 
| 566 | 
            +
                        result = []
         | 
| 567 | 
            +
                        segments = pipe(data["img_url"])
         | 
| 568 | 
            +
                        image = load_image(data["img_url"])
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                        colors = []
         | 
| 571 | 
            +
                        for i in range(len(segments)):
         | 
| 572 | 
            +
                            colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                        for segment in segments:
         | 
| 575 | 
            +
                            mask = segment["mask"]
         | 
| 576 | 
            +
                            mask = mask.convert('L')
         | 
| 577 | 
            +
                            layer = Image.new('RGBA', mask.size, colors[i])
         | 
| 578 | 
            +
                            image.paste(layer, (0, 0), mask)
         | 
| 579 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 580 | 
            +
                        image.save(f"public/images/{name}.jpg")
         | 
| 581 | 
            +
                        result = {"path": f"/images/{name}.jpg"}
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
         | 
| 584 | 
            +
                        image = load_image(data["img_url"])
         | 
| 585 | 
            +
                        inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
         | 
| 586 | 
            +
                        outputs = pipe(**inputs)
         | 
| 587 | 
            +
                        result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
         | 
| 588 | 
            +
                        predicted_panoptic_map = result["segmentation"].cpu().numpy()
         | 
| 589 | 
            +
                        predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
         | 
| 590 | 
            +
                        name = str(uuid.uuid4())[:4]
         | 
| 591 | 
            +
                        predicted_panoptic_map.save(f"public/images/{name}.jpg")
         | 
| 592 | 
            +
                        result = {"path": f"/images/{name}.jpg"}
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                except Exception as e:
         | 
| 595 | 
            +
                    print(e)
         | 
| 596 | 
            +
                    traceback.print_exc()
         | 
| 597 | 
            +
                    result = {"error": {"message": "Error when running the model inference."}}
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                if "device" in pipes[model_id]:
         | 
| 600 | 
            +
                    try:
         | 
| 601 | 
            +
                        pipe.to("cpu")
         | 
| 602 | 
            +
                        torch.cuda.empty_cache()
         | 
| 603 | 
            +
                    except:
         | 
| 604 | 
            +
                        pipe.device = torch.device("cpu")
         | 
| 605 | 
            +
                        pipe.model.to("cpu")
         | 
| 606 | 
            +
                        torch.cuda.empty_cache()
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                pipes[model_id]["using"] = False
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                if result is None:
         | 
| 611 | 
            +
                    result = {"error": {"message": "model not found"}}
         | 
| 612 | 
            +
                
         | 
| 613 | 
            +
                end = time.time()
         | 
| 614 | 
            +
                during = end - start
         | 
| 615 | 
            +
                print(f"[ complete {model_id} ] {during}s")
         | 
| 616 | 
            +
                print(f"[ result {model_id} ] {result}")
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                return result
         | 
    	
        packages.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            tesseract-ocr
         | 
    	
        public/examples/a.jpg
    ADDED
    
    |   | 
    	
        public/examples/b.jpg
    ADDED
    
    |   | 
    	
        public/examples/c.jpg
    ADDED
    
    |   | 
    	
        public/examples/d.jpg
    ADDED
    
    |   | 
    	
        public/examples/e.jpg
    ADDED
    
    |   | 
    	
        public/examples/f.jpg
    ADDED
    
    |   | 
    	
        public/examples/g.jpg
    ADDED
    
    |   | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            git+https://github.com/huggingface/diffusers.git@8c530fc2f6a76a2aefb6b285dce6df1675092ac6#egg=diffusers
         | 
| 2 | 
            +
            git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda#egg=transformers
         | 
| 3 | 
            +
            git+https://github.com/patrickvonplaten/controlnet_aux@78efc716868a7f5669c288233d65b471f542ce40#egg=controlnet_aux
         | 
| 4 | 
            +
            tiktoken==0.3.3
         | 
| 5 | 
            +
            pydub==0.25.1
         | 
| 6 | 
            +
            espnet==202301
         | 
| 7 | 
            +
            espnet_model_zoo==0.1.7
         | 
| 8 | 
            +
            flask==2.2.3
         | 
| 9 | 
            +
            flask_cors==3.0.10
         | 
| 10 | 
            +
            waitress==2.1.2
         | 
| 11 | 
            +
            datasets==2.11.0
         | 
| 12 | 
            +
            asteroid==0.6.0
         | 
| 13 | 
            +
            speechbrain==0.5.14
         | 
| 14 | 
            +
            timm==0.6.13
         | 
| 15 | 
            +
            typeguard==2.13.3
         | 
| 16 | 
            +
            accelerate==0.18.0
         | 
| 17 | 
            +
            pytesseract==0.3.10
         | 
| 18 | 
            +
            basicsr==1.4.2
         | 
