Source code for mlflow.gateway.config

import json
import logging
import os
import pathlib
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import pydantic
import yaml
from packaging.version import Version
from pydantic import ConfigDict, ValidationError
from pydantic.json import pydantic_encoder

from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import ConfigModel, LimitModel, ResponseModel
from mlflow.gateway.constants import (
    MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES,
    MLFLOW_GATEWAY_ROUTE_BASE,
    MLFLOW_QUERY_SUFFIX,
)
from mlflow.gateway.utils import (
    check_configuration_deprecated_fields,
    check_configuration_route_name_collisions,
    is_valid_ai21labs_model,
    is_valid_endpoint_name,
    is_valid_mosiacml_chat_model,
)
from mlflow.utils.pydantic_utils import IS_PYDANTIC_V2_OR_NEWER, field_validator, model_validator

_logger = logging.getLogger(__name__)

if IS_PYDANTIC_V2_OR_NEWER:
    from pydantic import SerializeAsAny

if TYPE_CHECKING:
    from mlflow.deployments.server.config import Endpoint


[docs]class Provider(str, Enum): OPENAI = "openai" ANTHROPIC = "anthropic" COHERE = "cohere" AI21LABS = "ai21labs" MLFLOW_MODEL_SERVING = "mlflow-model-serving" MOSAICML = "mosaicml" HUGGINGFACE_TEXT_GENERATION_INFERENCE = "huggingface-text-generation-inference" PALM = "palm" GEMINI = "gemini" BEDROCK = "bedrock" AMAZON_BEDROCK = "amazon-bedrock" # an alias for bedrock # Note: The following providers are only supported on Databricks DATABRICKS_MODEL_SERVING = "databricks-model-serving" DATABRICKS = "databricks" MISTRAL = "mistral" TOGETHERAI = "togetherai"
[docs] @classmethod def values(cls): return {p.value for p in cls}
[docs]class TogetherAIConfig(ConfigModel): togetherai_api_key: str
[docs] @field_validator("togetherai_api_key", mode="before") def validate_togetherai_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class EndpointType(str, Enum): LLM_V1_COMPLETIONS = "llm/v1/completions" LLM_V1_CHAT = "llm/v1/chat" LLM_V1_EMBEDDINGS = "llm/v1/embeddings"
[docs]class CohereConfig(ConfigModel): cohere_api_key: str
[docs] @field_validator("cohere_api_key", mode="before") def validate_cohere_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class AI21LabsConfig(ConfigModel): ai21labs_api_key: str
[docs] @field_validator("ai21labs_api_key", mode="before") def validate_ai21labs_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class MosaicMLConfig(ConfigModel): mosaicml_api_key: str mosaicml_api_base: str | None = None
[docs] @field_validator("mosaicml_api_key", mode="before") def validate_mosaicml_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class OpenAIAPIType(str, Enum): OPENAI = "openai" AZURE = "azure" AZUREAD = "azuread" @classmethod def _missing_(cls, value): """ Implements case-insensitive matching of API type strings """ for api_type in cls: if api_type.value == value.lower(): return api_type raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{value}'")
[docs]class OpenAIConfig(ConfigModel): openai_api_key: str openai_api_type: OpenAIAPIType = OpenAIAPIType.OPENAI openai_api_base: str | None = None openai_api_version: str | None = None openai_deployment_name: str | None = None openai_organization: str | None = None
[docs] @field_validator("openai_api_key", mode="before") def validate_openai_api_key(cls, value): return _resolve_api_key_from_input(value)
@classmethod def _validate_field_compatibility(cls, info: dict[str, Any]): if not isinstance(info, dict): return info api_type = (info.get("openai_api_type") or OpenAIAPIType.OPENAI).lower() if api_type == OpenAIAPIType.OPENAI: if info.get("openai_deployment_name") is not None: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration can only specify a value for " f"'openai_deployment_name' if 'openai_api_type' is '{OpenAIAPIType.AZURE}' " f"or '{OpenAIAPIType.AZUREAD}'. Found type: '{api_type}'" ) if info.get("openai_api_base") is None: info["openai_api_base"] = "https://api.openai.com/v1" elif api_type in (OpenAIAPIType.AZURE, OpenAIAPIType.AZUREAD): if info.get("openai_organization") is not None: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration can only specify a value for " f"'openai_organization' if 'openai_api_type' is '{OpenAIAPIType.OPENAI}'" ) base_url = info.get("openai_api_base") deployment_name = info.get("openai_deployment_name") api_version = info.get("openai_api_version") if (base_url, deployment_name, api_version).count(None) > 0: raise MlflowException.invalid_parameter_value( f"OpenAI route configuration must specify 'openai_api_base', " f"'openai_deployment_name', and 'openai_api_version' if 'openai_api_type' is " f"'{OpenAIAPIType.AZURE}' or '{OpenAIAPIType.AZUREAD}'." ) else: raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{api_type}'") return info
[docs] @model_validator(mode="before") def validate_field_compatibility(cls, info: dict[str, Any]): return cls._validate_field_compatibility(info)
[docs]class AnthropicConfig(ConfigModel): anthropic_api_key: str anthropic_version: str = "2023-06-01"
[docs] @field_validator("anthropic_api_key", mode="before") def validate_anthropic_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class PaLMConfig(ConfigModel): palm_api_key: str
[docs] @field_validator("palm_api_key", mode="before") def validate_palm_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class GeminiConfig(ConfigModel): gemini_api_key: str
[docs] @field_validator("gemini_api_key", mode="before") def validate_gemini_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class MlflowModelServingConfig(ConfigModel): model_server_url: str # Workaround to suppress warning that Pydantic raises when a field name starts with "model_". # https://github.com/mlflow/mlflow/issues/10335 model_config = pydantic.ConfigDict(protected_namespaces=())
[docs]class HuggingFaceTextGenerationInferenceConfig(ConfigModel): hf_server_url: str
[docs]class AWSBaseConfig(pydantic.BaseModel): aws_region: str | None = None
[docs]class AWSRole(AWSBaseConfig): aws_role_arn: str session_length_seconds: int = 15 * 60
[docs]class AWSIdAndKey(AWSBaseConfig): aws_access_key_id: str aws_secret_access_key: str aws_session_token: str | None = None
[docs]class AmazonBedrockConfig(ConfigModel): # order here is important, at least for pydantic<2 aws_config: AWSRole | AWSIdAndKey | AWSBaseConfig
[docs]class MistralConfig(ConfigModel): mistral_api_key: str
[docs] @field_validator("mistral_api_key", mode="before") def validate_mistral_api_key(cls, value): return _resolve_api_key_from_input(value)
[docs]class ModelInfo(ResponseModel): name: str | None = None provider: Provider
def _resolve_api_key_from_input(api_key_input): """ Resolves the provided API key. Input formats accepted: - Path to a file as a string which will have the key loaded from it - environment variable name that stores the api key - the api key itself """ if not isinstance(api_key_input, str): raise MlflowException.invalid_parameter_value( "The api key provided is not a string. Please provide either an environment " "variable key, a path to a file containing the api key, or the api key itself" ) # try reading as an environment variable if api_key_input.startswith("$"): env_var_name = api_key_input[1:] if env_var := os.getenv(env_var_name): return env_var else: raise MlflowException.invalid_parameter_value( f"Environment variable {env_var_name!r} is not set" ) # try reading from a local path file = pathlib.Path(api_key_input) try: if file.is_file(): return file.read_text() except OSError: # `is_file` throws an OSError if `api_key_input` exceeds the maximum filename length # (e.g., 255 characters on Unix). pass # if the key itself is passed, return return api_key_input
[docs]class Model(ConfigModel): name: str | None = None provider: str | Provider if IS_PYDANTIC_V2_OR_NEWER: config: SerializeAsAny[ConfigModel] | None = None else: config: ConfigModel | None = None
[docs] @field_validator("provider", mode="before") def validate_provider(cls, value): from mlflow.gateway.provider_registry import provider_registry if isinstance(value, Provider): return value formatted_value = value.replace("-", "_").upper() if formatted_value in Provider.__members__: return Provider[formatted_value] if value in provider_registry.keys(): return value raise MlflowException.invalid_parameter_value(f"The provider '{value}' is not supported.")
@classmethod def _validate_config(cls, val, context): from mlflow.gateway.provider_registry import provider_registry # For Pydantic v2: 'context' is a ValidationInfo object with a 'data' attribute. # For Pydantic v1: 'context' is dict-like 'values'. if IS_PYDANTIC_V2_OR_NEWER: provider = context.data.get("provider") else: provider = context.get("provider") if context else None if provider: config_type = provider_registry.get(provider).CONFIG_TYPE return config_type(**val) if isinstance(val, dict) else val raise MlflowException.invalid_parameter_value( "A provider must be provided for each gateway route." )
[docs] @field_validator("config", mode="before") def validate_config(cls, info, values): return cls._validate_config(info, values)
[docs]class AliasedConfigModel(ConfigModel): """ Enables use of field aliases in a configuration model for backwards compatibility """ if Version(pydantic.__version__) >= Version("2.0"): model_config = ConfigDict(populate_by_name=True) else: class Config: allow_population_by_field_name = True
[docs]class Limit(LimitModel): calls: int key: str | None = None renewal_period: str
[docs]class LimitsConfig(ConfigModel): limits: list[Limit] | None = []
[docs]class EndpointConfig(AliasedConfigModel): name: str endpoint_type: EndpointType model: Model limit: Limit | None = None
[docs] @field_validator("name") def validate_endpoint_name(cls, route_name): if not is_valid_endpoint_name(route_name): raise MlflowException.invalid_parameter_value( "The route name provided contains disallowed characters for a url endpoint. " f"'{route_name}' is invalid. Names cannot contain spaces or any non " "alphanumeric characters other than hyphen and underscore." ) return route_name
[docs] @field_validator("model", mode="before") def validate_model(cls, model): if model: model_instance = Model(**model) if model_instance.provider in Provider.values() and model_instance.config is None: raise MlflowException.invalid_parameter_value( "A config must be supplied when setting a provider. The provider entry for " f"{model_instance.provider} is incorrect." ) return model
@model_validator(mode="after", skip_on_failure=True) def validate_route_type_and_model_name(cls, values): if IS_PYDANTIC_V2_OR_NEWER: route_type = values.endpoint_type model = values.model else: route_type = values.get("endpoint_type") model = values.get("model") if ( model and model.provider == "mosaicml" and route_type == EndpointType.LLM_V1_CHAT and not is_valid_mosiacml_chat_model(model.name) ): raise MlflowException.invalid_parameter_value( f"An invalid model has been specified for the chat route. '{model.name}'. " f"Ensure the model selected starts with one of: " f"{MLFLOW_AI_GATEWAY_MOSAICML_CHAT_SUPPORTED_MODEL_PREFIXES}" ) if model and model.provider == "ai21labs" and not is_valid_ai21labs_model(model.name): raise MlflowException.invalid_parameter_value( f"An Unsupported AI21Labs model has been specified: '{model.name}'. " f"Please see documentation for supported models." ) return values
[docs] @field_validator("endpoint_type", mode="before") def validate_route_type(cls, value): if value in EndpointType._value2member_map_: return value raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")
[docs] @field_validator("limit", mode="before") def validate_limit(cls, value): from limits import parse if value: limit = Limit(**value) try: parse(f"{limit.calls}/{limit.renewal_period}") except ValueError: raise MlflowException.invalid_parameter_value( "Failed to parse the rate limit configuration." "Please make sure limit.calls is a positive number and" "limit.renewal_period is a right granularity" ) return value
def _to_legacy_route(self) -> "_LegacyRoute": return _LegacyRoute( name=self.name, route_type=self.endpoint_type, model=EndpointModelInfo( name=self.model.name, provider=self.model.provider, ), route_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}", limit=self.limit, )
[docs] def to_endpoint(self) -> "Endpoint": from mlflow.deployments.server.config import Endpoint return Endpoint( name=self.name, endpoint_type=self.endpoint_type, model=EndpointModelInfo( name=self.model.name, provider=self.model.provider, ), endpoint_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}", limit=self.limit, )
[docs]class RouteDestinationConfig(ConfigModel): name: str traffic_percentage: int
[docs]class TrafficRouteConfig(ConfigModel): name: str task_type: EndpointType destinations: list[RouteDestinationConfig] routing_strategy: Literal["TRAFFIC_SPLIT"] = "TRAFFIC_SPLIT"
[docs]class EndpointModelInfo(ResponseModel): name: str | None = None # Use `str` instead of `Provider` enum to allow gateway backends such as Databricks to # support new providers without breaking the gateway client. provider: str
_ROUTE_EXTRA_SCHEMA = { "example": { "name": "openai-completions", "route_type": "llm/v1/completions", "model": { "name": "gpt-4o-mini", "provider": "openai", }, "route_url": "/gateway/routes/completions/invocations", } } class _LegacyRoute(ConfigModel): name: str route_type: str model: EndpointModelInfo route_url: str limit: Limit | None = None class Config: if IS_PYDANTIC_V2_OR_NEWER: json_schema_extra = _ROUTE_EXTRA_SCHEMA else: schema_extra = _ROUTE_EXTRA_SCHEMA def to_endpoint(self): from mlflow.deployments.server.config import Endpoint return Endpoint( name=self.name, endpoint_type=self.route_type, model=self.model, endpoint_url=self.route_url, limit=self.limit, )
[docs]class GatewayConfig(AliasedConfigModel): endpoints: list[EndpointConfig] routes: list[TrafficRouteConfig] | None = None
def _load_gateway_config(path: str | Path) -> GatewayConfig: """ Reads the gateway configuration yaml file from the storage location and returns an instance of the configuration RouteConfig class """ if isinstance(path, str): path = Path(path) try: configuration = yaml.safe_load(path.read_text()) except Exception as e: raise MlflowException.invalid_parameter_value( f"The file at {path} is not a valid yaml file" ) from e check_configuration_deprecated_fields(configuration) check_configuration_route_name_collisions(configuration) try: return GatewayConfig(**configuration) except ValidationError as e: raise MlflowException.invalid_parameter_value( f"The gateway configuration is invalid: {e}" ) from e def _save_route_config(config: GatewayConfig, path: str | Path) -> None: if isinstance(path, str): path = Path(path) path.write_text(yaml.safe_dump(json.loads(json.dumps(config.dict(), default=pydantic_encoder)))) def _validate_config(config_path: str) -> GatewayConfig: if not os.path.exists(config_path): raise MlflowException.invalid_parameter_value(f"{config_path} does not exist") try: return _load_gateway_config(config_path) except Exception as e: raise MlflowException.invalid_parameter_value(f"Invalid gateway configuration: {e}") from e