Spaces:
Sleeping
Sleeping
| from ..presets import * | |
| from ..utils import * | |
| from .base_model import BaseLLMModel | |
| class ERNIE_Client(BaseLLMModel): | |
| def __init__(self, model_name, api_key, secret_key) -> None: | |
| super().__init__(model_name=model_name) | |
| self.api_key = api_key | |
| self.api_secret = secret_key | |
| if None in [self.api_secret, self.api_key]: | |
| raise Exception("请在配置文件或者环境变量中设置文心一言的API Key 和 Secret Key") | |
| if self.model_name == "ERNIE-Bot-turbo": | |
| self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" | |
| elif self.model_name == "ERNIE-Bot": | |
| self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" | |
| elif self.model_name == "ERNIE-Bot-4": | |
| self.ERNIE_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" | |
| def get_access_token(self): | |
| """ | |
| 使用 AK,SK 生成鉴权签名(Access Token) | |
| :return: access_token,或是None(如果错误) | |
| """ | |
| url = "https://aip.baidubce.com/oauth/2.0/token?client_id=" + self.api_key + "&client_secret=" + self.api_secret + "&grant_type=client_credentials" | |
| payload = json.dumps("") | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Accept': 'application/json' | |
| } | |
| response = requests.request("POST", url, headers=headers, data=payload) | |
| return response.json()["access_token"] | |
| def get_answer_stream_iter(self): | |
| url = self.ERNIE_url + self.get_access_token() | |
| system_prompt = self.system_prompt | |
| history = self.history | |
| if system_prompt is not None: | |
| history = [construct_system(system_prompt), *history] | |
| # 去除history中 history的role为system的 | |
| history = [i for i in history if i["role"] != "system"] | |
| payload = json.dumps({ | |
| "messages":history, | |
| "stream": True | |
| }) | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| response = requests.request("POST", url, headers=headers, data=payload, stream=True) | |
| if response.status_code == 200: | |
| partial_text = "" | |
| for line in response.iter_lines(): | |
| if len(line) == 0: | |
| continue | |
| line = json.loads(line[5:]) | |
| partial_text += line['result'] | |
| yield partial_text | |
| else: | |
| yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG | |
| def get_answer_at_once(self): | |
| url = self.ERNIE_url + self.get_access_token() | |
| system_prompt = self.system_prompt | |
| history = self.history | |
| if system_prompt is not None: | |
| history = [construct_system(system_prompt), *history] | |
| # 去除history中 history的role为system的 | |
| history = [i for i in history if i["role"] != "system"] | |
| payload = json.dumps({ | |
| "messages": history, | |
| "stream": True | |
| }) | |
| headers = { | |
| 'Content-Type': 'application/json' | |
| } | |
| response = requests.request("POST", url, headers=headers, data=payload, stream=True) | |
| if response.status_code == 200: | |
| return str(response.json()["result"]),len(response.json()["result"]) | |
| else: | |
| return "获取资源错误", 0 | |