From a39f874ad56c3eb3a29306ec75957e05cac99210 Mon Sep 17 00:00:00 2001 From: Thomas Sileo Date: Thu, 14 Jul 2022 12:13:23 +0200 Subject: [PATCH] Switch to raw ASGI middleware --- app/database.py | 5 +- app/main.py | 159 ++++++++++++++++++++++++++++++------------------ poetry.lock | 2 +- pyproject.toml | 1 + 4 files changed, 107 insertions(+), 60 deletions(-) diff --git a/app/database.py b/app/database.py index 1e2e31d..beba7f0 100644 --- a/app/database.py +++ b/app/database.py @@ -29,4 +29,7 @@ def now() -> datetime.datetime: async def get_db_session() -> AsyncGenerator[AsyncSession, None]: async with async_session() as session: - yield session + try: + yield session + finally: + await session.close() diff --git a/app/main.py b/app/main.py index 0f94967..476b5db 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,10 @@ from typing import MutableMapping from typing import Type import httpx +from asgiref.typing import ASGI3Application +from asgiref.typing import ASGIReceiveCallable +from asgiref.typing import ASGISendCallable +from asgiref.typing import Scope from cachetools import LFUCache from fastapi import Depends from fastapi import FastAPI @@ -28,7 +32,9 @@ from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import joinedload from starlette.background import BackgroundTask +from starlette.datastructures import MutableHeaders from starlette.responses import JSONResponse +from starlette.types import Message from app import activitypub as ap from app import admin @@ -82,12 +88,107 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac # - [ ] Dockerization # - [ ] cleanup tasks + +class CustomMiddleware: + def __init__( + self, + app: "ASGI3Application", + ) -> None: + self.app = app + + async def __call__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: + """ + if scope["type"] in ("http", "websocket"): + scope = cast(HTTPScope | WebSocketScope, scope) + client_addr: tuple[str, int] | None = scope.get("client") + client_host = client_addr[0] if client_addr else None + + if self.always_trust or client_host in self.trusted_hosts: + headers = dict(scope["headers"]) + + if b"x-forwarded-proto" in headers: + # Determine if the incoming request was http or https based on + # the X-Forwarded-Proto header. + x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1") + scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index] + + if b"x-forwarded-for" in headers: + # Determine the client address from the last trusted IP in the + # X-Forwarded-For header. We've lost the connecting client's port + # information by now, so only include the host. + x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") + x_forwarded_for_hosts = [ + item.strip() for item in x_forwarded_for.split(",") + ] + host = self.get_trusted_client_host(x_forwarded_for_hosts) + port = 0 + scope["client"] = (host, port) # type: ignore[arg-type] + """ + + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + instance = {"http_status_code": None} + + start_time = time.perf_counter() + request_id = os.urandom(8).hex() + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + instance["http_status_code"] = message["status"] + + headers = MutableHeaders(scope=message) + headers["X-Request-ID"] = request_id + headers["Server"] = "microblogpub" + headers[ + "referrer-policy" + ] = "no-referrer, strict-origin-when-cross-origin" + headers["x-content-type-options"] = "nosniff" + headers["x-xss-protection"] = "1; mode=block" + headers["x-frame-options"] = "SAMEORIGIN" + # TODO(ts): disallow inline CSS? + headers["content-security-policy"] = ( + "default-src 'self'" + " style-src 'self' 'unsafe-inline';" + ) + if not DEBUG: + headers[ + "strict-transport-security" + ] = "max-age=63072000; includeSubdomains" + + await send(message) # type: ignore + + with logger.contextualize(request_id=request_id): + client_host, client_port = scope["client"] # type: ignore + scheme = scope["scheme"] + server_host, server_port = scope["server"] # type: ignore + request_method = scope["method"] + request_path = scope["path"] + logger.info( + f"{client_host}:{client_port} - " + f"{request_method} {scheme}://{server_host}:{server_port}{request_path}" + ) + try: + await self.app(scope, receive, send_wrapper) # type: ignore + finally: + elapsed_time = time.perf_counter() - start_time + logger.info( + f"status_code={instance['http_status_code']} " + f"{elapsed_time=:.2f}s" + ) + + return None + + app = FastAPI(docs_url=None, redoc_url=None) app.mount("/static", StaticFiles(directory="app/static"), name="static") app.include_router(admin.router, prefix="/admin") app.include_router(admin.unauthenticated_router, prefix="/admin") app.include_router(indieauth.router) app.include_router(webmentions.router) +app.add_middleware(CustomMiddleware) logger.configure(extra={"request_id": "no_req_id"}) logger.remove() @@ -100,64 +201,6 @@ logger_format = ( logger.add(sys.stdout, format=logger_format) -@app.middleware("http") -async def request_middleware(request, call_next): - start_time = time.perf_counter() - request_id = os.urandom(8).hex() - with logger.contextualize(request_id=request_id): - logger.info( - f"{request.client.host}:{request.client.port} - " - f"{request.method} {request.url}" - ) - try: - response = await call_next(request) - response.headers["X-Request-ID"] = request_id - response.headers["Server"] = "microblogpub" - elapsed_time = time.perf_counter() - start_time - logger.info(f"status_code={response.status_code} {elapsed_time=:.2f}s") - return response - except Exception: - logger.exception("Request failed") - raise - - -@app.middleware("http") -async def add_security_headers(request: Request, call_next): - try: - response = await call_next(request) - except RuntimeError as exc: - # https://github.com/encode/starlette/discussions/1527#discussioncomment-2234702 - if await request.is_disconnected() and str(exc) == "No response returned.": - return Response(status_code=204) - - response.headers["referrer-policy"] = "no-referrer, strict-origin-when-cross-origin" - response.headers["x-content-type-options"] = "nosniff" - response.headers["x-xss-protection"] = "1; mode=block" - response.headers["x-frame-options"] = "SAMEORIGIN" - if request.url.path.startswith("/admin/login") or ( - is_current_user_admin(request) - and not ( - request.url.path.startswith("/attachments") - or request.url.path.startswith("/proxy") - or request.url.path.startswith("/static") - ) - ): - # Prevent caching (to prevent caching CSRF tokens) - response.headers["Cache-Control"] = "private" - - # TODO(ts): disallow inline CSS? - if DEBUG: - return response - response.headers["content-security-policy"] = ( - "default-src 'self'" + " style-src 'self' 'unsafe-inline';" - ) - if not DEBUG: - response.headers[ - "strict-transport-security" - ] = "max-age=63072000; includeSubdomains" - return response - - class ActivityPubResponse(JSONResponse): media_type = "application/activity+json" diff --git a/poetry.lock b/poetry.lock index 04a5747..c7f7540 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1202,7 +1202,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "7bc5ba65a004438ac015dcd01c27e1d327dbf491f9f881a48a2a790bb0bbf710" +content-hash = "4353bb98b40254eea5277799de3329b6658e21178a6da44113e78c897c7f140b" [metadata.files] aiosqlite = [ diff --git a/pyproject.toml b/pyproject.toml index 9e284c3..6e2bd10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ aiosqlite = "^0.17.0" cachetools = "^5.2.0" humanize = "^4.2.3" tabulate = "^0.8.10" +asgiref = "^3.5.2" [tool.poetry.dev-dependencies] black = "^22.3.0"