Source code for credproxy.credentials_handler
# SPDX-License-Identifier: MPL-2.0
# Copyright 2025-present John Mille <john@ews-network.net>
from __future__ import annotations
import time
import threading
from typing import TYPE_CHECKING
from dataclasses import asdict, dataclass
import boto3
from botocore.exceptions import ClientError
from credproxy.logger import LOG
if TYPE_CHECKING:
from credproxy.config import Config, ServiceConfig
[docs]
@dataclass
class ServiceCredentialsManager:
"""Service credentials manager with caching and expiry time."""
aws_access_key_id: str
aws_secret_access_key: str
session_token: str
expiry: float
[docs]
def is_expired(self) -> bool:
"""Check if credentials are expired."""
return time.time() > self.expiry
[docs]
def get_sensitive_values(self) -> list[str]:
"""Get list of sensitive values that should be sanitized.
Returns:
List of credential values
"""
return [
self.aws_access_key_id,
self.aws_secret_access_key,
self.session_token,
]
[docs]
def to_dict(self) -> dict:
"""Convert to dictionary format for API response."""
from datetime import datetime, timezone
expiration_iso = datetime.fromtimestamp(self.expiry, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
return {
"AccessKeyId": self.aws_access_key_id,
"SecretAccessKey": self.aws_secret_access_key,
"Token": self.session_token,
"Expiration": expiration_iso,
}
[docs]
class CredentialsHandler:
"""Simple credentials handler with caching and expiry."""
[docs]
def __init__(self, config: Config):
self.config = config
self.cache: dict[str, ServiceCredentialsManager] = {}
self._cache_lock = threading.RLock()
self._cleanup_thread: threading.Thread | None = None
self._stop_cleanup = threading.Event()
self._start_cache_cleanup()
def _start_cache_cleanup(self) -> None:
"""Start background thread for periodic cache cleanup."""
def cleanup_expired():
"""Periodically clean up expired credentials from cache."""
while not self._stop_cleanup.is_set():
try:
# Wait for 60 seconds or until stop is signaled
if self._stop_cleanup.wait(timeout=60):
break
# Clean up expired entries
with self._cache_lock:
expired_services = [
service_name
for service_name, creds in self.cache.items()
if creds.is_expired()
]
for service_name in expired_services:
# Unregister sensitive values before removing from cache
from credproxy.sanitizer import unregister_sensitive_value
creds = self.cache[service_name]
for value in creds.get_sensitive_values():
unregister_sensitive_value(value)
del self.cache[service_name]
LOG.debug(
"Removed expired credentials from cache: %s",
service_name,
)
if expired_services:
LOG.info(
"Cache cleanup: removed %d expired credential entries",
len(expired_services),
)
except Exception as error:
LOG.error("Error during cache cleanup")
LOG.exception(error)
self._cleanup_thread = threading.Thread(
target=cleanup_expired, daemon=True, name="cache-cleanup"
)
self._cleanup_thread.start()
LOG.debug("Started background cache cleanup thread")
[docs]
def cleanup(self) -> None:
"""Clean up resources during graceful shutdown."""
# Stop the cleanup thread
if self._cleanup_thread:
LOG.info("Stopping cache cleanup thread")
self._stop_cleanup.set()
self._cleanup_thread.join(timeout=5)
LOG.info("Cache cleanup thread stopped")
with self._cache_lock:
cache_size = len(self.cache)
if cache_size > 0:
LOG.info(
"Cleaning up %d cached credential entries during shutdown",
cache_size,
)
# Unregister all sensitive values
from credproxy.sanitizer import unregister_sensitive_value
for creds in self.cache.values():
for value in creds.get_sensitive_values():
unregister_sensitive_value(value)
self.cache.clear()
LOG.info("Credential cache cleared successfully")
else:
LOG.info("No cached credentials to clean up")
[docs]
def get_credentials(self, service_name: str) -> dict:
"""Get credentials for a service, using cache if not expired."""
with self._cache_lock:
# Check cache first
if service_name in self.cache:
cached = self.cache[service_name]
if not cached.is_expired():
LOG.debug("Using cached credentials for %s", service_name)
return cached.to_dict()
else:
# Cache expired
pass
else:
# Not in cache
pass
# Generate new credentials
LOG.info("Generating new credentials for %s", service_name)
service_config = self.config.services[service_name]
credentials = self._assume_role(service_config)
# Register temporary credentials for sanitization
from credproxy.sanitizer import register_sensitive_value
register_sensitive_value(credentials["AccessKeyId"])
register_sensitive_value(credentials["SecretAccessKey"])
register_sensitive_value(credentials["SessionToken"])
# Use the exact expiration from STS assume role API call
expiry_time = credentials["Expiration"].timestamp()
service_creds = ServiceCredentialsManager(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
session_token=credentials["SessionToken"],
expiry=expiry_time,
)
with self._cache_lock:
self.cache[service_name] = service_creds
return service_creds.to_dict()
def _assume_role(self, service_config: ServiceConfig) -> dict:
"""Assume role for service and return credentials."""
# Get service name for metrics
service_name = next(
(
name
for name, config in self.config.services.items()
if config == service_config
),
"unknown",
)
try:
# Get AWS config for this service
aws_config = self._get_aws_config(service_config)
# Create STS client with profile if specified
profile_name = aws_config.pop("profile_name", None)
if profile_name:
session = boto3.Session(profile_name=profile_name)
sts_client = session.client("sts", **aws_config)
else:
sts_client = boto3.client("sts", **aws_config)
# Convert dataclass to dict and filter out None values for boto3 API call
assumed_role_dict = asdict(service_config.assumed_role)
assume_role_params = {
k: v for k, v in assumed_role_dict.items() if v is not None
}
# Assume role
response = sts_client.assume_role(**assume_role_params)
# AWS operation successful - no detailed metrics needed
return response["Credentials"]
except ClientError as error:
# AWS operation failed - no detailed metrics needed
LOG.error("Failed to assume role for %s: %s", service_name, str(error))
raise
def _get_aws_config(self, service_config: ServiceConfig) -> dict:
"""Get AWS configuration for a service."""
service_creds = service_config.source_credentials
default_creds = self.config.aws_defaults
# Use or operator for clean fallbacks
region = (service_creds and service_creds.region) or (
default_creds and default_creds.region
)
profile_config = (service_creds and service_creds.iam_profile) or (
default_creds and default_creds.iam_profile
)
keys = (service_creds and service_creds.iam_keys) or (
default_creds and default_creds.iam_keys
)
aws_config = {"region_name": region}
# Auto-detect auth method based on presence of config objects
if profile_config and profile_config.profile_name:
# IAM profile authentication
aws_config["profile_name"] = profile_config.profile_name
elif keys:
# IAM keys authentication
aws_config.update(
{
"aws_access_key_id": keys.aws_access_key_id,
"aws_secret_access_key": keys.aws_secret_access_key,
}
)
if hasattr(keys, "session_token") and keys.session_token:
aws_config["aws_session_token"] = keys.session_token
# If neither profile_config nor keys present, use default SDK behavior
return aws_config