ryanDing26
App release
f2a52eb
import os
import openai
from typing import TYPE_CHECKING, Literal, Optional
from langchain_core.language_models.chat_models import BaseChatModel
if TYPE_CHECKING: from histopath.config import HistoPathConfig
SourceType = Literal["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "HuggingFace", "Custom"]
ALLOWED_SOURCES: set[str] = set(SourceType.__args__)
def get_llm(
model: str | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
source: SourceType | None = None,
base_url: str | None = None,
api_key: str | None = None,
config: Optional["HistoPathConfig"] = None,
) -> BaseChatModel:
"""
Get a language model instance based on the specified model name and source.
This function supports models from OpenAI, Azure OpenAI, Anthropic, Ollama, Gemini, Bedrock, and custom model serving.
Args:
model (str): The model name to use
temperature (float): Temperature setting for generation
stop_sequences (list): Sequences that will stop generation
source (str): Source provider: "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", or "Custom"
If None, will attempt to auto-detect from model name
base_url (str): The base URL for custom model serving (e.g., "http://localhost:8000/v1"), default is None
api_key (str): The API key for the custom llm
config (BiomniConfig): Optional configuration object. If provided, unspecified parameters will use config values
"""
# Use config values for any unspecified parameters
if config is not None:
if model is None:
model = config.llm_model
if temperature is None:
temperature = config.temperature
if source is None:
source = config.source
if base_url is None:
base_url = config.base_url
if api_key is None:
api_key = config.api_key or "EMPTY"
# Use defaults if still not specified
if model is None:
model = "claude-3-5-sonnet-20241022"
if temperature is None:
temperature = 0.7
if api_key is None:
api_key = "EMPTY"
# Auto-detect source from model name if not specified
if source is None:
env_source = os.getenv("LLM_SOURCE")
if env_source in ALLOWED_SOURCES:
source = env_source
else:
if model[:7] == "claude-":
source = "Anthropic"
elif model[:7] == "gpt-oss":
source = "Ollama"
elif model[:4] == "gpt-":
source = "OpenAI"
elif model.startswith("azure-"):
source = "AzureOpenAI"
elif model[:7] == "gemini-":
source = "Gemini"
elif "groq" in model.lower():
source = "Groq"
elif base_url is not None:
source = "Custom"
elif "/" in model or any(
name in model.lower()
for name in [
"llama",
"mistral",
"qwen",
"gemma",
"phi",
"dolphin",
"orca",
"vicuna",
"deepseek",
]
):
source = "Ollama"
elif model.startswith(
("anthropic.claude-", "amazon.titan-", "meta.llama-", "mistral.", "cohere.", "ai21.", "us.")
):
source = "Bedrock"
else:
raise ValueError("Unable to determine model source. Please specify 'source' parameter.")
# Create appropriate model based on source
if source == "OpenAI":
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError( # noqa: B904
"langchain-openai package is required for OpenAI models. Install with: pip install langchain-openai"
)
return ChatOpenAI(model=model, temperature=temperature, stop_sequences=stop_sequences)
elif source == "AzureOpenAI":
try:
from langchain_openai import AzureChatOpenAI
except ImportError:
raise ImportError( # noqa: B904
"langchain-openai package is required for Azure OpenAI models. Install with: pip install langchain-openai"
)
API_VERSION = "2024-12-01-preview"
model = model.replace("azure-", "")
return AzureChatOpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
azure_endpoint=os.getenv("OPENAI_ENDPOINT"),
azure_deployment=model,
openai_api_version=API_VERSION,
temperature=temperature,
)
elif source == "Anthropic":
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
raise ImportError( # noqa: B904
"langchain-anthropic package is required for Anthropic models. Install with: pip install langchain-anthropic"
)
return ChatAnthropic(
model=model,
temperature=temperature,
max_tokens=8192,
stop_sequences=stop_sequences,
)
elif source == "Gemini":
# If you want to use ChatGoogleGenerativeAI, you need to pass the stop sequences upon invoking the model.
# return ChatGoogleGenerativeAI(
# model=model,
# temperature=temperature,
# google_api_key=api_key,
# )
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError( # noqa: B904
"langchain-openai package is required for Gemini models. Install with: pip install langchain-openai"
)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=os.getenv("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
stop_sequences=stop_sequences,
)
elif source == "Groq":
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError( # noqa: B904
"langchain-openai package is required for Groq models. Install with: pip install langchain-openai"
)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=os.getenv("GROQ_API_KEY"),
base_url="https://api.groq.com/openai/v1",
stop_sequences=stop_sequences,
)
elif source == "Ollama":
try:
from langchain_ollama import ChatOllama
except ImportError:
raise ImportError( # noqa: B904
"langchain-ollama package is required for Ollama models. Install with: pip install langchain-ollama"
)
return ChatOllama(
model=model,
temperature=temperature,
)
elif source == "Bedrock":
try:
from langchain_aws import ChatBedrock
except ImportError:
raise ImportError( # noqa: B904
"langchain-aws package is required for Bedrock models. Install with: pip install langchain-aws"
)
return ChatBedrock(
model=model,
temperature=temperature,
stop_sequences=stop_sequences,
region_name=os.getenv("AWS_REGION", "us-east-1"),
)
elif source == "HuggingFace":
try:
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
except ImportError:
raise ImportError(
"langchain-huggingface package is required for HuggingFace models. Install with: pip install langchain-huggingface"
)
llm = HuggingFaceEndpoint(
repo_id="openai/gpt-oss-120b",
temperature=temperature,
stop_sequences=stop_sequences,
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY")
)
return ChatHuggingFace(llm=llm)
elif source == "Custom":
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError( # noqa: B904
"langchain-openai package is required for custom models. Install with: pip install langchain-openai"
)
# Custom LLM serving such as SGLang. Must expose an openai compatible API.
assert base_url is not None, "base_url must be provided for customly served LLMs"
llm = ChatOpenAI(
model=model,
temperature=temperature,
max_tokens=8192,
stop_sequences=stop_sequences,
base_url=base_url,
api_key=api_key,
)
return llm
else:
raise ValueError(
f"Invalid source: {source}. Valid options are 'OpenAI', 'AzureOpenAI', 'Anthropic', 'Gemini', 'Groq', 'Bedrock', or 'Ollama'"
)