Source code for credproxy.app

#  SPDX-License-Identifier: MPL-2.0
#  Copyright 2025-present John Mille <john@ews-network.net>

from __future__ import annotations

import time
from typing import TYPE_CHECKING

from flask import Flask, g, request

from credproxy.config import Config as AppConfig
from credproxy.logger import LOG, setup_json_logging
from credproxy.routes import api_bp, register_metrics_route
from credproxy.metrics import init_metrics, record_request
from credproxy.file_watcher import FileWatcherService
from credproxy.credentials_handler import CredentialsHandler


if TYPE_CHECKING:
    from flask import Flask


[docs] def set_service_context(): """Set service name in Flask's g context for access logging.""" # Only set service context for credential requests if request.endpoint == "api.get_credentials": # Get the authorization token from request header provided_token = request.headers.get("Authorization") if provided_token: # Get config from Flask current app from flask import current_app config = current_app.config.get("credproxy_config") if config: # Find the service name for this token using instant lookup service_name = config.get_service_name_by_token(provided_token) if service_name: g.service_name = service_name # Debug logging LOG.debug("Set service context: service_name=%s", service_name) # Also store source_file for logging service = config.services.get(service_name) if service and service.source_file: g.service_source_file = service.source_file else: LOG.debug( "No service found for token: %s", provided_token[:10] + "..." ) else: LOG.debug("No authorization token found in request") else: LOG.debug("Skipping service context: endpoint=%s", request.endpoint)
[docs] def init_app(config: AppConfig) -> Flask: """Create and configure Flask app.""" app = Flask(__name__) # Disable Flask's default logging handler to prevent duplicate logs app.config["LOGGER_HANDLER_POLICY"] = "never" # No need to configure LOG_KEYS - SimpleJsonFormatter includes all essential fields # Suppress Flask development server warning app.config["ENV"] = "production" # Store config in app context app.config["credproxy_config"] = config # Create credentials handler credentials_handler = CredentialsHandler(config) app.config["credentials_handler"] = credentials_handler # Create and start file watcher service file_watcher = FileWatcherService(config) app.config["file_watcher"] = file_watcher # Setup dynamic JSON logging setup_json_logging(app) # Initialize Prometheus metrics init_metrics() # Add request ID generation and metrics timing @app.before_request def make_request_id() -> None: # Use timestamp-based ID (microseconds for uniqueness) g.request_id = f"{request.remote_addr}-{int(time.time() * 1000000)}" # Record request start time for metrics g.start_time = time.time() # Register before_request handler for service context app.before_request(set_service_context) # Add shutdown check middleware @app.before_request def check_shutdown_flag(): if app.config.get("_shutdown_requested", False): LOG.info("=== SHUTDOWN IN PROGRESS - Rejecting new requests ===") return "Service shutting down", 503 # Add metrics recording after each request @app.after_request def record_metrics(response): # Only record metrics for credential requests if request.endpoint == "api.get_credentials": try: # Calculate request duration duration = time.time() - g.get("start_time", time.time()) service_name = g.get("service_name", "unknown") # Determine result based on status code if response.status_code == 200: result = "success" elif response.status_code == 401: result = "denied_missing_token" elif response.status_code == 403: result = "denied_invalid_token" else: result = "error" # Debug logging LOG.debug( "Recording metrics: endpoint=%s, service_name=%s, result=%s, " "status_code=%s", request.endpoint, service_name, result, response.status_code, ) # Record the request record_request( result=result, service_name=service_name, duration=duration ) except Exception as error: LOG.error("Failed to record request metrics: %s", error) else: # Debug logging for non-credential endpoints LOG.debug("Skipping metrics recording: endpoint=%s", request.endpoint) return response # Register metrics route if enabled register_metrics_route(app, config) # Register blueprint app.register_blueprint(api_bp) # Start file watcher service try: file_watcher.start() except Exception as error: LOG.error("Failed to start file watcher service: %s", error) # Continue without file watcher - not critical for basic operation return app