|
|
import { |
|
|
Autocomplete, |
|
|
Box, |
|
|
CircularProgress, |
|
|
FormControl, |
|
|
FormHelperText, |
|
|
InputLabel, |
|
|
MenuItem, |
|
|
Select, |
|
|
Stack, |
|
|
TextField, |
|
|
Tooltip, |
|
|
Typography, |
|
|
alpha, |
|
|
useTheme, |
|
|
} from "@mui/material"; |
|
|
import React, { useCallback, useState } from "react"; |
|
|
import { useModelList } from "../hooks/useModelList"; |
|
|
import { PROVIDERS } from "../utils/constants"; |
|
|
import { TooltipIcon } from "./TooltipIcon"; |
|
|
|
|
|
const TRANSITION = "all 0.3s ease-in-out"; |
|
|
|
|
|
interface SuggestedModel { |
|
|
provider: string; |
|
|
model: string; |
|
|
label: string; |
|
|
} |
|
|
|
|
|
interface ModelProviderSelectorProps { |
|
|
model: string; |
|
|
provider: string; |
|
|
onModelChange: (value: string) => void; |
|
|
onProviderChange: (value: string) => void; |
|
|
size?: "small" | "medium"; |
|
|
variant?: "form" | "toolbar" | "suggestions"; |
|
|
showValidation?: boolean; |
|
|
isDisabled?: boolean; |
|
|
|
|
|
suggestedModels?: SuggestedModel[]; |
|
|
selectedSuggestion?: string; |
|
|
onSuggestionChange?: (value: string, provider: string, model: string) => void; |
|
|
customProvider?: string; |
|
|
customModel?: string; |
|
|
onCustomProviderChange?: (provider: string) => void; |
|
|
onCustomModelChange?: (model: string) => void; |
|
|
selectedOption?: "suggested" | "custom"; |
|
|
onOptionChange?: (option: "suggested" | "custom") => void; |
|
|
} |
|
|
|
|
|
interface ModelSelectProps { |
|
|
model: string; |
|
|
provider: string; |
|
|
availableModels: string[]; |
|
|
loading: boolean; |
|
|
error: string | null; |
|
|
onModelChange: (value: string) => void; |
|
|
size?: "small" | "medium"; |
|
|
variant?: "form" | "toolbar"; |
|
|
showValidation?: boolean; |
|
|
fullWidth?: boolean; |
|
|
isDisabled?: boolean; |
|
|
} |
|
|
|
|
|
const ModelSelect: React.FC<ModelSelectProps> = ({ |
|
|
model, |
|
|
provider, |
|
|
availableModels, |
|
|
loading, |
|
|
error, |
|
|
onModelChange, |
|
|
size = "medium", |
|
|
variant = "form", |
|
|
showValidation = false, |
|
|
fullWidth = false, |
|
|
isDisabled = false, |
|
|
}) => { |
|
|
const isToolbar = variant === "toolbar"; |
|
|
|
|
|
if (!provider) { |
|
|
return ( |
|
|
<Autocomplete |
|
|
options={[]} |
|
|
value="" |
|
|
disabled |
|
|
fullWidth={fullWidth} |
|
|
size={size} |
|
|
renderInput={(params) => ( |
|
|
<TextField |
|
|
{...params} |
|
|
label={isToolbar ? "Model" : null} |
|
|
placeholder="Select a provider first" |
|
|
helperText="Please select a provider first" |
|
|
sx={isToolbar ? { minWidth: 250, height: 40 } : undefined} |
|
|
/> |
|
|
)} |
|
|
noOptionsText="Select a provider first" |
|
|
/> |
|
|
); |
|
|
} |
|
|
|
|
|
if (loading) { |
|
|
return ( |
|
|
<Autocomplete |
|
|
options={[]} |
|
|
value="" |
|
|
disabled |
|
|
loading={loading} |
|
|
fullWidth={fullWidth} |
|
|
size={size} |
|
|
renderInput={(params) => ( |
|
|
<TextField |
|
|
{...params} |
|
|
label={isToolbar ? "Model" : null} |
|
|
placeholder="Loading models..." |
|
|
sx={isToolbar ? { minWidth: 250, height: 40 } : undefined} |
|
|
slotProps={{ |
|
|
input: { |
|
|
...params.InputProps, |
|
|
startAdornment: <CircularProgress size={16} sx={{ mx: 1 }} />, |
|
|
}, |
|
|
}} |
|
|
/> |
|
|
)} |
|
|
noOptionsText="Loading models..." |
|
|
/> |
|
|
); |
|
|
} |
|
|
|
|
|
if (error) { |
|
|
return ( |
|
|
<TextField |
|
|
fullWidth={fullWidth} |
|
|
size={size} |
|
|
label={isToolbar ? "Model" : null} |
|
|
value={model} |
|
|
onChange={(e) => onModelChange(e.target.value)} |
|
|
error={showValidation && !model.trim()} |
|
|
helperText={ |
|
|
error || |
|
|
(showValidation && !model.trim() ? "Model ID is required" : "") |
|
|
} |
|
|
sx={isToolbar ? { minWidth: 250, height: 40 } : undefined} |
|
|
variant="outlined" |
|
|
disabled={isDisabled} |
|
|
/> |
|
|
); |
|
|
} |
|
|
|
|
|
return ( |
|
|
<Autocomplete |
|
|
options={availableModels} |
|
|
value={model || null} |
|
|
onChange={(_, newValue) => { |
|
|
onModelChange(newValue || ""); |
|
|
}} |
|
|
disabled={isDisabled} |
|
|
fullWidth={fullWidth} |
|
|
size={size} |
|
|
freeSolo |
|
|
autoHighlight |
|
|
filterOptions={(options, { inputValue }) => { |
|
|
return options.filter((option) => |
|
|
option.toLowerCase().includes(inputValue.toLowerCase()), |
|
|
); |
|
|
}} |
|
|
renderInput={(params) => ( |
|
|
<TextField |
|
|
{...params} |
|
|
label={"Model"} |
|
|
placeholder={ |
|
|
availableModels.length === 0 |
|
|
? "No models available" |
|
|
: "Search models..." |
|
|
} |
|
|
error={showValidation && !model.trim()} |
|
|
helperText={ |
|
|
<Box |
|
|
component={"span"} |
|
|
sx={ |
|
|
showValidation && !model.trim() && variant === "toolbar" |
|
|
? { |
|
|
bgcolor: (theme) => theme.palette.background.default, |
|
|
p: 1, |
|
|
ml: -1, |
|
|
borderRadius: 1, |
|
|
} |
|
|
: {} |
|
|
} |
|
|
> |
|
|
{showValidation && !model.trim() ? "Model is required" : ""} |
|
|
</Box> |
|
|
} |
|
|
sx={isToolbar ? { minWidth: 250, height: 40 } : undefined} |
|
|
/> |
|
|
)} |
|
|
noOptionsText={ |
|
|
availableModels.length === 0 |
|
|
? "No models available" |
|
|
: "No matching models" |
|
|
} |
|
|
/> |
|
|
); |
|
|
}; |
|
|
|
|
|
interface ProviderSelectProps { |
|
|
provider: string; |
|
|
onProviderChange: (value: string) => void; |
|
|
size?: "small" | "medium"; |
|
|
variant?: "form" | "toolbar"; |
|
|
showValidation?: boolean; |
|
|
fullWidth?: boolean; |
|
|
isDisabled?: boolean; |
|
|
} |
|
|
|
|
|
const ProviderSelect: React.FC<ProviderSelectProps> = ({ |
|
|
provider, |
|
|
onProviderChange, |
|
|
size = "medium", |
|
|
variant = "form", |
|
|
showValidation = false, |
|
|
fullWidth = false, |
|
|
isDisabled = false, |
|
|
}: ProviderSelectProps) => { |
|
|
const isToolbar = variant === "toolbar"; |
|
|
|
|
|
return ( |
|
|
<FormControl |
|
|
fullWidth={fullWidth} |
|
|
size={size} |
|
|
sx={isToolbar ? { minWidth: 120, height: 40 } : undefined} |
|
|
error={showValidation && !provider.trim()} |
|
|
> |
|
|
<InputLabel>Provider</InputLabel> |
|
|
<Select |
|
|
value={provider} |
|
|
label={"Provider"} |
|
|
onChange={(e) => onProviderChange(e.target.value)} |
|
|
sx={isToolbar ? { height: 40 } : undefined} |
|
|
disabled={isDisabled} |
|
|
> |
|
|
{PROVIDERS.map((providerOption) => ( |
|
|
<MenuItem key={providerOption} value={providerOption}> |
|
|
{providerOption} |
|
|
</MenuItem> |
|
|
))} |
|
|
</Select> |
|
|
{showValidation && !provider.trim() && ( |
|
|
<FormHelperText>Provider is required</FormHelperText> |
|
|
)} |
|
|
</FormControl> |
|
|
); |
|
|
}; |
|
|
|
|
|
export const ModelProviderSelector: React.FC<ModelProviderSelectorProps> = ({ |
|
|
model, |
|
|
provider, |
|
|
onModelChange, |
|
|
onProviderChange, |
|
|
size = "medium", |
|
|
variant = "form", |
|
|
showValidation = false, |
|
|
isDisabled = false, |
|
|
// Suggestions variant props |
|
|
suggestedModels = [], |
|
|
selectedSuggestion = "", |
|
|
onSuggestionChange, |
|
|
customProvider = "", |
|
|
customModel = "", |
|
|
onCustomProviderChange, |
|
|
onCustomModelChange, |
|
|
selectedOption = "suggested", |
|
|
onOptionChange, |
|
|
}: ModelProviderSelectorProps) => { |
|
|
const { availableModels, loading, error } = useModelList(provider); |
|
|
const theme = useTheme(); |
|
|
const isToolbar = variant === "toolbar"; |
|
|
const isSuggestions = variant === "suggestions"; |
|
|
const fieldSize = isToolbar ? "small" : size; |
|
|
|
|
|
|
|
|
const [providerTooltipOpen, setProviderTooltipOpen] = useState(false); |
|
|
const [modelTooltipOpen, setModelTooltipOpen] = useState(false); |
|
|
|
|
|
const handleProviderChange = useCallback( |
|
|
(newProvider: string) => { |
|
|
onProviderChange(newProvider); |
|
|
}, |
|
|
[onProviderChange], |
|
|
); |
|
|
|
|
|
const handleSuggestionSelectChange = (value: string) => { |
|
|
if (onSuggestionChange) { |
|
|
const selectedModel = suggestedModels.find( |
|
|
(_, index) => index.toString() === value, |
|
|
); |
|
|
if (selectedModel) { |
|
|
onSuggestionChange(value, selectedModel.provider, selectedModel.model); |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
const handleCustomClick = () => { |
|
|
if (onOptionChange) { |
|
|
onOptionChange("custom"); |
|
|
|
|
|
|
|
|
if (selectedSuggestion && onCustomProviderChange && onCustomModelChange) { |
|
|
const selectedModelIndex = parseInt(selectedSuggestion); |
|
|
const selectedModel = suggestedModels[selectedModelIndex]; |
|
|
|
|
|
if (selectedModel) { |
|
|
onCustomProviderChange(selectedModel.provider); |
|
|
onCustomModelChange(selectedModel.model); |
|
|
} |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
if (isSuggestions) { |
|
|
return ( |
|
|
<Box sx={{ mt: 1.5, pb: 1.5 }}> |
|
|
<Box |
|
|
sx={{ |
|
|
display: "flex", |
|
|
alignItems: "center", |
|
|
justifyContent: "space-between", |
|
|
mb: 0.5, |
|
|
}} |
|
|
> |
|
|
<Typography>Select a provider and model</Typography> |
|
|
<TooltipIcon title="Choose your inference provider and model configuration. You can select from our suggested models which are pre-configured with popular providers, or customize your own provider and model combination." /> |
|
|
</Box> |
|
|
|
|
|
<Stack spacing={1}> |
|
|
{/* Suggested Models Box */} |
|
|
<Box |
|
|
onClick={() => |
|
|
!isDisabled && onOptionChange && onOptionChange("suggested") |
|
|
} |
|
|
sx={{ |
|
|
p: 1.5, |
|
|
borderRadius: 1.5, |
|
|
border: "2px solid", |
|
|
borderColor: isDisabled |
|
|
? "divider" |
|
|
: selectedOption === "suggested" |
|
|
? theme.palette.primary.main |
|
|
: alpha(theme.palette.primary.main, 0.2), |
|
|
backgroundColor: isDisabled |
|
|
? alpha(theme.palette.action.disabled, 0.05) |
|
|
: selectedOption === "suggested" |
|
|
? alpha(theme.palette.primary.main, 0.1) |
|
|
: alpha(theme.palette.primary.main, 0.03), |
|
|
cursor: isDisabled ? "not-allowed" : "pointer", |
|
|
opacity: isDisabled ? 0.5 : 1, |
|
|
transition: TRANSITION, |
|
|
"&:hover": isDisabled |
|
|
? {} |
|
|
: { |
|
|
borderColor: alpha(theme.palette.primary.main, 0.4), |
|
|
backgroundColor: alpha(theme.palette.primary.main, 0.08), |
|
|
}, |
|
|
}} |
|
|
> |
|
|
<Typography |
|
|
variant="body2" |
|
|
sx={{ |
|
|
fontWeight: 600, |
|
|
color: isDisabled ? "text.disabled" : "text.primary", |
|
|
mb: 1, |
|
|
}} |
|
|
> |
|
|
Suggested models |
|
|
</Typography> |
|
|
<Typography |
|
|
variant="caption" |
|
|
color={isDisabled ? "text.disabled" : "text.secondary"} |
|
|
sx={{ |
|
|
display: "block", |
|
|
lineHeight: 1.3, |
|
|
mb: selectedOption === "suggested" ? 2 : 0, |
|
|
transition: TRANSITION, |
|
|
}} |
|
|
> |
|
|
Choose from pre-configured provider and model |
|
|
</Typography> |
|
|
|
|
|
<Box |
|
|
sx={{ |
|
|
maxHeight: selectedOption === "suggested" ? "200px" : "0px", |
|
|
opacity: selectedOption === "suggested" ? 1 : 0, |
|
|
transition: TRANSITION, |
|
|
}} |
|
|
> |
|
|
<FormControl fullWidth size="small" disabled={isDisabled}> |
|
|
<InputLabel>Provider/Model</InputLabel> |
|
|
<Select |
|
|
value={selectedSuggestion} |
|
|
label="Provider/Model" |
|
|
onChange={(e) => handleSuggestionSelectChange(e.target.value)} |
|
|
onClick={(e) => e.stopPropagation()} |
|
|
> |
|
|
{suggestedModels.map((model, index) => ( |
|
|
<MenuItem key={index} value={index.toString()}> |
|
|
<Box> |
|
|
<Typography variant="body2" sx={{ fontWeight: 500 }}> |
|
|
{model.label} |
|
|
</Typography> |
|
|
<Typography |
|
|
variant="caption" |
|
|
sx={{ |
|
|
fontFamily: "monospace", |
|
|
fontSize: "0.7rem", |
|
|
color: "text.secondary", |
|
|
mt: 0.5, |
|
|
display: "block", |
|
|
overflow: "hidden", |
|
|
textOverflow: "ellipsis", |
|
|
whiteSpace: "nowrap", |
|
|
}} |
|
|
> |
|
|
{model.provider} • {model.model} |
|
|
</Typography> |
|
|
</Box> |
|
|
</MenuItem> |
|
|
))} |
|
|
</Select> |
|
|
</FormControl> |
|
|
</Box> |
|
|
</Box> |
|
|
|
|
|
{/* Custom Box */} |
|
|
<Box |
|
|
onClick={() => !isDisabled && handleCustomClick()} |
|
|
sx={{ |
|
|
p: 1.5, |
|
|
borderRadius: 1.5, |
|
|
border: "2px solid", |
|
|
borderColor: isDisabled |
|
|
? "divider" |
|
|
: selectedOption === "custom" |
|
|
? theme.palette.primary.main |
|
|
: alpha(theme.palette.primary.main, 0.2), |
|
|
backgroundColor: isDisabled |
|
|
? alpha(theme.palette.action.disabled, 0.05) |
|
|
: selectedOption === "custom" |
|
|
? alpha(theme.palette.primary.main, 0.1) |
|
|
: alpha(theme.palette.primary.main, 0.03), |
|
|
cursor: isDisabled ? "not-allowed" : "pointer", |
|
|
opacity: isDisabled ? 0.5 : 1, |
|
|
transition: TRANSITION, |
|
|
"&:hover": isDisabled |
|
|
? {} |
|
|
: { |
|
|
borderColor: alpha(theme.palette.primary.main, 0.4), |
|
|
backgroundColor: alpha(theme.palette.primary.main, 0.08), |
|
|
}, |
|
|
}} |
|
|
> |
|
|
<Typography |
|
|
variant="body2" |
|
|
sx={{ |
|
|
fontWeight: 600, |
|
|
color: isDisabled ? "text.disabled" : "text.primary", |
|
|
mb: 0.5, |
|
|
}} |
|
|
> |
|
|
Custom configuration |
|
|
</Typography> |
|
|
<Typography |
|
|
variant="caption" |
|
|
color={isDisabled ? "text.disabled" : "text.secondary"} |
|
|
sx={{ |
|
|
display: "block", |
|
|
lineHeight: 1.3, |
|
|
mb: selectedOption === "custom" ? 2 : 0, |
|
|
transition: TRANSITION, |
|
|
}} |
|
|
> |
|
|
Choose your own inference provider and model |
|
|
</Typography> |
|
|
|
|
|
<Box |
|
|
sx={{ |
|
|
maxHeight: selectedOption === "custom" ? "300px" : "0px", |
|
|
opacity: selectedOption === "custom" ? 1 : 0, |
|
|
transition: TRANSITION, |
|
|
}} |
|
|
> |
|
|
<Box onClick={(e) => e.stopPropagation()}> |
|
|
<ModelProviderSelector |
|
|
model={customModel} |
|
|
provider={customProvider} |
|
|
onModelChange={onCustomModelChange || (() => {})} |
|
|
onProviderChange={onCustomProviderChange || (() => {})} |
|
|
variant="form" |
|
|
size="small" |
|
|
showValidation={true} |
|
|
isDisabled={isDisabled} |
|
|
/> |
|
|
</Box> |
|
|
</Box> |
|
|
</Box> |
|
|
</Stack> |
|
|
</Box> |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
if (isToolbar) { |
|
|
return ( |
|
|
<> |
|
|
{/* Provider Select with Tooltip */} |
|
|
<Box sx={{ display: "flex", alignItems: "center", gap: 0.5 }}> |
|
|
<Tooltip |
|
|
title="Select which Hugging Face inference provider to use for running your model" |
|
|
placement="left" |
|
|
open={providerTooltipOpen} |
|
|
> |
|
|
<span |
|
|
style={{ width: "150px" }} |
|
|
onMouseEnter={() => setProviderTooltipOpen(true)} |
|
|
onMouseLeave={() => setProviderTooltipOpen(false)} |
|
|
onClick={() => setProviderTooltipOpen(false)} |
|
|
> |
|
|
<ProviderSelect |
|
|
provider={provider} |
|
|
onProviderChange={handleProviderChange} |
|
|
size={fieldSize} |
|
|
variant={variant} |
|
|
showValidation={showValidation} |
|
|
isDisabled={isDisabled} |
|
|
fullWidth |
|
|
/> |
|
|
</span> |
|
|
</Tooltip> |
|
|
</Box> |
|
|
|
|
|
{/* Model Select with Tooltip */} |
|
|
<Box sx={{ display: "flex", alignItems: "center", gap: 0.5 }}> |
|
|
<Tooltip |
|
|
title="Select a model from the available options for the chosen provider" |
|
|
placement="left" |
|
|
open={modelTooltipOpen} |
|
|
> |
|
|
<span |
|
|
style={{ width: "400px" }} |
|
|
onMouseEnter={() => setModelTooltipOpen(true)} |
|
|
onMouseLeave={() => setModelTooltipOpen(false)} |
|
|
onClick={() => setModelTooltipOpen(false)} |
|
|
> |
|
|
<ModelSelect |
|
|
model={model} |
|
|
provider={provider} |
|
|
availableModels={availableModels} |
|
|
loading={loading} |
|
|
error={error} |
|
|
onModelChange={onModelChange} |
|
|
size={fieldSize} |
|
|
variant={variant} |
|
|
showValidation={showValidation} |
|
|
isDisabled={isDisabled} |
|
|
fullWidth |
|
|
/> |
|
|
</span> |
|
|
</Tooltip> |
|
|
</Box> |
|
|
</> |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
return ( |
|
|
<Stack spacing={2}> |
|
|
{/* Provider Selection */} |
|
|
<ProviderSelect |
|
|
provider={provider} |
|
|
onProviderChange={handleProviderChange} |
|
|
size={fieldSize} |
|
|
variant={variant} |
|
|
showValidation={showValidation} |
|
|
fullWidth |
|
|
isDisabled={isDisabled} |
|
|
/> |
|
|
|
|
|
{/* Model Selection */} |
|
|
<ModelSelect |
|
|
model={model} |
|
|
provider={provider} |
|
|
availableModels={availableModels} |
|
|
loading={loading} |
|
|
error={error} |
|
|
onModelChange={onModelChange} |
|
|
size={fieldSize} |
|
|
variant={variant} |
|
|
showValidation={showValidation} |
|
|
fullWidth |
|
|
isDisabled={isDisabled} |
|
|
/> |
|
|
</Stack> |
|
|
); |
|
|
}; |
|
|
|