|
|
import os |
|
|
import openai |
|
|
from typing import TYPE_CHECKING, Literal, Optional |
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
|
|
|
if TYPE_CHECKING: from histopath.config import HistoPathConfig |
|
|
|
|
|
SourceType = Literal["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "HuggingFace", "Custom"] |
|
|
ALLOWED_SOURCES: set[str] = set(SourceType.__args__) |
|
|
|
|
|
|
|
|
def get_llm( |
|
|
model: str | None = None, |
|
|
temperature: float | None = None, |
|
|
stop_sequences: list[str] | None = None, |
|
|
source: SourceType | None = None, |
|
|
base_url: str | None = None, |
|
|
api_key: str | None = None, |
|
|
config: Optional["HistoPathConfig"] = None, |
|
|
) -> BaseChatModel: |
|
|
""" |
|
|
Get a language model instance based on the specified model name and source. |
|
|
This function supports models from OpenAI, Azure OpenAI, Anthropic, Ollama, Gemini, Bedrock, and custom model serving. |
|
|
Args: |
|
|
model (str): The model name to use |
|
|
temperature (float): Temperature setting for generation |
|
|
stop_sequences (list): Sequences that will stop generation |
|
|
source (str): Source provider: "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", or "Custom" |
|
|
If None, will attempt to auto-detect from model name |
|
|
base_url (str): The base URL for custom model serving (e.g., "http://localhost:8000/v1"), default is None |
|
|
api_key (str): The API key for the custom llm |
|
|
config (BiomniConfig): Optional configuration object. If provided, unspecified parameters will use config values |
|
|
""" |
|
|
|
|
|
if config is not None: |
|
|
if model is None: |
|
|
model = config.llm_model |
|
|
if temperature is None: |
|
|
temperature = config.temperature |
|
|
if source is None: |
|
|
source = config.source |
|
|
if base_url is None: |
|
|
base_url = config.base_url |
|
|
if api_key is None: |
|
|
api_key = config.api_key or "EMPTY" |
|
|
|
|
|
|
|
|
if model is None: |
|
|
model = "claude-3-5-sonnet-20241022" |
|
|
if temperature is None: |
|
|
temperature = 0.7 |
|
|
if api_key is None: |
|
|
api_key = "EMPTY" |
|
|
|
|
|
if source is None: |
|
|
env_source = os.getenv("LLM_SOURCE") |
|
|
if env_source in ALLOWED_SOURCES: |
|
|
source = env_source |
|
|
else: |
|
|
if model[:7] == "claude-": |
|
|
source = "Anthropic" |
|
|
elif model[:7] == "gpt-oss": |
|
|
source = "Ollama" |
|
|
elif model[:4] == "gpt-": |
|
|
source = "OpenAI" |
|
|
elif model.startswith("azure-"): |
|
|
source = "AzureOpenAI" |
|
|
elif model[:7] == "gemini-": |
|
|
source = "Gemini" |
|
|
elif "groq" in model.lower(): |
|
|
source = "Groq" |
|
|
elif base_url is not None: |
|
|
source = "Custom" |
|
|
elif "/" in model or any( |
|
|
name in model.lower() |
|
|
for name in [ |
|
|
"llama", |
|
|
"mistral", |
|
|
"qwen", |
|
|
"gemma", |
|
|
"phi", |
|
|
"dolphin", |
|
|
"orca", |
|
|
"vicuna", |
|
|
"deepseek", |
|
|
] |
|
|
): |
|
|
source = "Ollama" |
|
|
elif model.startswith( |
|
|
("anthropic.claude-", "amazon.titan-", "meta.llama-", "mistral.", "cohere.", "ai21.", "us.") |
|
|
): |
|
|
source = "Bedrock" |
|
|
else: |
|
|
raise ValueError("Unable to determine model source. Please specify 'source' parameter.") |
|
|
|
|
|
|
|
|
if source == "OpenAI": |
|
|
try: |
|
|
from langchain_openai import ChatOpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-openai package is required for OpenAI models. Install with: pip install langchain-openai" |
|
|
) |
|
|
return ChatOpenAI(model=model, temperature=temperature, stop_sequences=stop_sequences) |
|
|
|
|
|
elif source == "AzureOpenAI": |
|
|
try: |
|
|
from langchain_openai import AzureChatOpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-openai package is required for Azure OpenAI models. Install with: pip install langchain-openai" |
|
|
) |
|
|
API_VERSION = "2024-12-01-preview" |
|
|
model = model.replace("azure-", "") |
|
|
return AzureChatOpenAI( |
|
|
openai_api_key=os.getenv("OPENAI_API_KEY"), |
|
|
azure_endpoint=os.getenv("OPENAI_ENDPOINT"), |
|
|
azure_deployment=model, |
|
|
openai_api_version=API_VERSION, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
elif source == "Anthropic": |
|
|
try: |
|
|
from langchain_anthropic import ChatAnthropic |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-anthropic package is required for Anthropic models. Install with: pip install langchain-anthropic" |
|
|
) |
|
|
return ChatAnthropic( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
max_tokens=8192, |
|
|
stop_sequences=stop_sequences, |
|
|
) |
|
|
|
|
|
elif source == "Gemini": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from langchain_openai import ChatOpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-openai package is required for Gemini models. Install with: pip install langchain-openai" |
|
|
) |
|
|
return ChatOpenAI( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
stop_sequences=stop_sequences, |
|
|
) |
|
|
|
|
|
elif source == "Groq": |
|
|
try: |
|
|
from langchain_openai import ChatOpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-openai package is required for Groq models. Install with: pip install langchain-openai" |
|
|
) |
|
|
return ChatOpenAI( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
api_key=os.getenv("GROQ_API_KEY"), |
|
|
base_url="https://api.groq.com/openai/v1", |
|
|
stop_sequences=stop_sequences, |
|
|
) |
|
|
|
|
|
elif source == "Ollama": |
|
|
try: |
|
|
from langchain_ollama import ChatOllama |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-ollama package is required for Ollama models. Install with: pip install langchain-ollama" |
|
|
) |
|
|
return ChatOllama( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
elif source == "Bedrock": |
|
|
try: |
|
|
from langchain_aws import ChatBedrock |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-aws package is required for Bedrock models. Install with: pip install langchain-aws" |
|
|
) |
|
|
return ChatBedrock( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
stop_sequences=stop_sequences, |
|
|
region_name=os.getenv("AWS_REGION", "us-east-1"), |
|
|
) |
|
|
elif source == "HuggingFace": |
|
|
try: |
|
|
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-huggingface package is required for HuggingFace models. Install with: pip install langchain-huggingface" |
|
|
) |
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id="openai/gpt-oss-120b", |
|
|
temperature=temperature, |
|
|
stop_sequences=stop_sequences, |
|
|
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY") |
|
|
) |
|
|
return ChatHuggingFace(llm=llm) |
|
|
|
|
|
elif source == "Custom": |
|
|
try: |
|
|
from langchain_openai import ChatOpenAI |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"langchain-openai package is required for custom models. Install with: pip install langchain-openai" |
|
|
) |
|
|
|
|
|
assert base_url is not None, "base_url must be provided for customly served LLMs" |
|
|
llm = ChatOpenAI( |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
max_tokens=8192, |
|
|
stop_sequences=stop_sequences, |
|
|
base_url=base_url, |
|
|
api_key=api_key, |
|
|
) |
|
|
return llm |
|
|
|
|
|
else: |
|
|
raise ValueError( |
|
|
f"Invalid source: {source}. Valid options are 'OpenAI', 'AzureOpenAI', 'Anthropic', 'Gemini', 'Groq', 'Bedrock', or 'Ollama'" |
|
|
) |