import json
import logging
import os
import pathlib
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import pydantic
import yaml
from packaging.version import Version
from pydantic import ConfigDict, ValidationError
from pydantic.json import pydantic_encoder
from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import ConfigModel, LimitModel, ResponseModel
from mlflow.gateway.constants import (
MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES,
MLFLOW_GATEWAY_ROUTE_BASE,
MLFLOW_QUERY_SUFFIX,
)
from mlflow.gateway.utils import (
check_configuration_deprecated_fields,
check_configuration_route_name_collisions,
is_valid_ai21labs_model,
is_valid_endpoint_name,
is_valid_mosiacml_chat_model,
)
from mlflow.utils.pydantic_utils import IS_PYDANTIC_V2_OR_NEWER, field_validator, model_validator
_logger = logging.getLogger(__name__)
if IS_PYDANTIC_V2_OR_NEWER:
from pydantic import SerializeAsAny
if TYPE_CHECKING:
from mlflow.deployments.server.config import Endpoint
[docs]class Provider(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"
AI21LABS = "ai21labs"
MLFLOW_MODEL_SERVING = "mlflow-model-serving"
MOSAICML = "mosaicml"
HUGGINGFACE_TEXT_GENERATION_INFERENCE = "huggingface-text-generation-inference"
PALM = "palm"
GEMINI = "gemini"
BEDROCK = "bedrock"
AMAZON_BEDROCK = "amazon-bedrock" # an alias for bedrock
# Note: The following providers are only supported on Databricks
DATABRICKS_MODEL_SERVING = "databricks-model-serving"
DATABRICKS = "databricks"
MISTRAL = "mistral"
TOGETHERAI = "togetherai"
[docs] @classmethod
def values(cls):
return {p.value for p in cls}
[docs]class TogetherAIConfig(ConfigModel):
togetherai_api_key: str
[docs] @field_validator("togetherai_api_key", mode="before")
def validate_togetherai_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class EndpointType(str, Enum):
LLM_V1_COMPLETIONS = "llm/v1/completions"
LLM_V1_CHAT = "llm/v1/chat"
LLM_V1_EMBEDDINGS = "llm/v1/embeddings"
[docs]class CohereConfig(ConfigModel):
cohere_api_key: str
[docs] @field_validator("cohere_api_key", mode="before")
def validate_cohere_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class AI21LabsConfig(ConfigModel):
ai21labs_api_key: str
[docs] @field_validator("ai21labs_api_key", mode="before")
def validate_ai21labs_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class MosaicMLConfig(ConfigModel):
mosaicml_api_key: str
mosaicml_api_base: str | None = None
[docs] @field_validator("mosaicml_api_key", mode="before")
def validate_mosaicml_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class OpenAIAPIType(str, Enum):
OPENAI = "openai"
AZURE = "azure"
AZUREAD = "azuread"
@classmethod
def _missing_(cls, value):
"""
Implements case-insensitive matching of API type strings
"""
for api_type in cls:
if api_type.value == value.lower():
return api_type
raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{value}'")
[docs]class OpenAIConfig(ConfigModel):
openai_api_key: str
openai_api_type: OpenAIAPIType = OpenAIAPIType.OPENAI
openai_api_base: str | None = None
openai_api_version: str | None = None
openai_deployment_name: str | None = None
openai_organization: str | None = None
[docs] @field_validator("openai_api_key", mode="before")
def validate_openai_api_key(cls, value):
return _resolve_api_key_from_input(value)
@classmethod
def _validate_field_compatibility(cls, info: dict[str, Any]):
if not isinstance(info, dict):
return info
api_type = (info.get("openai_api_type") or OpenAIAPIType.OPENAI).lower()
if api_type == OpenAIAPIType.OPENAI:
if info.get("openai_deployment_name") is not None:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration can only specify a value for "
f"'openai_deployment_name' if 'openai_api_type' is '{OpenAIAPIType.AZURE}' "
f"or '{OpenAIAPIType.AZUREAD}'. Found type: '{api_type}'"
)
if info.get("openai_api_base") is None:
info["openai_api_base"] = "https://api.openai.com/v1"
elif api_type in (OpenAIAPIType.AZURE, OpenAIAPIType.AZUREAD):
if info.get("openai_organization") is not None:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration can only specify a value for "
f"'openai_organization' if 'openai_api_type' is '{OpenAIAPIType.OPENAI}'"
)
base_url = info.get("openai_api_base")
deployment_name = info.get("openai_deployment_name")
api_version = info.get("openai_api_version")
if (base_url, deployment_name, api_version).count(None) > 0:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration must specify 'openai_api_base', "
f"'openai_deployment_name', and 'openai_api_version' if 'openai_api_type' is "
f"'{OpenAIAPIType.AZURE}' or '{OpenAIAPIType.AZUREAD}'."
)
else:
raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{api_type}'")
return info
[docs] @model_validator(mode="before")
def validate_field_compatibility(cls, info: dict[str, Any]):
return cls._validate_field_compatibility(info)
[docs]class AnthropicConfig(ConfigModel):
anthropic_api_key: str
anthropic_version: str = "2023-06-01"
[docs] @field_validator("anthropic_api_key", mode="before")
def validate_anthropic_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class PaLMConfig(ConfigModel):
palm_api_key: str
[docs] @field_validator("palm_api_key", mode="before")
def validate_palm_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class GeminiConfig(ConfigModel):
gemini_api_key: str
[docs] @field_validator("gemini_api_key", mode="before")
def validate_gemini_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class MlflowModelServingConfig(ConfigModel):
model_server_url: str
# Workaround to suppress warning that Pydantic raises when a field name starts with "model_".
# https://github.com/mlflow/mlflow/issues/10335
model_config = pydantic.ConfigDict(protected_namespaces=())
[docs]class HuggingFaceTextGenerationInferenceConfig(ConfigModel):
hf_server_url: str
[docs]class AWSBaseConfig(pydantic.BaseModel):
aws_region: str | None = None
[docs]class AWSRole(AWSBaseConfig):
aws_role_arn: str
session_length_seconds: int = 15 * 60
[docs]class AWSIdAndKey(AWSBaseConfig):
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str | None = None
[docs]class AmazonBedrockConfig(ConfigModel):
# order here is important, at least for pydantic<2
aws_config: AWSRole | AWSIdAndKey | AWSBaseConfig
[docs]class MistralConfig(ConfigModel):
mistral_api_key: str
[docs] @field_validator("mistral_api_key", mode="before")
def validate_mistral_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class ModelInfo(ResponseModel):
name: str | None = None
provider: Provider
def _resolve_api_key_from_input(api_key_input):
"""
Resolves the provided API key.
Input formats accepted:
- Path to a file as a string which will have the key loaded from it
- environment variable name that stores the api key
- the api key itself
"""
if not isinstance(api_key_input, str):
raise MlflowException.invalid_parameter_value(
"The api key provided is not a string. Please provide either an environment "
"variable key, a path to a file containing the api key, or the api key itself"
)
# try reading as an environment variable
if api_key_input.startswith("$"):
env_var_name = api_key_input[1:]
if env_var := os.getenv(env_var_name):
return env_var
else:
raise MlflowException.invalid_parameter_value(
f"Environment variable {env_var_name!r} is not set"
)
# try reading from a local path
file = pathlib.Path(api_key_input)
try:
if file.is_file():
return file.read_text()
except OSError:
# `is_file` throws an OSError if `api_key_input` exceeds the maximum filename length
# (e.g., 255 characters on Unix).
pass
# if the key itself is passed, return
return api_key_input
[docs]class Model(ConfigModel):
name: str | None = None
provider: str | Provider
if IS_PYDANTIC_V2_OR_NEWER:
config: SerializeAsAny[ConfigModel] | None = None
else:
config: ConfigModel | None = None
[docs] @field_validator("provider", mode="before")
def validate_provider(cls, value):
from mlflow.gateway.provider_registry import provider_registry
if isinstance(value, Provider):
return value
formatted_value = value.replace("-", "_").upper()
if formatted_value in Provider.__members__:
return Provider[formatted_value]
if value in provider_registry.keys():
return value
raise MlflowException.invalid_parameter_value(f"The provider '{value}' is not supported.")
@classmethod
def _validate_config(cls, val, context):
from mlflow.gateway.provider_registry import provider_registry
# For Pydantic v2: 'context' is a ValidationInfo object with a 'data' attribute.
# For Pydantic v1: 'context' is dict-like 'values'.
if IS_PYDANTIC_V2_OR_NEWER:
provider = context.data.get("provider")
else:
provider = context.get("provider") if context else None
if provider:
config_type = provider_registry.get(provider).CONFIG_TYPE
return config_type(**val) if isinstance(val, dict) else val
raise MlflowException.invalid_parameter_value(
"A provider must be provided for each gateway route."
)
[docs] @field_validator("config", mode="before")
def validate_config(cls, info, values):
return cls._validate_config(info, values)
[docs]class AliasedConfigModel(ConfigModel):
"""
Enables use of field aliases in a configuration model for backwards compatibility
"""
if Version(pydantic.__version__) >= Version("2.0"):
model_config = ConfigDict(populate_by_name=True)
else:
class Config:
allow_population_by_field_name = True
[docs]class Limit(LimitModel):
calls: int
key: str | None = None
renewal_period: str
[docs]class LimitsConfig(ConfigModel):
limits: list[Limit] | None = []
[docs]class EndpointConfig(AliasedConfigModel):
name: str
endpoint_type: EndpointType
model: Model
limit: Limit | None = None
[docs] @field_validator("name")
def validate_endpoint_name(cls, route_name):
if not is_valid_endpoint_name(route_name):
raise MlflowException.invalid_parameter_value(
"The route name provided contains disallowed characters for a url endpoint. "
f"'{route_name}' is invalid. Names cannot contain spaces or any non "
"alphanumeric characters other than hyphen and underscore."
)
return route_name
[docs] @field_validator("model", mode="before")
def validate_model(cls, model):
if model:
model_instance = Model(**model)
if model_instance.provider in Provider.values() and model_instance.config is None:
raise MlflowException.invalid_parameter_value(
"A config must be supplied when setting a provider. The provider entry for "
f"{model_instance.provider} is incorrect."
)
return model
@model_validator(mode="after", skip_on_failure=True)
def validate_route_type_and_model_name(cls, values):
if IS_PYDANTIC_V2_OR_NEWER:
route_type = values.endpoint_type
model = values.model
else:
route_type = values.get("endpoint_type")
model = values.get("model")
if (
model
and model.provider == "mosaicml"
and route_type == EndpointType.LLM_V1_CHAT
and not is_valid_mosiacml_chat_model(model.name)
):
raise MlflowException.invalid_parameter_value(
f"An invalid model has been specified for the chat route. '{model.name}'. "
f"Ensure the model selected starts with one of: "
f"{MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES}"
)
if model and model.provider == "ai21labs" and not is_valid_ai21labs_model(model.name):
raise MlflowException.invalid_parameter_value(
f"An Unsupported AI21Labs model has been specified: '{model.name}'. "
f"Please see documentation for supported models."
)
return values
[docs] @field_validator("endpoint_type", mode="before")
def validate_route_type(cls, value):
if value in EndpointType._value2member_map_:
return value
raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")
[docs] @field_validator("limit", mode="before")
def validate_limit(cls, value):
from limits import parse
if value:
limit = Limit(**value)
try:
parse(f"{limit.calls}/{limit.renewal_period}")
except ValueError:
raise MlflowException.invalid_parameter_value(
"Failed to parse the rate limit configuration."
"Please make sure limit.calls is a positive number and"
"limit.renewal_period is a right granularity"
)
return value
def _to_legacy_route(self) -> "_LegacyRoute":
return _LegacyRoute(
name=self.name,
route_type=self.endpoint_type,
model=EndpointModelInfo(
name=self.model.name,
provider=self.model.provider,
),
route_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}",
limit=self.limit,
)
[docs] def to_endpoint(self) -> "Endpoint":
from mlflow.deployments.server.config import Endpoint
return Endpoint(
name=self.name,
endpoint_type=self.endpoint_type,
model=EndpointModelInfo(
name=self.model.name,
provider=self.model.provider,
),
endpoint_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}",
limit=self.limit,
)
[docs]class RouteDestinationConfig(ConfigModel):
name: str
traffic_percentage: int
[docs]class TrafficRouteConfig(ConfigModel):
name: str
task_type: EndpointType
destinations: list[RouteDestinationConfig]
routing_strategy: Literal["TRAFFIC_SPLIT"] = "TRAFFIC_SPLIT"
[docs]class EndpointModelInfo(ResponseModel):
name: str | None = None
# Use `str` instead of `Provider` enum to allow gateway backends such as Databricks to
# support new providers without breaking the gateway client.
provider: str
_ROUTE_EXTRA_SCHEMA = {
"example": {
"name": "openai-completions",
"route_type": "llm/v1/completions",
"model": {
"name": "gpt-4o-mini",
"provider": "openai",
},
"route_url": "/gateway/routes/completions/invocations",
}
}
class _LegacyRoute(ConfigModel):
name: str
route_type: str
model: EndpointModelInfo
route_url: str
limit: Limit | None = None
class Config:
if IS_PYDANTIC_V2_OR_NEWER:
json_schema_extra = _ROUTE_EXTRA_SCHEMA
else:
schema_extra = _ROUTE_EXTRA_SCHEMA
def to_endpoint(self):
from mlflow.deployments.server.config import Endpoint
return Endpoint(
name=self.name,
endpoint_type=self.route_type,
model=self.model,
endpoint_url=self.route_url,
limit=self.limit,
)
[docs]class GatewayConfig(AliasedConfigModel):
endpoints: list[EndpointConfig]
routes: list[TrafficRouteConfig] | None = None
def _load_gateway_config(path: str | Path) -> GatewayConfig:
"""
Reads the gateway configuration yaml file from the storage location and returns an instance
of the configuration RouteConfig class
"""
if isinstance(path, str):
path = Path(path)
try:
configuration = yaml.safe_load(path.read_text())
except Exception as e:
raise MlflowException.invalid_parameter_value(
f"The file at {path} is not a valid yaml file"
) from e
check_configuration_deprecated_fields(configuration)
check_configuration_route_name_collisions(configuration)
try:
return GatewayConfig(**configuration)
except ValidationError as e:
raise MlflowException.invalid_parameter_value(
f"The gateway configuration is invalid: {e}"
) from e
def _save_route_config(config: GatewayConfig, path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.write_text(yaml.safe_dump(json.loads(json.dumps(config.dict(), default=pydantic_encoder))))
def _validate_config(config_path: str) -> GatewayConfig:
if not os.path.exists(config_path):
raise MlflowException.invalid_parameter_value(f"{config_path} does not exist")
try:
return _load_gateway_config(config_path)
except Exception as e:
raise MlflowException.invalid_parameter_value(f"Invalid gateway configuration: {e}") from e