FastAPI Mastery
Topic 8 / 22
Topic 8
Middleware
Middleware sits between the client and your route handlers — it intercepts every request before it reaches your routes, and every response before it leaves your app. Think of it as a security checkpoint, logger, or transformer that wraps your entire application.
💡
Real-world analogy: Middleware is like the reception desk at an office. Every visitor (request) passes through reception first — they can be checked, logged, redirected, or turned away — before reaching the actual office (route). On the way out, reception can also stamp the visitor's pass (add response headers).
8.1 Middleware Fundamentals
🌊
Request Flow
Every HTTP request passes through a middleware stack — a chain of middleware functions — before reaching your route. The stack executes in order on the way in, and in reverse order on the way out.
Incoming Request │ ▼ ┌──────────────────────────┐ │ Middleware 1 (CORS) │ ← runs first on request │ ┌────────────────────┐ │ │ │ Middleware 2 (Log) │ │ ← runs second │ │ ┌──────────────┐ │ │ │ │ │ Your Route │ │ │ ← business logic │ │ └──────────────┘ │ │ │ │ ↑ response here │ │ │ └────────────────────┘ │ ← Middleware 2 wraps response └──────────────────────────┘ ← Middleware 1 wraps response │ ▼ Outgoing Response Key: each middleware calls `await call_next(request)` to pass to the next layer
The middleware lifecycle — 3 phases:
python
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware

app = FastAPI()

class LifecycleMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):

        # ── PHASE 1: Before the route runs ──────────────────
        print(f"→ Incoming: {request.method} {request.url.path}")
        # You can: inspect headers, validate tokens, log, block requests

        # ── PHASE 2: Call the actual route ──────────────────
        response = await call_next(request)
        # Everything ABOVE this line runs BEFORE the route
        # Everything BELOW this line runs AFTER the route

        # ── PHASE 3: After the route runs ───────────────────
        print(f"← Outgoing: status={response.status_code}")
        # You can: add response headers, log timing, compress

        return response  # must return the response!

app.add_middleware(LifecycleMiddleware)

@app.get("/hello")
def hello():
    return {"message": "Hello!"}

# Console output when hitting GET /hello:
# → Incoming: GET /hello
# ← Outgoing: status=200
↩️
Response Flow
After await call_next(request) returns, you hold the response object. You can inspect or mutate it — add headers, change the status code, or even replace the body entirely — before returning it to the client.
Example — adding custom headers to every response:
python
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time, uuid

app = FastAPI()

class ResponseEnricherMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start = time.perf_counter()
        request_id = str(uuid.uuid4())[:8]

        # Pass request_id downstream (other middleware/routes can read it)
        request.state.request_id = request_id

        response = await call_next(request)

        # Mutate the response — add headers
        elapsed_ms = (time.perf_counter() - start) * 1000
        response.headers["X-Request-ID"] = request_id
        response.headers["X-Process-Time"] = f"{elapsed_ms:.2f}ms"
        response.headers["X-Powered-By"] = "FastAPI"

        return response

app.add_middleware(ResponseEnricherMiddleware)

@app.get("/ping")
def ping(request: Request):
    return {
        "pong": True,
        "request_id": request.state.request_id  # set by middleware!
    }

# Response headers will contain:
# X-Request-ID: a1b2c3d4
# X-Process-Time: 1.23ms
# X-Powered-By: FastAPI
request.state is a scratchpad you can use to pass data from middleware to your route handlers — like a request-scoped context object.
Middleware order matters — stacking example:
python
# Middleware is applied in REVERSE order of add_middleware calls!
# Last added = outermost layer (runs first)

app.add_middleware(MiddlewareA)   # runs SECOND (inner)
app.add_middleware(MiddlewareB)   # runs FIRST (outer)

# Request flow:  MiddlewareB → MiddlewareA → Route
# Response flow: Route → MiddlewareA → MiddlewareB
8.2 Built-in Middleware
🌐
CORS — Cross-Origin Resource Sharing
CORS controls which domains are allowed to make requests to your API from a browser. Without CORS middleware, your React/Vue frontend on localhost:3000 would be blocked from calling your API on localhost:8000.
💡
Why does CORS exist? Browsers enforce the Same-Origin Policy — JS can only call APIs on the same domain unless the API explicitly allows cross-origin requests via CORS headers.
Browser (localhost:3000) ──► OPTIONS /api/users ──► FastAPI (localhost:8000) ◄── Access-Control-Allow-Origin: * ◄── ──► GET /api/users ──────────────► ◄── 200 OK ◄───────────────────── Browser first sends a "preflight" OPTIONS request. CORS middleware replies with permission headers. Then browser proceeds with the real request.
Example 1 — Development setup (allow all origins):
python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# ⚠️  Development only — never use allow_origins=["*"] in production!
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],        # any domain can access
    allow_credentials=False,   # must be False with allow_origins=["*"]
    allow_methods=["*"],        # GET, POST, PUT, DELETE, etc.
    allow_headers=["*"],        # any request header
)

