Source code for credproxy.config

#  SPDX-License-Identifier: MPL-2.0
#  Copyright 2025-present John Mille <john@ews-network.net>

from __future__ import annotations

import json
import threading
from typing import Any
from pathlib import Path
from dataclasses import field, dataclass

import yaml
import jsonschema

from credproxy.logger import LOG
from credproxy.metrics import update_active_services
from credproxy.settings import NAMESPACE, get_config_file
from credproxy.sanitizer import (
    register_sensitive_dict,
    register_sensitive_value,
    sanitize_exception_message,
)

# Import substitution parser and centralized logging
from credproxy.substitutions import substitute_variables


[docs] def keyisset(key: str, data: dict) -> Any: """Check if key exists in dict and return value, raise if missing.""" if key not in data: raise KeyError(f"Required key '{key}' not found in configuration") return data[key]
[docs] def set_else_none(key: str, data: dict, default: Any) -> Any: """Get value from dict or return default if not present.""" return data.get(key, default)
[docs] @dataclass class IAMProfileAuthConfig: """IAM profile authentication configuration.""" profile_name: str config_file: str | None = None # Path to AWS config file
[docs] @dataclass class IAMKeysAuthConfig: """IAM access keys authentication configuration.""" aws_access_key_id: str aws_secret_access_key: str session_token: str | None = None # For temporary credentials
[docs] @dataclass class SourceCredentialsConfig: """Source AWS credentials configuration.""" region: str | None = None iam_profile: IAMProfileAuthConfig | None = None iam_keys: IAMKeysAuthConfig | None = None
[docs] @dataclass class AssumedRoleConfig: """AWS role assumption configuration.""" RoleArn: str RoleSessionName: str = "credproxy" DurationSeconds: int = 900 ExternalId: str | None = None PolicyArns: list[dict] | None = None Policy: str | None = None Tags: list[dict] | None = None TransitiveTagKeys: list[str] | None = None SerialNumber: str | None = None TokenCode: str | None = None SourceIdentity: str | None = None
[docs] @dataclass class ServerConfig: """Server configuration settings.""" host: str = "0.0.0.0" port: int = 1338 debug: bool = False log_health_checks: bool = False
[docs] @dataclass class CredentialsConfig: """Credential management settings.""" refresh_buffer_seconds: int = 300 retry_delay: int = 60 request_timeout: int = 30
[docs] @dataclass class DirectoryConfig: """Configuration for a single monitored directory.""" path: str include_patterns: list[str] = field(default_factory=list) exclude_patterns: list[str] = field(default_factory=list)
[docs] @dataclass class DynamicServicesConfig: """Dynamic services configuration settings.""" enabled: bool = False directories: list[DirectoryConfig] = field( default_factory=lambda: [DirectoryConfig(path="/credproxy/dynamic")] ) reload_interval: int = 5 watcher_stop_timeout: int = 5 # Timeout in seconds for stopping the file watcher
[docs] @dataclass class PrometheusConfig: """Prometheus metrics configuration.""" enabled: bool = True host: str = "0.0.0.0" port: int = 9090
[docs] @dataclass class MetricsConfig: """Metrics and telemetry configuration.""" prometheus: PrometheusConfig = field(default_factory=PrometheusConfig)
[docs] @dataclass class ServiceConfig: """Configuration for a single service.""" auth_token: str source_credentials: SourceCredentialsConfig assumed_role: AssumedRoleConfig source_file: str | None = None # Track which file loaded this service
def _parse_directory_configs( directories_data: list | str, dynamic_services_data: dict ) -> list[DirectoryConfig]: """Parse directory configurations from various formats. Args: directories_data: List of directory paths or directory config objects dynamic_services_data: Parent dynamic services configuration Returns: List of DirectoryConfig objects """ if isinstance(directories_data, list) and directories_data: if isinstance(directories_data[0], str): # Old format: list of strings return [ DirectoryConfig( path=path, include_patterns=set_else_none( "include_patterns", dynamic_services_data, [] ), exclude_patterns=set_else_none( "exclude_patterns", dynamic_services_data, [] ), ) for path in directories_data ] else: # New format: list of objects return [ DirectoryConfig( path=dir_config["path"], include_patterns=set_else_none("include_patterns", dir_config, []), exclude_patterns=set_else_none("exclude_patterns", dir_config, []), ) for dir_config in directories_data ] return [DirectoryConfig(path="/credproxy/dynamic")]
[docs] def merge_aws_config(defaults: dict, overrides: dict) -> dict: """Merge AWS configuration with defaults and service-specific overrides.""" merged = defaults.copy() if defaults else {} # Apply service-specific overrides for key, value in overrides.items(): merged[key] = value # Override even if None is explicitly set return merged
[docs] @dataclass class Config: """Main configuration class.""" server: ServerConfig = field(default_factory=ServerConfig) credentials: CredentialsConfig = field(default_factory=CredentialsConfig) aws_defaults: SourceCredentialsConfig | None = None services: dict[str, ServiceConfig] = field(default_factory=dict) dynamic_services: DynamicServicesConfig | None = None metrics: MetricsConfig = field(default_factory=MetricsConfig) # Token-to-service mapping for instant lookup _token_to_service: dict[str, str] = field( default_factory=dict, init=False, repr=False ) # Thread lock for service management operations _services_lock: threading.RLock = field( default_factory=threading.RLock, init=False, repr=False ) # Class-level sanitizer for message sanitization def __post_init__(self): """Build token-to-service mapping after initialization.""" self._build_token_mapping() def _build_token_mapping(self): """Build instant lookup mapping from tokens to service names.""" self._token_to_service.clear() LOG.info("Building token mapping for %d services", len(self.services)) for service_name, service_config in self.services.items(): self._token_to_service[service_config.auth_token] = service_name LOG.debug( "Mapped token for service %s: %s...", service_name, service_config.auth_token[:8] + "...", ) LOG.info( "Token mapping built successfully with %d services", len(self._token_to_service), )
[docs] def get_service_name_by_token(self, token: str) -> str | None: """Get service name by authorization token.""" service_name = self._token_to_service.get(token) LOG.info("Token lookup for %s...: %s", token[:8] + "...", service_name) LOG.debug("Token registry contains %d tokens", len(self._token_to_service)) LOG.debug("Available services: %s", list(self.services.keys())) if not service_name: LOG.warning("Token not found in registry: %s...", token[:8] + "...") LOG.info("Available services in registry: %s", list(self.services.keys())) token_list = [f"{token[:8]}..." for token in self._token_to_service.keys()] LOG.info("Available tokens in registry: %s", token_list) return service_name
[docs] def add_service(self, service_name: str, service_config: ServiceConfig) -> bool: """Add a new service dynamically.""" with self._services_lock: # Check if service already exists (static services take precedence) if service_name in self.services: LOG.warning( "Service '%s' already defined. Ignoring service from %s", service_name, service_config.source_file or "unknown", ) return False # Add service and rebuild token mapping self.services[service_name] = service_config self._build_token_mapping() LOG.info( "Added dynamic service '%s' from %s with token %s...", service_name, service_config.source_file or "unknown", service_config.auth_token[:8] + "...", ) # Update active services count update_active_services(len(self.services)) return True
[docs] def remove_service(self, service_name: str) -> bool: """Remove a service dynamically.""" with self._services_lock: if service_name not in self.services: LOG.warning("Service '%s' not found for removal", service_name) return False # Remove service and rebuild token mapping del self.services[service_name] self._build_token_mapping() LOG.info("Removed dynamic service '%s'", service_name) # Update active services count update_active_services(len(self.services)) return True
[docs] def update_service(self, service_name: str, service_config: ServiceConfig) -> bool: """Update an existing service dynamically.""" with self._services_lock: if service_name not in self.services: LOG.warning("Service '%s' not found for update", service_name) return False # Update service and rebuild token mapping self.services[service_name] = service_config self._build_token_mapping() LOG.info( "Updated dynamic service '%s' from %s", service_name, service_config.source_file or "unknown", ) # Service updated - no specific metric needed return True
[docs] @classmethod def from_file(cls, config_path: str | None = None) -> Config: """Load configuration from YAML or JSON file.""" if config_path is None: config_path = get_config_file(NAMESPACE) else: # Check if CREDPROXY_CONFIG_FILE environment variable is set # and override the provided config_path if it is env_config_path = get_config_file(NAMESPACE) default_path = "/credproxy/config.yaml" if env_config_path != default_path: config_path = env_config_path config_file = Path(config_path) if not config_file.exists(): raise FileNotFoundError(f"Configuration file not found: {config_path}") # Load raw YAML/JSON first try: with open(config_file, encoding="utf-8") as f: config_data = yaml.safe_load(f) LOG.info("Loaded configuration from %s as YAML", config_path) except yaml.YAMLError as error: LOG.debug("YAML parsing failed, trying JSON") try: with open(config_file, encoding="utf-8") as f: config_data = json.load(f) LOG.info("Loaded configuration from %s as JSON", config_path) except json.JSONDecodeError as json_error: raise ValueError( f"File is not valid YAML or JSON. YAML error: {error}, " f"JSON error: {json_error}" ) from error return cls.from_dict(config_data, config_path)
[docs] @classmethod def from_dict(cls, config_data: dict, config_path: str | None = None) -> Config: """Create configuration from dictionary.""" # Apply variable substitution to the original config data config_data = substitute_variables(config_data) # Validate the substituted config against schema cls.validate_schema(config_data) server_data = config_data.get("server", {}) creds_data = config_data.get("credentials", {}) aws_defaults_data = config_data.get("aws_defaults", {}) services_data = config_data.get("services", {}) dynamic_services_data = config_data.get("dynamic_services", {}) metrics_data = config_data.get("metrics", {}) # Create AWS defaults if provided aws_defaults = None if aws_defaults_data: aws_defaults = cls._create_source_credentials_config(aws_defaults_data) # Create dynamic services config if provided (needed for validation) dynamic_services = None if dynamic_services_data: # Handle directories - convert from new format to DirectoryConfig objects directories_data = set_else_none( "directories", dynamic_services_data, ["/credproxy/dynamic"] ) # Parse directories using helper function directories = _parse_directory_configs( directories_data, dynamic_services_data ) dynamic_services = DynamicServicesConfig( enabled=set_else_none("enabled", dynamic_services_data, False), directories=directories, reload_interval=set_else_none( "reload_interval", dynamic_services_data, 5 ), watcher_stop_timeout=set_else_none( "watcher_stop_timeout", dynamic_services_data, 5 ), ) services = {} for service_name, service_config in services_data.items(): source_creds_data = service_config.get("source_credentials", {}) assumed_role_data = service_config.get("assumed_role", {}) # Merge defaults with service-specific overrides for source credentials merged_source_creds_data = merge_aws_config( cls._source_credentials_config_to_dict(aws_defaults) if aws_defaults else {}, source_creds_data, ) # Register sensitive values for sanitization # Register auth token auth_token = keyisset("auth_token", service_config) register_sensitive_value(auth_token) # Register credentials from source_credentials register_sensitive_dict(merged_source_creds_data) # Register ExternalId if present if "ExternalId" in assumed_role_data: register_sensitive_value(assumed_role_data["ExternalId"]) services[service_name] = ServiceConfig( auth_token=auth_token, source_credentials=cls._create_source_credentials_config( merged_source_creds_data ), assumed_role=cls._create_assumed_role_config(assumed_role_data), source_file=str(Path(config_path).resolve()) if config_path else "static_config", ) # Validate service configurations after inheritance cls._validate_services(services, dynamic_services) # Create metrics config prometheus_data = metrics_data.get("prometheus", {}) metrics = MetricsConfig( prometheus=PrometheusConfig( enabled=set_else_none("enabled", prometheus_data, True), host=set_else_none("host", prometheus_data, "0.0.0.0"), port=set_else_none("port", prometheus_data, 9090), ) ) # Import LOG_HEALTH_CHECKS from settings for env var support from credproxy.settings import LOG_HEALTH_CHECKS # Config file setting OR environment variable (either can enable it) log_health_checks_config = set_else_none( "log_health_checks", server_data, False ) log_health_checks = log_health_checks_config or LOG_HEALTH_CHECKS return cls( server=ServerConfig( host=set_else_none("host", server_data, "localhost"), port=set_else_none("port", server_data, 1338), debug=set_else_none("debug", server_data, False), log_health_checks=log_health_checks, ), credentials=CredentialsConfig( refresh_buffer_seconds=set_else_none( "refresh_buffer_seconds", creds_data, 300 ), retry_delay=set_else_none("retry_delay", creds_data, 60), request_timeout=set_else_none("request_timeout", creds_data, 30), ), aws_defaults=aws_defaults, services=services, dynamic_services=dynamic_services, metrics=metrics, )
@classmethod def _create_source_credentials_config(cls, data: dict) -> SourceCredentialsConfig: """Create SourceCredentialsConfig from dictionary data.""" iam_profile_config = None iam_keys_config = None # Auto-detect auth method based on presence of config objects if "iam_profile" in data: profile_data = data["iam_profile"] iam_profile_config = IAMProfileAuthConfig( profile_name=set_else_none("profile_name", profile_data, None), config_file=set_else_none("config_file", profile_data, None), ) elif "iam_keys" in data: keys_data = data["iam_keys"] iam_keys_config = IAMKeysAuthConfig( aws_access_key_id=set_else_none("aws_access_key_id", keys_data, None), aws_secret_access_key=set_else_none( "aws_secret_access_key", keys_data, None ), session_token=set_else_none("session_token", keys_data, None), ) # If neither iam_profile nor iam_keys present, use default SDK behavior return SourceCredentialsConfig( region=set_else_none("region", data, None), iam_profile=iam_profile_config, iam_keys=iam_keys_config, ) @classmethod def _create_assumed_role_config(cls, data: dict) -> AssumedRoleConfig: """Create AssumedRoleConfig from dictionary data.""" return AssumedRoleConfig( RoleArn=keyisset("RoleArn", data), RoleSessionName=set_else_none("RoleSessionName", data, "credproxy"), DurationSeconds=set_else_none("DurationSeconds", data, 900), ExternalId=set_else_none("ExternalId", data, None), PolicyArns=set_else_none("PolicyArns", data, None), Policy=set_else_none("Policy", data, None), Tags=set_else_none("Tags", data, None), TransitiveTagKeys=set_else_none("TransitiveTagKeys", data, None), SerialNumber=set_else_none("SerialNumber", data, None), TokenCode=set_else_none("TokenCode", data, None), SourceIdentity=set_else_none("SourceIdentity", data, None), ) @classmethod def _source_credentials_config_to_dict( cls, source_config: SourceCredentialsConfig | None ) -> dict: """Convert SourceCredentialsConfig to dictionary for merging.""" if not source_config: return {} result: dict = { "region": source_config.region, } if source_config.iam_profile: result["iam_profile"] = { "profile_name": source_config.iam_profile.profile_name, "config_file": source_config.iam_profile.config_file, } elif source_config.iam_keys: result["iam_keys"] = { "aws_access_key_id": source_config.iam_keys.aws_access_key_id, "aws_secret_access_key": source_config.iam_keys.aws_secret_access_key, "session_token": source_config.iam_keys.session_token, } return result
[docs] @classmethod def validate_schema(cls, config_data: dict) -> None: """Validate configuration data against JSON schema.""" schema_path = Path(__file__).parent / "config-schema.json" if not schema_path.exists(): LOG.warning("JSON schema file not found at %s", schema_path) return try: with open(schema_path, encoding="utf-8") as f: schema = json.load(f) # Validate the config data jsonschema.validate(config_data, schema) LOG.debug("Configuration validation against JSON schema passed") except jsonschema.ValidationError as error: error_path = ( " -> ".join(str(p) for p in error.absolute_path) if error.absolute_path else "root" ) # Use the existing sanitizer to handle exception messages full_error_message = str(error) sanitized_message = sanitize_exception_message(full_error_message) # Log the error - sanitization will be handled by logger LOG.error( "Configuration validation failed at %s: %s", error_path, sanitized_message, ) raise ValueError( f"Configuration validation failed at {error_path}: {sanitized_message}" ) from error except jsonschema.SchemaError as error: LOG.error("JSON schema error: %s", error.message) raise ValueError(f"Invalid JSON schema: {error.message}") from error except Exception as error: LOG.error("Error validating configuration against schema: %s", str(error)) raise ValueError(f"Schema validation error: {str(error)}") from error
@classmethod def _validate_services( cls, services: dict[str, ServiceConfig], dynamic_services: DynamicServicesConfig | None = None, ) -> None: """Validate service configurations after inheritance.""" # Check if we have either static services or enabled dynamic services has_static_services = bool(services) has_dynamic_services = dynamic_services and dynamic_services.enabled if not has_static_services and not has_dynamic_services: raise ValueError( "At least one service must be configured. " "Either define static services or enable dynamic_services." ) for service_name, service_config in services.items(): source_creds = service_config.source_credentials assumed_role = service_config.assumed_role if not source_creds.region: raise ValueError(f"AWS region is required for service '{service_name}'") if not assumed_role.RoleArn: raise ValueError( f"AWS role ARN is required for service '{service_name}'" ) # No additional validation needed: # - JSON schema validates required fields within auth sections # - _create_source_credentials_config validates presence of auth sections # Authentication method auto-detected from iam_keys or iam_profile presence LOG.info("Configuration validation passed for %d services", len(services))