Merge pull request #292 from p1c2u/fix/format-checker-on-validator-scope

Format checker on validation scope
This commit is contained in:
A 2021-02-13 12:31:57 +00:00 committed by GitHub
commit 5673e8f4e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 15 deletions

View file

@ -1,7 +1,6 @@
from copy import copy
import warnings import warnings
from openapi_schema_validator import OAS30Validator, oas30_format_checker from openapi_schema_validator import OAS30Validator
from openapi_core.schema.schemas.enums import SchemaType, SchemaFormat from openapi_core.schema.schemas.enums import SchemaType, SchemaFormat
from openapi_core.schema.schemas.models import Schema from openapi_core.schema.schemas.models import Schema
@ -35,8 +34,11 @@ class SchemaUnmarshallersFactory(object):
UnmarshalContext.RESPONSE: 'read', UnmarshalContext.RESPONSE: 'read',
} }
def __init__(self, resolver=None, custom_formatters=None, context=None): def __init__(
self, resolver=None, format_checker=None,
custom_formatters=None, context=None):
self.resolver = resolver self.resolver = resolver
self.format_checker = format_checker
if custom_formatters is None: if custom_formatters is None:
custom_formatters = {} custom_formatters = {}
self.custom_formatters = custom_formatters self.custom_formatters = custom_formatters
@ -79,17 +81,10 @@ class SchemaUnmarshallersFactory(object):
return default_formatters.get(schema_format) return default_formatters.get(schema_format)
def get_validator(self, schema): def get_validator(self, schema):
format_checker = self._get_format_checker()
kwargs = { kwargs = {
'resolver': self.resolver, 'resolver': self.resolver,
'format_checker': format_checker, 'format_checker': self.format_checker,
} }
if self.context is not None: if self.context is not None:
kwargs[self.CONTEXT_VALIDATION[self.context]] = True kwargs[self.CONTEXT_VALIDATION[self.context]] = True
return OAS30Validator(schema.__dict__, **kwargs) return OAS30Validator(schema.__dict__, **kwargs)
def _get_format_checker(self):
fc = copy(oas30_format_checker)
for name, formatter in self.custom_formatters.items():
fc.checks(name)(formatter.validate)
return fc

View file

@ -1,10 +1,15 @@
"""OpenAPI core schemas util module""" """OpenAPI core schemas util module"""
from base64 import b64decode from base64 import b64decode
from copy import copy
import datetime import datetime
from distutils.util import strtobool from distutils.util import strtobool
from six import string_types, text_type, integer_types from six import string_types, text_type, integer_types
from uuid import UUID from uuid import UUID
from openapi_schema_validator import oas30_format_checker
from openapi_core.compat import lru_cache
def forcebool(val): def forcebool(val):
if isinstance(val, string_types): if isinstance(val, string_types):
@ -32,3 +37,14 @@ def format_number(value):
return value return value
return float(value) return float(value)
@lru_cache()
def build_format_checker(**custom_formatters):
if not custom_formatters:
return oas30_format_checker
fc = copy(oas30_format_checker)
for name, formatter in custom_formatters.items():
fc.checks(name)(formatter.validate)
return fc

View file

@ -1,4 +1,5 @@
"""OpenAPI core validation validators module""" """OpenAPI core validation validators module"""
from openapi_core.unmarshalling.schemas.util import build_format_checker
class BaseValidator(object): class BaseValidator(object):
@ -10,9 +11,11 @@ class BaseValidator(object):
): ):
self.spec = spec self.spec = spec
self.base_url = base_url self.base_url = base_url
self.custom_formatters = custom_formatters self.custom_formatters = custom_formatters or {}
self.custom_media_type_deserializers = custom_media_type_deserializers self.custom_media_type_deserializers = custom_media_type_deserializers
self.format_checker = build_format_checker(**self.custom_formatters)
def _find_path(self, request): def _find_path(self, request):
from openapi_core.templating.paths.finders import PathFinder from openapi_core.templating.paths.finders import PathFinder
finder = PathFinder(self.spec, base_url=self.base_url) finder = PathFinder(self.spec, base_url=self.base_url)
@ -45,8 +48,8 @@ class BaseValidator(object):
SchemaUnmarshallersFactory, SchemaUnmarshallersFactory,
) )
unmarshallers_factory = SchemaUnmarshallersFactory( unmarshallers_factory = SchemaUnmarshallersFactory(
self.spec._resolver, self.custom_formatters, self.spec._resolver, self.format_checker,
context=context, self.custom_formatters, context=context,
) )
unmarshaller = unmarshallers_factory.create( unmarshaller = unmarshallers_factory.create(
param_or_media_type.schema) param_or_media_type.schema)

View file

@ -18,12 +18,16 @@ from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory, SchemaUnmarshallersFactory,
) )
from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.unmarshalling.schemas.formatters import Formatter
from openapi_core.unmarshalling.schemas.util import build_format_checker
@pytest.fixture @pytest.fixture
def unmarshaller_factory(): def unmarshaller_factory():
def create_unmarshaller(schema, custom_formatters=None, context=None): def create_unmarshaller(schema, custom_formatters=None, context=None):
custom_formatters = custom_formatters or {}
format_checker = build_format_checker(**custom_formatters)
return SchemaUnmarshallersFactory( return SchemaUnmarshallersFactory(
format_checker=format_checker,
custom_formatters=custom_formatters, context=context).create( custom_formatters=custom_formatters, context=context).create(
schema) schema)
return create_unmarshaller return create_unmarshaller

View file

@ -12,6 +12,7 @@ from openapi_core.unmarshalling.schemas.factories import (
from openapi_core.unmarshalling.schemas.exceptions import ( from openapi_core.unmarshalling.schemas.exceptions import (
FormatterNotFoundError, InvalidSchemaValue, FormatterNotFoundError, InvalidSchemaValue,
) )
from openapi_core.unmarshalling.schemas.util import build_format_checker
from six import b, u from six import b, u
@ -21,7 +22,9 @@ class TestSchemaValidate(object):
@pytest.fixture @pytest.fixture
def validator_factory(self): def validator_factory(self):
def create_validator(schema): def create_validator(schema):
return SchemaUnmarshallersFactory().create(schema) format_checker = build_format_checker()
return SchemaUnmarshallersFactory(
format_checker=format_checker).create(schema)
return create_validator return create_validator
@pytest.mark.parametrize('schema_type', [ @pytest.mark.parametrize('schema_type', [