@app.get("/public-data")
def public_data():
    return {"data": "accessible from any origin"}
Example 2 — Production setup (specific origins):
python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

ALLOWED_ORIGINS = [
    "https://myapp.com",
    "https://www.myapp.com",
    "https://admin.myapp.com",
    "http://localhost:3000",  # local dev frontend
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,        # allow cookies/auth headers
    allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
    allow_headers=["Authorization", "Content-Type", "X-API-Key"],
    expose_headers=["X-Request-ID", "X-Process-Time"],
    max_age=3600,  # preflight cache: 1 hour (reduces OPTIONS requests)
)

@app.get("/protected-data")
def protected_data():
    return {"data": "only accessible from allowed origins"}
ParameterWhat it controlsCommon value
allow_originsWhich domains can access the API["https://myapp.com"]
allow_credentialsAllow cookies & auth headersTrue (with specific origins)
allow_methodsWhich HTTP methods are allowed["GET","POST","PUT","DELETE"]
allow_headersWhich request headers are allowed["Authorization","Content-Type"]
expose_headersWhich response headers JS can read["X-Request-ID"]
max_agePreflight cache duration (seconds)3600
🗜️
GZip — Response Compression
GZip middleware automatically compresses responses when the client supports it (sends Accept-Encoding: gzip). This can reduce response sizes by 60–80% for JSON payloads — a huge win for large API responses.
python
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware

app = FastAPI()

app.add_middleware(
    GZipMiddleware,
    minimum_size=1000  # only compress responses > 1000 bytes
                        # small responses aren't worth compressing
)

@app.get("/large-data")
def large_data():
    # FastAPI will compress this automatically if client supports it
    return {"items": [f"item-{i}" for i in range(1000)]}

# Without GZip:  ~12,000 bytes
# With GZip:      ~1,800 bytes  ← 85% smaller!
GZip middleware only activates if the client sends Accept-Encoding: gzip. Modern browsers and httpx/requests send this by default.
🛡️
Trusted Host — Host Header Validation
TrustedHostMiddleware rejects requests with a Host header that doesn't match your allowed domains. This protects against Host header injection attacks — where attackers send forged Host headers to trick your app.
python
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware

app = FastAPI()

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=[
        "myapp.com",
        "*.myapp.com",   # wildcard subdomains
        "localhost",
        "127.0.0.1",
    ]
)

@app.get("/data")
def data():
    return {"ok": True}

# Request with Host: myapp.com       → 200 OK
# Request with Host: evil-site.com   → 400 Bad Request (rejected!)
# Request with Host: api.myapp.com   → 200 OK (wildcard match)
🔒
HTTPS Redirect
HTTPSRedirectMiddleware automatically redirects all HTTP requests to HTTPS. Any request coming in on plain HTTP gets a 307 Temporary Redirect to the same URL but with https://.
python
from fastapi import FastAPI
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware

app = FastAPI()

# Only add in production — will break local development!
import os
if os.getenv("ENV") == "production":
    app.add_middleware(HTTPSRedirectMiddleware)

@app.get("/secure")
def secure():
    return {"message": "Secure!"}

# http://myapp.com/secure  →  307 redirect to  https://myapp.com/secure
⚠️
Don't use in development! Your local server likely doesn't have TLS configured, so this will cause redirect loops. Guard it behind an environment check.
8.3 Custom Middleware
📝
Logging Middleware
A logging middleware records every request and response — method, path, status code, timing. This gives you full visibility into traffic without cluttering your route handlers.
Example 1 — Basic request/response logger:
python
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time, logging

# Set up structured logger
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger("api")

app = FastAPI()

class RequestLoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start = time.perf_counter()

        # Log request (before route runs)
        logger.info(
            f"REQUEST  {request.method} {request.url.path}"
            f" | client={request.client.host}"
            f" | ua={request.headers.get('user-agent','unknown')[:40]}"
        )

        try:
            response = await call_next(request)
            elapsed = (time.perf_counter() - start) * 1000

            # Log response (after route runs)
            logger.info(
                f"RESPONSE {request.method} {request.url.path}"
                f" | status={response.status_code}"
                f" | time={elapsed:.1f}ms"
            )
            return response

        except Exception as e:
            elapsed = (time.perf_counter() - start) * 1000
            logger.error(
                f"ERROR    {request.method} {request.url.path}"
                f" | error={e!r} | time={elapsed:.1f}ms"
            )
            raise  # re-raise so FastAPI handles it

