Merge pull request #195 from p1c2u/feature/security-validation

Security validation
This commit is contained in:
A 2020-02-04 15:54:52 +00:00 committed by GitHub
commit 90bbc558d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 345 additions and 65 deletions

View file

@ -1,6 +1,9 @@
from openapi_core.compat import lru_cache from openapi_core.compat import lru_cache
from openapi_core.schema.components.models import Components from openapi_core.schema.components.models import Components
from openapi_core.schema.schemas.generators import SchemasGenerator from openapi_core.schema.schemas.generators import SchemasGenerator
from openapi_core.schema.security_schemes.generators import (
SecuritySchemesGenerator,
)
class ComponentsFactory(object): class ComponentsFactory(object):
@ -15,15 +18,18 @@ class ComponentsFactory(object):
schemas_spec = components_deref.get('schemas', {}) schemas_spec = components_deref.get('schemas', {})
responses_spec = components_deref.get('responses', {}) responses_spec = components_deref.get('responses', {})
parameters_spec = components_deref.get('parameters', {}) parameters_spec = components_deref.get('parameters', {})
request_bodies_spec = components_deref.get('request_bodies', {}) request_bodies_spec = components_deref.get('requestBodies', {})
security_schemes_spec = components_deref.get('securitySchemes', {})
schemas = self.schemas_generator.generate(schemas_spec) schemas = self.schemas_generator.generate(schemas_spec)
responses = self._generate_response(responses_spec) responses = self._generate_response(responses_spec)
parameters = self._generate_parameters(parameters_spec) parameters = self._generate_parameters(parameters_spec)
request_bodies = self._generate_request_bodies(request_bodies_spec) request_bodies = self._generate_request_bodies(request_bodies_spec)
security_schemes = self._generate_security_schemes(
security_schemes_spec)
return Components( return Components(
schemas=list(schemas), responses=responses, parameters=parameters, schemas=list(schemas), responses=responses, parameters=parameters,
request_bodies=request_bodies, request_bodies=request_bodies, security_schemes=security_schemes,
) )
@property @property
@ -39,3 +45,7 @@ class ComponentsFactory(object):
def _generate_request_bodies(self, request_bodies_spec): def _generate_request_bodies(self, request_bodies_spec):
return request_bodies_spec return request_bodies_spec
def _generate_security_schemes(self, security_schemes_spec):
return SecuritySchemesGenerator(self.dereferencer).generate(
security_schemes_spec)

View file

@ -3,8 +3,11 @@ class Components(object):
def __init__( def __init__(
self, schemas=None, responses=None, parameters=None, self, schemas=None, responses=None, parameters=None,
request_bodies=None): request_bodies=None, security_schemes=None):
self.schemas = schemas and dict(schemas) or {} self.schemas = schemas and dict(schemas) or {}
self.responses = responses and dict(responses) or {} self.responses = responses and dict(responses) or {}
self.parameters = parameters and dict(parameters) or {} self.parameters = parameters and dict(parameters) or {}
self.request_bodies = request_bodies and dict(request_bodies) or {} self.request_bodies = request_bodies and dict(request_bodies) or {}
self.security_schemes = (
security_schemes and dict(security_schemes) or {}
)

View file

