Spaces:
Sleeping
Sleeping
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import pathlib | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| import requests | |
| import tiktoken | |
| from PIL import Image | |
| from modules.config import retrieve_proxy | |
| from modules.models.XMChat import XMChat | |
| mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE") | |
| mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL") | |
| mj_temp_folder = os.getenv("MIDJOURNEY_TEMP_FOLDER") | |
| class Midjourney_Client(XMChat): | |
| class FetchDataPack: | |
| """ | |
| A class to store data for current fetching data from Midjourney API | |
| """ | |
| action: str # current action, e.g. "IMAGINE", "UPSCALE", "VARIATION" | |
| prefix_content: str # prefix content, task description and process hint | |
| task_id: str # task id | |
| start_time: float # task start timestamp | |
| timeout: int # task timeout in seconds | |
| finished: bool # whether the task is finished | |
| prompt: str # prompt for the task | |
| def __init__(self, action, prefix_content, task_id, timeout=900): | |
| self.action = action | |
| self.prefix_content = prefix_content | |
| self.task_id = task_id | |
| self.start_time = time.time() | |
| self.timeout = timeout | |
| self.finished = False | |
| def __init__(self, model_name, api_key, user_name=""): | |
| super().__init__(api_key, user_name) | |
| self.model_name = model_name | |
| self.history = [] | |
| self.api_key = api_key | |
| self.headers = { | |
| "Content-Type": "application/json", | |
| "mj-api-secret": f"{api_key}" | |
| } | |
| self.proxy_url = mj_proxy_api_base | |
| self.command_splitter = "::" | |
| if mj_temp_folder: | |
| temp = "./tmp" | |
| if user_name: | |
| temp = os.path.join(temp, user_name) | |
| if not os.path.exists(temp): | |
| os.makedirs(temp) | |
| self.temp_path = tempfile.mkdtemp(dir=temp) | |
| logging.info("mj temp folder: " + self.temp_path) | |
| else: | |
| self.temp_path = None | |
| def use_mj_self_proxy_url(self, img_url): | |
| """ | |
| replace discord cdn url with mj self proxy url | |
| """ | |
| return img_url.replace( | |
| "https://cdn.discordapp.com/", | |
| mj_discord_proxy_url and mj_discord_proxy_url or "https://cdn.discordapp.com/" | |
| ) | |
| def split_image(self, image_url): | |
| """ | |
| when enabling temp dir, split image into 4 parts | |
| """ | |
| with retrieve_proxy(): | |
| image_bytes = requests.get(image_url).content | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| width, height = img.size | |
| # calculate half width and height | |
| half_width = width // 2 | |
| half_height = height // 2 | |
| # create coordinates (top-left x, top-left y, bottom-right x, bottom-right y) | |
| coordinates = [(0, 0, half_width, half_height), | |
| (half_width, 0, width, half_height), | |
| (0, half_height, half_width, height), | |
| (half_width, half_height, width, height)] | |
| images = [img.crop(c) for c in coordinates] | |
| return images | |
| def auth_mj(self): | |
| """ | |
| auth midjourney api | |
| """ | |
| # TODO: check if secret is valid | |
| return {'status': 'ok'} | |
| def request_mj(self, path: str, action: str, data: str, retries=3): | |
| """ | |
| request midjourney api | |
| """ | |
| mj_proxy_url = self.proxy_url | |
| if mj_proxy_url is None or not (mj_proxy_url.startswith("http://") or mj_proxy_url.startswith("https://")): | |
| raise Exception('please set MIDJOURNEY_PROXY_API_BASE in ENV or in config.json') | |
| auth_ = self.auth_mj() | |
| if auth_.get('error'): | |
| raise Exception('auth not set') | |
| fetch_url = f"{mj_proxy_url}/{path}" | |
| # logging.info(f"[MJ Proxy] {action} {fetch_url} params: {data}") | |
| for _ in range(retries): | |
| try: | |
| with retrieve_proxy(): | |
| res = requests.request(method=action, url=fetch_url, headers=self.headers, data=data) | |
| break | |
| except Exception as e: | |
| print(e) | |
| if res.status_code != 200: | |
| raise Exception(f'{res.status_code} - {res.content}') | |
| return res | |
| def fetch_status(self, fetch_data: FetchDataPack): | |
| """ | |
| fetch status of current task | |
| """ | |
| if fetch_data.start_time + fetch_data.timeout < time.time(): | |
| fetch_data.finished = True | |
| return "任务超时,请检查 dc 输出。描述:" + fetch_data.prompt | |
| time.sleep(3) | |
| status_res = self.request_mj(f"task/{fetch_data.task_id}/fetch", "GET", '') | |
| status_res_json = status_res.json() | |
| if not (200 <= status_res.status_code < 300): | |
| raise Exception("任务状态获取失败:" + status_res_json.get( | |
| 'error') or status_res_json.get('description') or '未知错误') | |
| else: | |
| fetch_data.finished = False | |
| if status_res_json['status'] == "SUCCESS": | |
| content = status_res_json['imageUrl'] | |
| fetch_data.finished = True | |
| elif status_res_json['status'] == "FAILED": | |
| content = status_res_json['failReason'] or '未知原因' | |
| fetch_data.finished = True | |
| elif status_res_json['status'] == "NOT_START": | |
| content = f'任务未开始,已等待 {time.time() - fetch_data.start_time:.2f} 秒' | |
| elif status_res_json['status'] == "IN_PROGRESS": | |
| content = '任务正在运行' | |
| if status_res_json.get('progress'): | |
| content += f",进度:{status_res_json['progress']}" | |
| elif status_res_json['status'] == "SUBMITTED": | |
| content = '任务已提交处理' | |
| elif status_res_json['status'] == "FAILURE": | |
| fetch_data.finished = True | |
| return "任务处理失败,原因:" + status_res_json['failReason'] or '未知原因' | |
| else: | |
| content = status_res_json['status'] | |
| if fetch_data.finished: | |
| img_url = self.use_mj_self_proxy_url(status_res_json['imageUrl']) | |
| if fetch_data.action == "DESCRIBE": | |
| return f"\n{status_res_json['prompt']}" | |
| time_cost_str = f"\n\n{fetch_data.action} 花费时间:{time.time() - fetch_data.start_time:.2f} 秒" | |
| upscale_str = "" | |
| variation_str = "" | |
| if fetch_data.action in ["IMAGINE", "UPSCALE", "VARIATION"]: | |
| upscale = [f'/mj UPSCALE{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' | |
| for i in range(4)] | |
| upscale_str = '\n放大图片:\n\n' + '\n\n'.join(upscale) | |
| variation = [f'/mj VARIATION{self.command_splitter}{i+1}{self.command_splitter}{fetch_data.task_id}' | |
| for i in range(4)] | |
| variation_str = '\n图片变体:\n\n' + '\n\n'.join(variation) | |
| if self.temp_path and fetch_data.action in ["IMAGINE", "VARIATION"]: | |
| try: | |
| images = self.split_image(img_url) | |
| # save images to temp path | |
| for i in range(4): | |
| images[i].save(pathlib.Path(self.temp_path) / f"{fetch_data.task_id}_{i}.png") | |
| img_str = '\n'.join( | |
| [f"" | |
| for i in range(4)]) | |
| return fetch_data.prefix_content + f"{time_cost_str}\n\n{img_str}{upscale_str}{variation_str}" | |
| except Exception as e: | |
| logging.error(e) | |
| return fetch_data.prefix_content + \ | |
| f"{time_cost_str}[]({img_url}){upscale_str}{variation_str}" | |
| else: | |
| content = f"**任务状态:** [{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - {content}" | |
| content += f"\n\n花费时间:{time.time() - fetch_data.start_time:.2f} 秒" | |
| if status_res_json['status'] == 'IN_PROGRESS' and status_res_json.get('imageUrl'): | |
| img_url = status_res_json.get('imageUrl') | |
| return f"{content}\n[]({img_url})" | |
| return content | |
| return None | |
| def handle_file_upload(self, files, chatbot, language): | |
| """ | |
| handle file upload | |
| """ | |
| if files: | |
| for file in files: | |
| if file.name: | |
| logging.info(f"尝试读取图像: {file.name}") | |
| self.try_read_image(file.name) | |
| if self.image_path is not None: | |
| chatbot = chatbot + [((self.image_path,), None)] | |
| if self.image_bytes is not None: | |
| logging.info("使用图片作为输入") | |
| return None, chatbot, None | |
| def reset(self, remain_system_prompt=False): | |
| self.image_bytes = None | |
| self.image_path = None | |
| return super().reset() | |
| def get_answer_at_once(self): | |
| content = self.history[-1]['content'] | |
| answer = self.get_help() | |
| if not content.lower().startswith("/mj"): | |
| return answer, len(content) | |
| prompt = content[3:].strip() | |
| action = "IMAGINE" | |
| first_split_index = prompt.find(self.command_splitter) | |
| if first_split_index > 0: | |
| action = prompt[:first_split_index] | |
| if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", | |
| # "VARIATION", "BLEND", "REROLL" | |
| ]: | |
| raise Exception("任务提交失败:未知的任务类型") | |
| else: | |
| action_index = None | |
| action_use_task_id = None | |
| if action in ["VARIATION", "UPSCALE", "REROLL"]: | |
| action_index = int(prompt[first_split_index + 2:first_split_index + 3]) | |
| action_use_task_id = prompt[first_split_index + 5:] | |
| try: | |
| res = None | |
| if action == "IMAGINE": | |
| data = { | |
| "prompt": prompt | |
| } | |
| if self.image_bytes is not None: | |
| data["base64"] = 'data:image/png;base64,' + self.image_bytes | |
| res = self.request_mj("submit/imagine", "POST", | |
| json.dumps(data)) | |
| elif action == "DESCRIBE": | |
| res = self.request_mj("submit/describe", "POST", | |
| json.dumps({"base64": 'data:image/png;base64,' + self.image_bytes})) | |
| elif action == "BLEND": | |
| res = self.request_mj("submit/blend", "POST", json.dumps( | |
| {"base64Array": [self.image_bytes, self.image_bytes]})) | |
| elif action in ["UPSCALE", "VARIATION", "REROLL"]: | |
| res = self.request_mj( | |
| "submit/change", "POST", | |
| json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) | |
| res_json = res.json() | |
| if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): | |
| answer = "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) | |
| else: | |
| task_id = res_json['result'] | |
| prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" | |
| fetch_data = Midjourney_Client.FetchDataPack( | |
| action=action, | |
| prefix_content=prefix_content, | |
| task_id=task_id, | |
| ) | |
| fetch_data.prompt = prompt | |
| while not fetch_data.finished: | |
| answer = self.fetch_status(fetch_data) | |
| except Exception as e: | |
| logging.error("submit failed", e) | |
| answer = "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' | |
| return answer, tiktoken.get_encoding("cl100k_base").encode(content) | |
| def get_answer_stream_iter(self): | |
| content = self.history[-1]['content'] | |
| answer = self.get_help() | |
| if not content.lower().startswith("/mj"): | |
| yield answer | |
| return | |
| prompt = content[3:].strip() | |
| action = "IMAGINE" | |
| first_split_index = prompt.find(self.command_splitter) | |
| if first_split_index > 0: | |
| action = prompt[:first_split_index] | |
| if action not in ["IMAGINE", "DESCRIBE", "UPSCALE", | |
| "VARIATION", "BLEND", "REROLL" | |
| ]: | |
| yield "任务提交失败:未知的任务类型" | |
| return | |
| action_index = None | |
| action_use_task_id = None | |
| if action in ["VARIATION", "UPSCALE", "REROLL"]: | |
| action_index = int(prompt[first_split_index + 2:first_split_index + 3]) | |
| action_use_task_id = prompt[first_split_index + 5:] | |
| try: | |
| res = None | |
| if action == "IMAGINE": | |
| data = { | |
| "prompt": prompt | |
| } | |
| if self.image_bytes is not None: | |
| data["base64"] = 'data:image/png;base64,' + self.image_bytes | |
| res = self.request_mj("submit/imagine", "POST", | |
| json.dumps(data)) | |
| elif action == "DESCRIBE": | |
| res = self.request_mj("submit/describe", "POST", json.dumps( | |
| {"base64": 'data:image/png;base64,' + self.image_bytes})) | |
| elif action == "BLEND": | |
| res = self.request_mj("submit/blend", "POST", json.dumps( | |
| {"base64Array": [self.image_bytes, self.image_bytes]})) | |
| elif action in ["UPSCALE", "VARIATION", "REROLL"]: | |
| res = self.request_mj( | |
| "submit/change", "POST", | |
| json.dumps({"action": action, "index": action_index, "taskId": action_use_task_id})) | |
| res_json = res.json() | |
| if not (200 <= res.status_code < 300) or (res_json['code'] not in [1, 22]): | |
| yield "任务提交失败:" + res_json.get('error', res_json.get('description', '未知错误')) | |
| else: | |
| task_id = res_json['result'] | |
| prefix_content = f"**画面描述:** {prompt}\n**任务ID:** {task_id}\n" | |
| content = f"[{(datetime.now()).strftime('%Y-%m-%d %H:%M:%S')}] - 任务提交成功:" + \ | |
| res_json.get('description') or '请稍等片刻' | |
| yield content | |
| fetch_data = Midjourney_Client.FetchDataPack( | |
| action=action, | |
| prefix_content=prefix_content, | |
| task_id=task_id, | |
| ) | |
| while not fetch_data.finished: | |
| yield self.fetch_status(fetch_data) | |
| except Exception as e: | |
| logging.error('submit failed', e) | |
| yield "任务提交错误:" + str(e.args[0]) if e.args else '未知错误' | |
| def get_help(self): | |
| return """``` | |
| 【绘图帮助】 | |
| 所有命令都需要以 /mj 开头,如:/mj a dog | |
| IMAGINE - 绘图,可以省略该命令,后面跟上绘图内容 | |
| /mj a dog | |
| /mj IMAGINE::a cat | |
| DESCRIBE - 描述图片,需要在右下角上传需要描述的图片内容 | |
| /mj DESCRIBE:: | |
| UPSCALE - 确认后放大图片,第一个数值为需要放大的图片(1~4),第二参数为任务ID | |
| /mj UPSCALE::1::123456789 | |
| 请使用SD进行UPSCALE | |
| VARIATION - 图片变体,第一个数值为需要放大的图片(1~4),第二参数为任务ID | |
| /mj VARIATION::1::123456789 | |
| 【绘图参数】 | |
| 所有命令默认会带上参数--v 5.2 | |
| 其他参数参照 https://docs.midjourney.com/docs/parameter-list | |
| 长宽比 --aspect/--ar | |
| --ar 1:2 | |
| --ar 16:9 | |
| 负面tag --no | |
| --no plants | |
| --no hands | |
| 随机种子 --seed | |
| --seed 1 | |
| 生成动漫风格(NijiJourney) --niji | |
| --niji | |
| ``` | |
| """ | |