app.add_middleware(RequestLoggingMiddleware)

@app.get("/items")
def get_items():
    return ["item1", "item2"]

# Console output:
# 2024-01-15 | INFO | REQUEST  GET /items | client=127.0.0.1 | ua=curl/7.68
# 2024-01-15 | INFO | RESPONSE GET /items | status=200 | time=2.1ms
Example 2 — Structured JSON logging (production-grade):
python
import json, time, uuid, logging
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware

app = FastAPI()

class StructuredLogMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        request_id = str(uuid.uuid4())
        start = time.perf_counter()

        response = await call_next(request)
        elapsed_ms = (time.perf_counter() - start) * 1000

        # Structured log as JSON — parseable by tools like Datadog, ELK
        log_entry = {
            "request_id": request_id,
            "method": request.method,
            "path": request.url.path,
            "query": str(request.query_params),
            "status": response.status_code,
            "duration_ms": round(elapsed_ms, 2),
            "client_ip": request.client.host,
        }
        print(json.dumps(log_entry))  # or: logger.info(json.dumps(log_entry))

        response.headers["X-Request-ID"] = request_id
        return response

app.add_middleware(StructuredLogMiddleware)
📊
Metrics Middleware
Metrics middleware tracks aggregate numbers about your API — request counts, latency percentiles, error rates. These power dashboards and alerting. Here's how to build one from scratch (usable with any metrics backend).
python
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from collections import defaultdict
import time

app = FastAPI()

# ---- In-memory metrics store ----
# In production: use prometheus_client, statsd, or OpenTelemetry
metrics = {
    "request_count": defaultdict(int),    # {route: count}
    "error_count": defaultdict(int),      # {route: error_count}
    "total_latency": defaultdict(float),  # {route: total_ms}
}

class MetricsMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        route = f"{request.method} {request.url.path}"
        start = time.perf_counter()

        response = await call_next(request)

        elapsed_ms = (time.perf_counter() - start) * 1000

        # Record metrics
        metrics["request_count"][route] += 1
        metrics["total_latency"][route] += elapsed_ms
        if response.status_code >= 400:
            metrics["error_count"][route] += 1

        return response

app.add_middleware(MetricsMiddleware)

# ---- Metrics endpoint ----
@app.get("/metrics")
def get_metrics():
    result = {}
    for route, count in metrics["request_count"].items():
        avg_latency = metrics["total_latency"][route] / count
        errors = metrics["error_count"][route]
        result[route] = {
            "requests": count,
            "avg_latency_ms": round(avg_latency, 2),
            "errors": errors,
            "error_rate": f"{(errors/count)*100:.1f}%"
        }
    return result

@app.get("/users")
def get_users():
    return ["Alice", "Bob"]

# After some requests, GET /metrics returns:
# {
#   "GET /users": {"requests": 42, "avg_latency_ms": 1.8, "errors": 0, "error_rate": "0.0%"},
#   "GET /metrics": {"requests": 5, "avg_latency_ms": 0.4, "errors": 0, "error_rate": "0.0%"}
# }
In production, use prometheus_client and expose a /metrics endpoint that Prometheus scrapes. The middleware structure is identical — just replace the dict with Prometheus counters/histograms.
🔖
Correlation IDs
A Correlation ID is a unique ID assigned to each request that flows through every layer of your system — middleware, routes, services, database calls, logs. When something goes wrong, you can grep all logs by the same correlation ID to trace the full request journey.
Client ──► Middleware assigns ID: abc-123 ──► Route logs with ID: abc-123 ──► Service logs with ID: abc-123 ──► DB query logged with ID: abc-123 Client ◄── Response header: X-Correlation-ID: abc-123 Now you can grep logs for "abc-123" and see the FULL story of one request
Example — Full correlation ID system:
python
from fastapi import FastAPI, Request, Depends
from starlette.middleware.base import BaseHTTPMiddleware
from contextvars import ContextVar
import uuid, logging

app = FastAPI()

# ContextVar: async-safe per-request storage (like threading.local but for async)
correlation_id_var: ContextVar[str] = ContextVar("correlation_id", default="none")

class CorrelationIDMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Accept ID from client (useful for distributed tracing)
        # or generate a new one if not provided
        cid = request.headers.get("X-Correlation-ID") or str(uuid.uuid4())

        # Store in context var — accessible from anywhere in this async task
        correlation_id_var.set(cid)
        request.state.correlation_id = cid

        response = await call_next(request)

        # Return the ID in response so client can reference it
        response.headers["X-Correlation-ID"] = cid
        return response