@ -11,7 +11,9 @@ from openapi_core.schema.operations.models import Operation
from openapi_core.schema.parameters.generators import ParametersGenerator from openapi_core.schema.parameters.generators import ParametersGenerator
from openapi_core.schema.request_bodies.factories import RequestBodyFactory from openapi_core.schema.request_bodies.factories import RequestBodyFactory
from openapi_core.schema.responses.generators import ResponsesGenerator from openapi_core.schema.responses.generators import ResponsesGenerator
from openapi_core.schema.security.factories import SecurityRequirementFactory from openapi_core.schema.security_requirements.generators import (
SecurityRequirementsGenerator,
)
from openapi_core.schema.servers.generators import ServersGenerator from openapi_core.schema.servers.generators import ServersGenerator
@ -39,16 +41,12 @@ class OperationsGenerator(object):
tags_list = operation_deref.get('tags', []) tags_list = operation_deref.get('tags', [])
summary = operation_deref.get('summary') summary = operation_deref.get('summary')
description = operation_deref.get('description') description = operation_deref.get('description')
security_requirements_list = operation_deref.get('security', []) security_spec = operation_deref.get('security', [])
servers_spec = operation_deref.get('servers', []) servers_spec = operation_deref.get('servers', [])
servers = self.servers_generator.generate(servers_spec) servers = self.servers_generator.generate(servers_spec)
security = self.security_requirements_generator.generate(
security = None security_spec)
if security_requirements_list:
security = list(map(
self.security_requirement_factory.create,
security_requirements_list))
external_docs = None external_docs = None
if 'externalDocs' in operation_deref: if 'externalDocs' in operation_deref:
@ -67,10 +65,10 @@ class OperationsGenerator(object):
Operation( Operation(
http_method, path_name, responses, list(parameters), http_method, path_name, responses, list(parameters),
summary=summary, description=description, summary=summary, description=description,
external_docs=external_docs, security=security, external_docs=external_docs, security=list(security),
request_body=request_body, deprecated=deprecated, request_body=request_body, deprecated=deprecated,
operation_id=operation_id, tags=list(tags_list), operation_id=operation_id, tags=list(tags_list),
servers=servers, servers=list(servers),
), ),
) )
@ -96,8 +94,8 @@ class OperationsGenerator(object):
@property @property
@lru_cache() @lru_cache()
def security_requirement_factory(self): def security_requirements_generator(self):
return SecurityRequirementFactory(self.dereferencer) return SecurityRequirementsGenerator(self.dereferencer)
@property @property
@lru_cache() @lru_cache()

View file

