mirror of
https://github.com/correl/openapi-core.git
synced 2024-12-28 03:00:11 +00:00
Security providers and security models retouch
This commit is contained in:
parent
4ae5a085a3
commit
fde1b6bdc5
11 changed files with 131 additions and 64 deletions
|
@ -12,7 +12,4 @@ class SecurityRequirementsGenerator(object):
|
|||
def generate(self, security_spec):
|
||||
security_deref = self.dereferencer.dereference(security_spec)
|
||||
for security_requirement_spec in security_deref:
|
||||
name = next(iter(security_requirement_spec))
|
||||
scope_names = security_requirement_spec[name]
|
||||
|
||||
yield SecurityRequirement(name, scope_names=scope_names)
|
||||
yield SecurityRequirement(security_requirement_spec)
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
"""OpenAPI core security requirements models module"""
|
||||
|
||||
|
||||
class SecurityRequirement(object):
|
||||
class SecurityRequirement(dict):
|
||||
"""Represents an OpenAPI Security Requirement."""
|
||||
|
||||
def __init__(self, name, scope_names=None):
|
||||
self.name = name
|
||||
self.scope_names = scope_names or []
|
||||
pass
|
||||
|
|
0
openapi_core/security/__init__.py
Normal file
0
openapi_core/security/__init__.py
Normal file
5
openapi_core/security/exceptions.py
Normal file
5
openapi_core/security/exceptions.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from openapi_core.exceptions import OpenAPIError
|
||||
|
||||
|
||||
class SecurityError(OpenAPIError):
|
||||
pass
|
19
openapi_core/security/factories.py
Normal file
19
openapi_core/security/factories.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from openapi_core.schema.security_schemes.enums import SecuritySchemeType
|
||||
from openapi_core.security.providers import (
|
||||
ApiKeyProvider, HttpProvider, UnsupportedProvider,
|
||||
)
|
||||
|
||||
|
||||
class SecurityProviderFactory(object):
|
||||
|
||||
PROVIDERS = {
|
||||
SecuritySchemeType.API_KEY: ApiKeyProvider,
|
||||
SecuritySchemeType.HTTP: HttpProvider,
|
||||
}
|
||||
|
||||
def create(self, scheme):
|
||||
if scheme.type == SecuritySchemeType.API_KEY:
|
||||
return ApiKeyProvider(scheme)
|
||||
elif scheme.type == SecuritySchemeType.HTTP:
|
||||
return HttpProvider(scheme)
|
||||
return UnsupportedProvider(scheme)
|
47
openapi_core/security/providers.py
Normal file
47
openapi_core/security/providers.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import base64
|
||||
import binascii
|
||||
import warnings
|
||||
|
||||
from openapi_core.security.exceptions import SecurityError
|
||||
|
||||
|
||||
class BaseProvider(object):
|
||||
|
||||
def __init__(self, scheme):
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
class UnsupportedProvider(BaseProvider):
|
||||
|
||||
def __call__(self, request):
|
||||
warnings.warn("Unsupported scheme type")
|
||||
|
||||
|
||||
class ApiKeyProvider(BaseProvider):
|
||||
|
||||
def __call__(self, request):
|
||||
source = getattr(request.parameters, self.scheme.apikey_in.value)
|
||||
if self.scheme.name not in source:
|
||||
raise SecurityError("Missing api key parameter.")
|
||||
return source.get(self.scheme.name)
|
||||
|
||||
|
||||
class HttpProvider(BaseProvider):
|
||||
|
||||
def __call__(self, request):
|
||||
if 'Authorization' not in request.parameters.header:
|
||||
raise SecurityError('Missing authorization header.')
|
||||
auth_header = request.parameters.header['Authorization']
|
||||
try:
|
||||
auth_type, encoded_credentials = auth_header.split(' ', 1)
|
||||
except ValueError:
|
||||
raise SecurityError('Could not parse authorization header.')
|
||||
|
||||
if auth_type.lower() != self.scheme.scheme.value:
|
||||
raise SecurityError(
|
||||
'Unknown authorization method %s' % auth_type)
|
||||
try:
|
||||
return base64.b64decode(
|
||||
encoded_credentials.encode('ascii')).decode('latin1')
|
||||
except binascii.Error:
|
||||
raise SecurityError('Invalid base64 encoding.')
|
15
openapi_core/validation/exceptions.py
Normal file
15
openapi_core/validation/exceptions.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
"""OpenAPI core validation exceptions module"""
|
||||
import attr
|
||||
|
||||
from openapi_core.exceptions import OpenAPIError
|
||||
|
||||
|
||||
class ValidationError(OpenAPIError):
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(hash=True)
|
||||
class InvalidSecurity(ValidationError):
|
||||
|
||||
def __str__(self):
|
||||
return "Security not valid for any requirement"
|
|
@ -1,9 +1,6 @@
|
|||
"""OpenAPI core validation request validators module"""
|
||||
import base64
|
||||
import binascii
|
||||
from itertools import chain
|
||||
from six import iteritems
|
||||
import warnings
|
||||
|
||||
from openapi_core.casting.schemas.exceptions import CastError
|
||||
from openapi_core.deserializing.exceptions import DeserializeError
|
||||
|
@ -14,11 +11,12 @@ from openapi_core.schema.parameters.exceptions import (
|
|||
)
|
||||
from openapi_core.schema.paths.exceptions import InvalidPath
|
||||
from openapi_core.schema.request_bodies.exceptions import MissingRequestBody
|
||||
from openapi_core.schema.security_schemes.enums import SecuritySchemeType
|
||||
from openapi_core.schema.servers.exceptions import InvalidServer
|
||||
from openapi_core.security.exceptions import SecurityError
|
||||
from openapi_core.unmarshalling.schemas.exceptions import (
|
||||
UnmarshalError, ValidateError,
|
||||
)
|
||||
from openapi_core.validation.exceptions import InvalidSecurity
|
||||
from openapi_core.validation.request.datatypes import (
|
||||
RequestParameters, RequestValidationResult,
|
||||
)
|
||||
|
@ -45,8 +43,7 @@ class RequestValidator(object):
|
|||
|
||||
try:
|
||||
security = self._get_security(request, operation)
|
||||
# TODO narrow exceptions
|
||||
except Exception as exc:
|
||||
except InvalidSecurity as exc:
|
||||
return RequestValidationResult([exc, ], None, None, None)
|
||||
|
||||
params, params_errors = self._get_parameters(
|
||||
|
@ -108,14 +105,16 @@ class RequestValidator(object):
|
|||
return {}
|
||||
|
||||
for security_requirement in security:
|
||||
data = {
|
||||
security_requirement.name: self._get_security_value(
|
||||
security_requirement.name, request)
|
||||
}
|
||||
if all(value for value in data.values()):
|
||||
return data
|
||||
try:
|
||||
return {
|
||||
scheme_name: self._get_security_value(
|
||||
scheme_name, request)
|
||||
for scheme_name in security_requirement
|
||||
}
|
||||
except SecurityError:
|
||||
continue
|
||||
|
||||
return {}
|
||||
raise InvalidSecurity()
|
||||
|
||||
def _get_parameters(self, request, params):
|
||||
errors = []
|
||||
|
@ -196,27 +195,10 @@ class RequestValidator(object):
|
|||
if not scheme:
|
||||
return
|
||||
|
||||
if scheme.type == SecuritySchemeType.API_KEY:
|
||||
source = getattr(request.parameters, scheme.apikey_in.value)
|
||||
return source.get(scheme.name)
|
||||
elif scheme.type == SecuritySchemeType.HTTP:
|
||||
auth_header = request.parameters.header.get('Authorization')
|
||||
try:
|
||||
auth_type, encoded_credentials = auth_header.split(' ', 1)
|
||||
except ValueError:
|
||||
raise ValueError('Could not parse authorization header.')
|
||||
|
||||
if auth_type.lower() != scheme.scheme.value:
|
||||
raise ValueError(
|
||||
'Unknown authorization method %s' % auth_type)
|
||||
try:
|
||||
return base64.b64decode(
|
||||
encoded_credentials.encode('ascii'), validate=True
|
||||
).decode('latin1')
|
||||
except binascii.Error:
|
||||
raise ValueError('Invalid base64 encoding.')
|
||||
|
||||
warnings.warn("Only api key security scheme type supported")
|
||||
from openapi_core.security.factories import SecurityProviderFactory
|
||||
security_provider_factory = SecurityProviderFactory()
|
||||
security_provider = security_provider_factory.create(scheme)
|
||||
return security_provider(request)
|
||||
|
||||
def _get_parameter_value(self, param, request):
|
||||
location = request.parameters[param.location.value]
|
||||
|
|
|
@ -13,6 +13,7 @@ info:
|
|||
url: https://opensource.org/licenses/MIT
|
||||
security:
|
||||
- api_key: []
|
||||
- {}
|
||||
servers:
|
||||
- url: http://petstore.swagger.io/{version}
|
||||
variables:
|
||||
|
|
|
@ -71,9 +71,8 @@ class TestPetstore(object):
|
|||
assert type(security_req) == SecurityRequirement
|
||||
|
||||
security_req_spec = security_spec[idx]
|
||||
name = next(iter(security_req_spec))
|
||||
assert security_req.name == name
|
||||
assert security_req.scope_names == security_req_spec[name]
|
||||
for scheme_name in security_req:
|
||||
security_req[scheme_name] == security_req_spec[scheme_name]
|
||||
|
||||
assert spec.get_server_url() == url
|
||||
|
||||
|
@ -115,15 +114,6 @@ class TestPetstore(object):
|
|||
assert ext_docs.description == ext_docs_spec.get(
|
||||
'description')
|
||||
|
||||
security_spec = operation_spec.get('security')
|
||||
if security_spec:
|
||||
for idx, sec_req in enumerate(operation.security):
|
||||
assert type(sec_req) == SecurityRequirement
|
||||
sec_req_spec = security_spec[idx]
|
||||
sec_req_nam = next(iter(sec_req_spec))
|
||||
assert sec_req.name == sec_req_nam
|
||||
assert sec_req.scope_names == sec_req_spec[sec_req_nam]
|
||||
|
||||
servers_spec = operation_spec.get('servers', [])
|
||||
for idx, server in enumerate(operation.servers):
|
||||
assert type(server) == Server
|
||||
|
@ -146,9 +136,9 @@ class TestPetstore(object):
|
|||
assert type(security_req) == SecurityRequirement
|
||||
|
||||
security_req_spec = security_spec[idx]
|
||||
name = next(iter(security_req_spec))
|
||||
assert security_req.name == name
|
||||
assert security_req.scope_names == security_req_spec[name]
|
||||
for scheme_name in security_req:
|
||||
security_req[scheme_name] == security_req_spec[
|
||||
scheme_name]
|
||||
|
||||
responses_spec = operation_spec.get('responses')
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from openapi_core.schema.servers.exceptions import InvalidServer
|
|||
from openapi_core.shortcuts import create_spec
|
||||
from openapi_core.testing import MockRequest, MockResponse
|
||||
from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue
|
||||
from openapi_core.validation.exceptions import InvalidSecurity
|
||||
from openapi_core.validation.request.datatypes import RequestParameters
|
||||
from openapi_core.validation.request.validators import RequestValidator
|
||||
from openapi_core.validation.response.validators import ResponseValidator
|
||||
|
@ -37,15 +38,15 @@ class TestRequestValidator(object):
|
|||
api_key_bytes_enc = b64encode(api_key_bytes)
|
||||
return text_type(api_key_bytes_enc, 'utf8')
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def spec_dict(self, factory):
|
||||
return factory.spec_from_file("data/v3.0/petstore.yaml")
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def spec(self, spec_dict):
|
||||
return create_spec(spec_dict)
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def validator(self, spec):
|
||||
return RequestValidator(spec)
|
||||
|
||||
|
@ -248,6 +249,19 @@ class TestRequestValidator(object):
|
|||
assert result.body.address.street == pet_street
|
||||
assert result.body.address.city == pet_city
|
||||
|
||||
def test_get_pet_unauthorized(self, validator):
|
||||
request = MockRequest(
|
||||
self.host_url, 'get', '/v1/pets/1',
|
||||
path_pattern='/v1/pets/{petId}', view_args={'petId': '1'},
|
||||
)
|
||||
|
||||
result = validator.validate(request)
|
||||
|
||||
assert result.errors == [InvalidSecurity(), ]
|
||||
assert result.body is None
|
||||
assert result.parameters is None
|
||||
assert result.security is None
|
||||
|
||||
def test_get_pet(self, validator):
|
||||
authorization = 'Basic ' + self.api_key_encoded
|
||||
headers = {
|
||||
|
@ -275,7 +289,7 @@ class TestRequestValidator(object):
|
|||
|
||||
class TestPathItemParamsValidator(object):
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def spec_dict(self):
|
||||
return {
|
||||
"openapi": "3.0.0",
|
||||
|
@ -306,11 +320,11 @@ class TestPathItemParamsValidator(object):
|
|||
}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def spec(self, spec_dict):
|
||||
return create_spec(spec_dict)
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='session')
|
||||
def validator(self, spec):
|
||||
return RequestValidator(spec)
|
||||
|
||||
|
|
Loading…
Reference in a new issue