app.add_middleware(CorrelationIDMiddleware)

# ---- Helper to get current request's correlation ID ----
def get_correlation_id() -> str:
    return correlation_id_var.get()

# ---- Use in logging ----
class CorrelatedLogger:
    def __init__(self, name: str):
        self._logger = logging.getLogger(name)

    def info(self, msg: str):
        cid = get_correlation_id()
        self._logger.info(f"[{cid}] {msg}")

log = CorrelatedLogger("myapp")

# ---- Route — correlation ID is available everywhere ----
@app.get("/process")
def process(request: Request):
    cid = request.state.correlation_id
    log.info("Processing request in route")      # [abc-123] Processing...
    log.info("Calling external service")          # [abc-123] Calling...
    return {"processed": True, "correlation_id": cid}

# Client sends: X-Correlation-ID: my-trace-123
# All logs tagged: [my-trace-123] Processing request in route
# Response header: X-Correlation-ID: my-trace-123
ContextVar is crucial here. Regular global variables would mix up values between concurrent requests. ContextVar is async-safe — each coroutine gets its own isolated value.
Pure ASGI Middleware (Advanced)
BaseHTTPMiddleware is convenient but has a small overhead — it buffers responses. For maximum performance, you can write pure ASGI middleware that works at the raw ASGI protocol level with zero buffering.
Example — pure ASGI middleware:
python
from fastapi import FastAPI
from starlette.types import ASGIApp, Receive, Scope, Send
import time

class TimingMiddleware:
    def __init__(self, app: ASGIApp):
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] != "http":
            # Pass WebSocket/lifespan events through unchanged
            await self.app(scope, receive, send)
            return

        start = time.perf_counter()

        # Wrap the send callable to intercept response start
        async def send_with_timing(message):
            if message["type"] == "http.response.start":
                elapsed = (time.perf_counter() - start) * 1000
                # Inject header into response
                headers = dict(message.get("headers", []))
                headers[b"x-process-time"] = f"{elapsed:.2f}ms".encode()
                message["headers"] = list(headers.items())
            await send(message)

        await self.app(scope, receive, send_with_timing)

app = FastAPI()
app.add_middleware(TimingMiddleware)

@app.get("/fast")
async def fast():
    return {"fast": True}

# Response includes header: x-process-time: 0.42ms
BaseHTTPMiddlewarePure ASGI
ComplexitySimple ✅More complex
PerformanceSlight overhead (buffers body)Zero overhead ✅
Streaming supportLimitedFull ✅
Best forMost use casesHigh-throughput APIs, streaming
🧩
Combining Multiple Middleware
Real applications stack multiple middleware together. Here's a production-ready pattern combining CORS, logging, metrics, and correlation IDs — in the correct order.
python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
# assuming you defined these above:
# from middleware import CorrelationIDMiddleware, RequestLoggingMiddleware, MetricsMiddleware

app = FastAPI()

# ── Add middleware in REVERSE execution order ──────────────────
# Last added = OUTERMOST (runs first on request, last on response)

# 1. Correlation ID — outermost, so ID is available to all other middleware
app.add_middleware(CorrelationIDMiddleware)

# 2. Logging — wraps everything, so it captures real total time
app.add_middleware(RequestLoggingMiddleware)

# 3. Metrics — also wraps business logic
app.add_middleware(MetricsMiddleware)

# 4. GZip — compress before sending (near the edge)
app.add_middleware(GZipMiddleware, minimum_size=500)

# 5. CORS — must be outermost edge middleware to handle preflight
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://myapp.com"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 6. Trusted Host — security, outermost possible
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["myapp.com", "localhost"])

# Execution order (request): TrustedHost → CORS → GZip → Metrics → Logging → CorrelationID → Route
# Execution order (response): Route → CorrelationID → Logging → Metrics → GZip → CORS → TrustedHost
⚠️
Order tip: CORS and TrustedHost should be near the outermost layer — they can reject requests before any expensive work is done. GZip should wrap the route response, not the entire stack.
📋 Topic 8 Summary
MiddlewarePurposeSource
BaseHTTPMiddlewareBase class for custom middlewarestarlette
CORSMiddlewareAllow cross-origin browser requestsBuilt-in
GZipMiddlewareCompress large responses automaticallyBuilt-in
TrustedHostMiddlewareReject forged Host headersBuilt-in
HTTPSRedirectMiddlewareForce HTTPS in productionBuilt-in
Logging MiddlewareRecord every request/responseCustom
Metrics MiddlewareCount requests, measure latencyCustom
Correlation IDTrace one request across all logsCustom
Pure ASGIMax performance, no bufferingCustom
✅ Approve this topic to continue to Topic 9: Exception Handling