@ -1,14 +0,0 @@
"""OpenAPI core security factories module"""
from openapi_core.schema.security.models import SecurityRequirement
class SecurityRequirementFactory(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def create(self, security_requirement_spec):
name = next(iter(security_requirement_spec))
scope_names = security_requirement_spec[name]
return SecurityRequirement(name, scope_names=scope_names)

View file

@ -1,9 +0,0 @@
"""OpenAPI core security models module"""
class SecurityRequirement(object):
"""Represents an OpenAPI Security Requirement."""
def __init__(self, name, scope_names=None):
self.name = name
self.scope_names = scope_names or []

View file

@ -0,0 +1,15 @@
"""OpenAPI core security requirements generators module"""
from openapi_core.schema.security_requirements.models import (
SecurityRequirement,
)
class SecurityRequirementsGenerator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def generate(self, security_spec):
security_deref = self.dereferencer.dereference(security_spec)
for security_requirement_spec in security_deref:
yield SecurityRequirement(security_requirement_spec)

View file

@ -0,0 +1,6 @@
"""OpenAPI core security requirements models module"""
class SecurityRequirement(dict):
"""Represents an OpenAPI Security Requirement."""
pass

View file

@ -0,0 +1,27 @@
"""OpenAPI core security schemes enums module"""
from enum import Enum
class SecuritySchemeType(Enum):
API_KEY = 'apiKey'
HTTP = 'http'
OAUTH2 = 'oauth2'
OPEN_ID_CONNECT = 'openIdConnect'
class ApiKeyLocation(Enum):
QUERY = 'query'
HEADER = 'header'
COOKIE = 'cookie'
@classmethod
def has_value(cls, value):
return (any(value == item.value for item in cls))
class HttpAuthScheme(Enum):
BASIC = 'basic'
BEARER = 'bearer'

View file

@ -0,0 +1,37 @@
"""OpenAPI core security schemes generators module"""
import logging
from six import iteritems
from openapi_core.schema.security_schemes.models import SecurityScheme
log = logging.getLogger(__name__)
class SecuritySchemesGenerator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def generate(self, security_schemes_spec):
security_schemes_deref = self.dereferencer.dereference(
security_schemes_spec)
for scheme_name, scheme_spec in iteritems(security_schemes_deref):
scheme_deref = self.dereferencer.dereference(scheme_spec)
scheme_type = scheme_deref['type']
description = scheme_deref.get('description')
name = scheme_deref.get('name')
apikey_in = scheme_deref.get('in')
scheme = scheme_deref.get('scheme')
bearer_format = scheme_deref.get('bearerFormat')
flows = scheme_deref.get('flows')
open_id_connect_url = scheme_deref.get('openIdConnectUrl')
scheme = SecurityScheme(
scheme_type, description=description, name=name,
apikey_in=apikey_in, scheme=scheme,
bearer_format=bearer_format, flows=flows,
open_id_connect_url=open_id_connect_url,
)
yield scheme_name, scheme

View file

@ -0,0 +1,22 @@
"""OpenAPI core security schemes models module"""
from openapi_core.schema.security_schemes.enums import (
SecuritySchemeType, ApiKeyLocation, HttpAuthScheme,
)
class SecurityScheme(object):
"""Represents an OpenAPI Security Scheme."""
def __init__(
self, scheme_type, description=None, name=None, apikey_in=None,
scheme=None, bearer_format=None, flows=None,
open_id_connect_url=None,
):
self.type = SecuritySchemeType(scheme_type)
self.description = description
self.name = name
self.apikey_in = apikey_in and ApiKeyLocation(apikey_in)
self.scheme = scheme and HttpAuthScheme(scheme)
self.bearer_format = bearer_format
self.flows = flows
self.open_id_connect_url = open_id_connect_url

View file

@ -9,6 +9,9 @@ from openapi_core.schema.components.factories import ComponentsFactory
from openapi_core.schema.infos.factories import InfoFactory from openapi_core.schema.infos.factories import InfoFactory
from openapi_core.schema.paths.generators import PathsGenerator from openapi_core.schema.paths.generators import PathsGenerator
from openapi_core.schema.schemas.registries import SchemaRegistry from openapi_core.schema.schemas.registries import SchemaRegistry
from openapi_core.schema.security_requirements.generators import (
SecurityRequirementsGenerator,
)
from openapi_core.schema.servers.generators import ServersGenerator from openapi_core.schema.servers.generators import ServersGenerator
from openapi_core.schema.specs.models import Spec from openapi_core.schema.specs.models import Spec
@ -29,6 +32,7 @@ class SpecFactory(object):
servers_spec = spec_dict_deref.get('servers', []) servers_spec = spec_dict_deref.get('servers', [])
paths = spec_dict_deref.get('paths', {}) paths = spec_dict_deref.get('paths', {})
components_spec = spec_dict_deref.get('components', {}) components_spec = spec_dict_deref.get('components', {})
security_spec = spec_dict_deref.get('security', [])
if not servers_spec: if not servers_spec:
servers_spec = [ servers_spec = [
@ -39,8 +43,13 @@ class SpecFactory(object):
servers = self.servers_generator.generate(servers_spec) servers = self.servers_generator.generate(servers_spec)
paths = self.paths_generator.generate(paths) paths = self.paths_generator.generate(paths)
components = self.components_factory.create(components_spec) components = self.components_factory.create(components_spec)
security = self.security_requirements_generator.generate(
security_spec)
spec = Spec( spec = Spec(
info, list(paths), servers=list(servers), components=components, info, list(paths), servers=list(servers), components=components,
security=list(security),
_resolver=self.spec_resolver, _resolver=self.spec_resolver,
) )
return spec return spec
@ -74,3 +83,8 @@ class SpecFactory(object):
@lru_cache() @lru_cache()
def components_factory(self): def components_factory(self):
return ComponentsFactory(self.dereferencer, self.schemas_registry) return ComponentsFactory(self.dereferencer, self.schemas_registry)
@property
@lru_cache()
def security_requirements_generator(self):
return SecurityRequirementsGenerator(self.dereferencer)

View file

@ -15,11 +15,13 @@ class Spec(object):
"""Represents an OpenAPI Specification for a service.""" """Represents an OpenAPI Specification for a service."""
def __init__( def __init__(
self, info, paths, servers=None, components=None, _resolver=None): self, info, paths, servers=None, components=None,
security=None, _resolver=None):
self.info = info self.info = info
self.paths = paths and dict(paths) self.paths = paths and dict(paths)
self.servers = servers or [] self.servers = servers or []
self.components = components self.components = components
self.security = security
self._resolver = _resolver self._resolver = _resolver

View file

View file

@ -0,0 +1,5 @@
from openapi_core.exceptions import OpenAPIError
class SecurityError(OpenAPIError):
pass

View file

@ -0,0 +1,19 @@
from openapi_core.schema.security_schemes.enums import SecuritySchemeType
from openapi_core.security.providers import (
ApiKeyProvider, HttpProvider, UnsupportedProvider,
)
class SecurityProviderFactory(object):
PROVIDERS = {
SecuritySchemeType.API_KEY: ApiKeyProvider,
SecuritySchemeType.HTTP: HttpProvider,
}
def create(self, scheme):
if scheme.type == SecuritySchemeType.API_KEY:
return ApiKeyProvider(scheme)
elif scheme.type == SecuritySchemeType.HTTP:
return HttpProvider(scheme)
return UnsupportedProvider(scheme)

View file

@ -0,0 +1,47 @@
import base64
import binascii
import warnings
from openapi_core.security.exceptions import SecurityError
class BaseProvider(object):
def __init__(self, scheme):
self.scheme = scheme
class UnsupportedProvider(BaseProvider):
def __call__(self, request):
warnings.warn("Unsupported scheme type")
class ApiKeyProvider(BaseProvider):
def __call__(self, request):
source = getattr(request.parameters, self.scheme.apikey_in.value)
if self.scheme.name not in source:
raise SecurityError("Missing api key parameter.")
return source.get(self.scheme.name)
class HttpProvider(BaseProvider):
def __call__(self, request):
if 'Authorization' not in request.parameters.header:
raise SecurityError('Missing authorization header.')
auth_header = request.parameters.header['Authorization']
try:
auth_type, encoded_credentials = auth_header.split(' ', 1)
except ValueError:
raise SecurityError('Could not parse authorization header.')
if auth_type.lower() != self.scheme.scheme.value:
raise SecurityError(
'Unknown authorization method %s' % auth_type)
try:
return base64.b64decode(
encoded_credentials.encode('ascii')).decode('latin1')
except binascii.Error:
raise SecurityError('Invalid base64 encoding.')

View file

@ -0,0 +1,15 @@
"""OpenAPI core validation exceptions module"""
import attr
from openapi_core.exceptions import OpenAPIError
class ValidationError(OpenAPIError):
pass
@attr.s(hash=True)
class InvalidSecurity(ValidationError):
def __str__(self):
return "Security not valid for any requirement"

View file

@ -71,3 +71,4 @@ class OpenAPIRequest(object):
class RequestValidationResult(BaseValidationResult): class RequestValidationResult(BaseValidationResult):
body = attr.ib(default=None) body = attr.ib(default=None)
parameters = attr.ib(factory=RequestParameters) parameters = attr.ib(factory=RequestParameters)
security = attr.ib(default=None)

View file

@ -12,9 +12,11 @@ from openapi_core.schema.parameters.exceptions import (
from openapi_core.schema.paths.exceptions import InvalidPath from openapi_core.schema.paths.exceptions import InvalidPath
from openapi_core.schema.request_bodies.exceptions import MissingRequestBody from openapi_core.schema.request_bodies.exceptions import MissingRequestBody
from openapi_core.schema.servers.exceptions import InvalidServer from openapi_core.schema.servers.exceptions import InvalidServer
from openapi_core.security.exceptions import SecurityError
from openapi_core.unmarshalling.schemas.exceptions import ( from openapi_core.unmarshalling.schemas.exceptions import (
UnmarshalError, ValidateError, UnmarshalError, ValidateError,
) )
from openapi_core.validation.exceptions import InvalidSecurity
from openapi_core.validation.request.datatypes import ( from openapi_core.validation.request.datatypes import (
RequestParameters, RequestValidationResult, RequestParameters, RequestValidationResult,
) )
@ -37,7 +39,12 @@ class RequestValidator(object):
operation = self._get_operation(request) operation = self._get_operation(request)
# don't process if operation errors # don't process if operation errors
except (InvalidServer, InvalidPath, InvalidOperation) as exc: except (InvalidServer, InvalidPath, InvalidOperation) as exc:
return RequestValidationResult([exc, ], None, None) return RequestValidationResult([exc, ], None, None, None)
try:
security = self._get_security(request, operation)
except InvalidSecurity as exc:
return RequestValidationResult([exc, ], None, None, None)
params, params_errors = self._get_parameters( params, params_errors = self._get_parameters(
request, chain( request, chain(
@ -49,7 +56,7 @@ class RequestValidator(object):
body, body_errors = self._get_body(request, operation) body, body_errors = self._get_body(request, operation)
errors = params_errors + body_errors errors = params_errors + body_errors
return RequestValidationResult(errors, body, params) return RequestValidationResult(errors, body, params, security)
def _validate_parameters(self, request): def _validate_parameters(self, request):
try: try:
@ -64,7 +71,7 @@ class RequestValidator(object):
iteritems(path.parameters) iteritems(path.parameters)
) )
) )
return RequestValidationResult(params_errors, None, params) return RequestValidationResult(params_errors, None, params, None)
def _validate_body(self, request): def _validate_body(self, request):
try: try:
@ -73,7 +80,7 @@ class RequestValidator(object):
return RequestValidationResult([exc, ], None, None) return RequestValidationResult([exc, ], None, None)
body, body_errors = self._get_body(request, operation) body, body_errors = self._get_body(request, operation)
return RequestValidationResult(body_errors, body, None) return RequestValidationResult(body_errors, body, None, None)
def _get_operation_pattern(self, request): def _get_operation_pattern(self, request):
server = self.spec.get_server(request.full_url_pattern) server = self.spec.get_server(request.full_url_pattern)
@ -92,6 +99,23 @@ class RequestValidator(object):
return self.spec.get_operation(operation_pattern, request.method) return self.spec.get_operation(operation_pattern, request.method)
def _get_security(self, request, operation):
security = operation.security or self.spec.security
if not security:
return {}
for security_requirement in security:
try:
return {
scheme_name: self._get_security_value(
scheme_name, request)
for scheme_name in security_requirement
}
except SecurityError:
continue
raise InvalidSecurity()
def _get_parameters(self, request, params): def _get_parameters(self, request, params):
errors = [] errors = []
seen = set() seen = set()
@ -166,6 +190,16 @@ class RequestValidator(object):
return body, [] return body, []
def _get_security_value(self, scheme_name, request):
scheme = self.spec.components.security_schemes.get(scheme_name)
if not scheme:
return
from openapi_core.security.factories import SecurityProviderFactory
security_provider_factory = SecurityProviderFactory()
security_provider = security_provider_factory.create(scheme)
return security_provider(request)
def _get_parameter_value(self, param, request): def _get_parameter_value(self, param, request):
location = request.parameters[param.location.value] location = request.parameters[param.location.value]

View file

@ -11,6 +11,9 @@ info:
license: license:
name: MIT name: MIT
url: https://opensource.org/licenses/MIT url: https://opensource.org/licenses/MIT
security:
- api_key: []
- {}
servers: servers:
- url: http://petstore.swagger.io/{version} - url: http://petstore.swagger.io/{version}
variables: variables:
@ -73,10 +76,6 @@ paths:
externalDocs: externalDocs:
url: https://example.com url: https://example.com
description: Find more info here description: Find more info here
security:
- petstore_auth:
- write:pets
- read:pets
servers: servers:
- url: https://development.gigantic-server.com/v1 - url: https://development.gigantic-server.com/v1
description: Development server description: Development server
@ -126,6 +125,10 @@ paths:
schema: schema:
type: integer type: integer
format: int64 format: int64
security:
- petstore_auth:
- write:pets
- read:pets
responses: responses:
'200': '200':
description: Expected response to a valid request description: Expected response to a valid request
@ -363,3 +366,11 @@ components:
application/json: application/json:
schema: schema:
$ref: "#/components/schemas/PetsData" $ref: "#/components/schemas/PetsData"
securitySchemes:
api_key:
type: apiKey
name: api_key
in: query
petstore_auth:
type: http
scheme: basic

View file

@ -9,7 +9,9 @@ from openapi_core.schema.paths.models import Path
from openapi_core.schema.request_bodies.models import RequestBody from openapi_core.schema.request_bodies.models import RequestBody
from openapi_core.schema.responses.models import Response from openapi_core.schema.responses.models import Response
from openapi_core.schema.schemas.models import Schema from openapi_core.schema.schemas.models import Schema
from openapi_core.schema.security.models import SecurityRequirement from openapi_core.schema.security_requirements.models import (
SecurityRequirement,
)
from openapi_core.schema.servers.models import Server, ServerVariable from openapi_core.schema.servers.models import Server, ServerVariable
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.request.validators import RequestValidator
@ -64,6 +66,14 @@ class TestPetstore(object):
assert spec.info.license.name == license_spec['name'] assert spec.info.license.name == license_spec['name']
assert spec.info.license.url == license_spec['url'] assert spec.info.license.url == license_spec['url']
security_spec = spec_dict.get('security', [])
for idx, security_req in enumerate(spec.security):
assert type(security_req) == SecurityRequirement
security_req_spec = security_spec[idx]
for scheme_name in security_req:
security_req[scheme_name] == security_req_spec[scheme_name]
assert spec.get_server_url() == url assert spec.get_server_url() == url
for idx, server in enumerate(spec.servers): for idx, server in enumerate(spec.servers):
@ -104,15 +114,6 @@ class TestPetstore(object):
assert ext_docs.description == ext_docs_spec.get( assert ext_docs.description == ext_docs_spec.get(
'description') 'description')
security_spec = operation_spec.get('security')
if security_spec:
for idx, sec_req in enumerate(operation.security):
assert type(sec_req) == SecurityRequirement
sec_req_spec = security_spec[idx]
sec_req_nam = next(iter(sec_req_spec))
assert sec_req.name == sec_req_nam
assert sec_req.scope_names == sec_req_spec[sec_req_nam]
servers_spec = operation_spec.get('servers', []) servers_spec = operation_spec.get('servers', [])
for idx, server in enumerate(operation.servers): for idx, server in enumerate(operation.servers):
assert type(server) == Server assert type(server) == Server
@ -130,6 +131,15 @@ class TestPetstore(object):
assert variable.default == variable_spec['default'] assert variable.default == variable_spec['default']
assert variable.enum == variable_spec.get('enum') assert variable.enum == variable_spec.get('enum')
security_spec = operation_spec.get('security', [])
for idx, security_req in enumerate(operation.security):
assert type(security_req) == SecurityRequirement
security_req_spec = security_spec[idx]
for scheme_name in security_req:
security_req[scheme_name] == security_req_spec[
scheme_name]
responses_spec = operation_spec.get('responses') responses_spec = operation_spec.get('responses')
for http_status, response in iteritems(operation.responses): for http_status, response in iteritems(operation.responses):

View file

@ -20,6 +20,7 @@ from openapi_core.schema.servers.exceptions import InvalidServer
from openapi_core.shortcuts import create_spec from openapi_core.shortcuts import create_spec
from openapi_core.testing import MockRequest, MockResponse from openapi_core.testing import MockRequest, MockResponse
from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue
from openapi_core.validation.exceptions import InvalidSecurity
from openapi_core.validation.request.datatypes import RequestParameters from openapi_core.validation.request.datatypes import RequestParameters
from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.request.validators import RequestValidator
from openapi_core.validation.response.validators import ResponseValidator from openapi_core.validation.response.validators import ResponseValidator
@ -37,15 +38,15 @@ class TestRequestValidator(object):
api_key_bytes_enc = b64encode(api_key_bytes) api_key_bytes_enc = b64encode(api_key_bytes)
return text_type(api_key_bytes_enc, 'utf8') return text_type(api_key_bytes_enc, 'utf8')
@pytest.fixture @pytest.fixture(scope='session')
def spec_dict(self, factory): def spec_dict(self, factory):
return factory.spec_from_file("data/v3.0/petstore.yaml") return factory.spec_from_file("data/v3.0/petstore.yaml")
@pytest.fixture @pytest.fixture(scope='session')
def spec(self, spec_dict): def spec(self, spec_dict):
return create_spec(spec_dict) return create_spec(spec_dict)
@pytest.fixture @pytest.fixture(scope='session')
def validator(self, spec): def validator(self, spec):
return RequestValidator(spec) return RequestValidator(spec)
@ -94,9 +95,10 @@ class TestRequestValidator(object):
) )
def test_get_pets(self, validator): def test_get_pets(self, validator):
args = {'limit': '10', 'ids': ['1', '2'], 'api_key': self.api_key}
request = MockRequest( request = MockRequest(
self.host_url, 'get', '/v1/pets', self.host_url, 'get', '/v1/pets',
path_pattern='/v1/pets', args={'limit': '10', 'ids': ['1', '2']}, path_pattern='/v1/pets', args=args,
) )
result = validator.validate(request) result = validator.validate(request)
@ -111,6 +113,9 @@ class TestRequestValidator(object):
'ids': [1, 2], 'ids': [1, 2],
}, },
) )
assert result.security == {
'api_key': self.api_key,
}
def test_get_pets_webob(self, validator): def test_get_pets_webob(self, validator):
from webob.multidict import GetDict from webob.multidict import GetDict
@ -231,6 +236,7 @@ class TestRequestValidator(object):
'user': 123, 'user': 123,
}, },
) )
assert result.security == {}
schemas = spec_dict['components']['schemas'] schemas = spec_dict['components']['schemas']
pet_model = schemas['PetCreate']['x-model'] pet_model = schemas['PetCreate']['x-model']
@ -243,7 +249,7 @@ class TestRequestValidator(object):
assert result.body.address.street == pet_street assert result.body.address.street == pet_street
assert result.body.address.city == pet_city assert result.body.address.city == pet_city
def test_get_pet(self, validator): def test_get_pet_unauthorized(self, validator):
request = MockRequest( request = MockRequest(
self.host_url, 'get', '/v1/pets/1', self.host_url, 'get', '/v1/pets/1',
path_pattern='/v1/pets/{petId}', view_args={'petId': '1'}, path_pattern='/v1/pets/{petId}', view_args={'petId': '1'},
@ -251,6 +257,24 @@ class TestRequestValidator(object):
result = validator.validate(request) result = validator.validate(request)
assert result.errors == [InvalidSecurity(), ]
assert result.body is None
assert result.parameters is None
assert result.security is None
def test_get_pet(self, validator):
authorization = 'Basic ' + self.api_key_encoded
headers = {
'Authorization': authorization,
}
request = MockRequest(
self.host_url, 'get', '/v1/pets/1',
path_pattern='/v1/pets/{petId}', view_args={'petId': '1'},
headers=headers,
)
result = validator.validate(request)
assert result.errors == [] assert result.errors == []
assert result.body is None assert result.body is None
assert result.parameters == RequestParameters( assert result.parameters == RequestParameters(
@ -258,11 +282,14 @@ class TestRequestValidator(object):
'petId': 1, 'petId': 1,
}, },
) )
assert result.security == {
'petstore_auth': self.api_key,
}
class TestPathItemParamsValidator(object): class TestPathItemParamsValidator(object):
@pytest.fixture @pytest.fixture(scope='session')
def spec_dict(self): def spec_dict(self):
return { return {
"openapi": "3.0.0", "openapi": "3.0.0",
@ -293,11 +320,11 @@ class TestPathItemParamsValidator(object):
} }
} }
@pytest.fixture @pytest.fixture(scope='session')
def spec(self, spec_dict): def spec(self, spec_dict):
return create_spec(spec_dict) return create_spec(spec_dict)
@pytest.fixture @pytest.fixture(scope='session')
def validator(self, spec): def validator(self, spec):
return RequestValidator(spec) return RequestValidator(spec)