| import logging | |
| import os | |
| import random | |
| import threading | |
| import json | |
| import dashscope | |
| from typing import Optional, Literal, Any | |
| import time | |
| import backoff | |
| import dspy | |
| import litellm | |
| import requests | |
| from dashscope import Generation | |
| try: | |
| from anthropic import RateLimitError | |
| except ImportError: | |
| RateLimitError = None | |
| MAX_API_RETRY = 3 | |
| LLM_MIT_RETRY_SLEEP = 5 | |
| SUPPORT_ARGS = {"model", "messages", "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens", | |
| "n", "presence_penalty", "response_format", "seed", "stop", "stream", "temperature", "top_p", | |
| "tools", "tool_choice", "user", "function_call", "functions", "tenant", "max_completion_tokens"} | |
| def truncate_long_strings(d): | |
| if isinstance(d, dict): | |
| return {k: truncate_long_strings(v) for k, v in d.items()} | |
| elif isinstance(d, list): | |
| return [truncate_long_strings(item) for item in d] | |
| elif isinstance(d, str) and len(d) > 100: | |
| return d[:100] + '...' | |
| else: | |
| return d | |
| class QwenModel(dspy.OpenAI): | |
| """A wrapper class for dspy.OpenAI.""" | |
| def __init__( | |
| self, | |
| model: str = "qwen-max-allinone", | |
| api_key: Optional[str] = None, | |
| **kwargs | |
| ): | |
| super().__init__(model=model, api_key=api_key, **kwargs) | |
| self.model = model | |
| self.api_key = api_key | |
| self._token_usage_lock = threading.Lock() | |
| self.prompt_tokens = 0 | |
| self.completion_tokens = 0 | |
| def log_usage(self, response): | |
| """Log the total tokens from the OpenAI API response.""" | |
| usage_data = response.get('usage') | |
| if usage_data: | |
| with self._token_usage_lock: | |
| self.prompt_tokens += usage_data.get('input_tokens', 0) | |
| self.completion_tokens += usage_data.get('output_tokens', 0) | |
| def get_usage_and_reset(self): | |
| """Get the total tokens used and reset the token usage.""" | |
| usage = { | |
| self.kwargs.get('model') or self.kwargs.get('engine'): | |
| {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} | |
| } | |
| self.prompt_tokens = 0 | |
| self.completion_tokens = 0 | |
| return usage | |
| def __call__( | |
| self, | |
| prompt: str, | |
| only_completed: bool = True, | |
| return_sorted: bool = False, | |
| **kwargs, | |
| ) -> list[dict[str, Any]]: | |
| """Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" | |
| assert only_completed, "for now" | |
| assert return_sorted is False, "for now" | |
| messages = [{'role': 'user', 'content': prompt}] | |
| max_retries = 3 | |
| attempt = 0 | |
| while attempt < max_retries: | |
| try: | |
| response = Generation.call( | |
| model=self.model, | |
| messages=messages, | |
| result_format='message', | |
| ) | |
| choices = response["output"]["choices"] | |
| break | |
| except Exception as e: | |
| print(f"请求失败: {e}. 尝试重新请求...") | |
| delay = random.uniform(0, 3) | |
| print(f"等待 {delay:.2f} 秒后重试...") | |
| time.sleep(delay) | |
| attempt += 1 | |
| self.log_usage(response) | |
| completed_choices = [c for c in choices if c["finish_reason"] != "length"] | |
| if only_completed and len(completed_choices): | |
| choices = completed_choices | |
| completions = [c['message']['content'] for c in choices] | |
| return completions | |