Skip to content

Commit

Permalink
Fix CORS headers not set on exceptions (spec-first#1821)
Browse files Browse the repository at this point in the history
Fixes spec-first#1820.
Correct error handling in response to CORS.


Changes proposed in this pull request:

- Add a MiddlewarePosition before Exception handling so CORS is always
returned
- Add ServerError Middleware to handle unhandled errors between the
ServerError- and ExceptionMiddleware
 - Update corresponding docs

---------

Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
  • Loading branch information
nielsbox and RobbeSneyders committed Nov 30, 2023
1 parent 0857710 commit 0082d7a
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 4 deletions.
17 changes: 17 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.server_error import ServerErrorMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
Expand Down Expand Up @@ -93,6 +94,21 @@ def replace(self, **changes) -> "_Options":
class MiddlewarePosition(enum.Enum):
"""Positions to insert a middleware"""

BEFORE_EXCEPTION = ExceptionMiddleware
"""Add before the :class:`ExceptionMiddleware`. This is useful if you want your changes to
affect the way exceptions are handled, such as a custom error handler.
Be mindful that security has not yet been applied at this stage.
Additionally, the inserted middleware is positioned before the RoutingMiddleware, so you cannot
leverage any routing information yet and should implement your middleware to work globally
instead of on an operation level.
Useful for middleware which should also be applied to error responses. Note that errors
raised here will not be handled by the exception handlers and will always result in an
internal server error response.
:meta hide-value:
"""
BEFORE_SWAGGER = SwaggerUIMiddleware
"""Add before the :class:`SwaggerUIMiddleware`. This is useful if you want your changes to
affect the Swagger UI, such as a path altering middleware that should also alter the paths
Expand Down Expand Up @@ -165,6 +181,7 @@ class ConnexionMiddleware:
provided application."""

default_middlewares = [
ServerErrorMiddleware,
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,
Expand Down
36 changes: 36 additions & 0 deletions connexion/middleware/server_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
import typing as t

from starlette.middleware.errors import (
ServerErrorMiddleware as StarletteServerErrorMiddleware,
)
from starlette.types import ASGIApp

from connexion.exceptions import InternalServerError
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.exceptions import connexion_wrapper
from connexion.types import MaybeAwaitable

logger = logging.getLogger(__name__)


class ServerErrorMiddleware(StarletteServerErrorMiddleware):
"""Subclass of starlette ServerErrorMiddleware to change handling of Unhandled Server
exceptions to existing connexion behavior."""

def __init__(
self,
next_app: ASGIApp,
handler: t.Optional[
t.Callable[[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]]
] = None,
):
handler = connexion_wrapper(handler) if handler else None
super().__init__(next_app, handler=handler)

@staticmethod
@connexion_wrapper
def error_response(_request: ConnexionRequest, exc: Exception) -> ConnexionResponse:
"""Default handler for any unhandled Exception"""
logger.error("%r", exc, exc_info=exc)
return InternalServerError().to_problem()
6 changes: 3 additions & 3 deletions docs/cookbook.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -62,7 +62,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -96,7 +96,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down
2 changes: 2 additions & 0 deletions docs/middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ following order:
.. csv-table::
:widths: 30, 70

**ServerErrorMiddleware**, "Returns server errors for any exceptions not caught by the
ExceptionMiddleware"
**ExceptionMiddleware**, Handles exceptions raised by the middleware stack or application
**SwaggerUIMiddleware**, Adds a Swagger UI to your application
**RoutingMiddleware**, "Routes incoming requests to the right operation defined in the
Expand Down
22 changes: 21 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging

import pytest
from connexion.middleware import MiddlewarePosition
from starlette.middleware.cors import CORSMiddleware
from starlette.types import Receive, Scope, Send

from conftest import FIXTURES_FOLDER, OPENAPI3_SPEC, build_app_from_fixture
from conftest import OPENAPI3_SPEC, build_app_from_fixture


@pytest.fixture(scope="session")
Expand All @@ -20,6 +22,24 @@ def simple_openapi_app(app_class):
)


@pytest.fixture(scope="session")
def cors_openapi_app(app_class):
app = build_app_from_fixture(
"simple",
app_class=app_class,
spec_file=OPENAPI3_SPEC,
validate_responses=True,
)

app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["http://localhost"],
)

return app


@pytest.fixture(scope="session")
def reverse_proxied_app(spec, app_class):
class ReverseProxied:
Expand Down
44 changes: 44 additions & 0 deletions tests/api/test_cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json


def test_cors_valid(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post("/v1.0/goodday/dan", data={}, headers={"Origin": origin})
assert response.status_code == 201
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]


def test_cors_invalid(cors_openapi_app):
app_client = cors_openapi_app.test_client()
response = app_client.options(
"/v1.0/goodday/dan",
headers={"Origin": "http://0.0.0.0", "Access-Control-Request-Method": "POST"},
)
assert response.status_code == 400
assert "Access-Control-Allow-Origin" not in response.headers


def test_cors_validation_error(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post(
"/v1.0/body-not-allowed-additional-properties",
data={},
headers={"Origin": origin},
)
assert response.status_code == 400
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]


def test_cors_server_error(cors_openapi_app):
app_client = cors_openapi_app.test_client()
origin = "http://localhost"
response = app_client.post(
"/v1.0/goodday/noheader", data={}, headers={"Origin": origin}
)
assert response.status_code == 500
assert "Access-Control-Allow-Origin" in response.headers
assert origin == response.headers["Access-Control-Allow-Origin"]

0 comments on commit 0082d7a

Please sign in to comment.