diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 660affb..60f90dc 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -1,7 +1,6 @@ -from copy import copy import warnings -from openapi_schema_validator import OAS30Validator, oas30_format_checker +from openapi_schema_validator import OAS30Validator from openapi_core.schema.schemas.enums import SchemaType, SchemaFormat from openapi_core.schema.schemas.models import Schema @@ -35,8 +34,11 @@ class SchemaUnmarshallersFactory(object): UnmarshalContext.RESPONSE: 'read', } - def __init__(self, resolver=None, custom_formatters=None, context=None): + def __init__( + self, resolver=None, format_checker=None, + custom_formatters=None, context=None): self.resolver = resolver + self.format_checker = format_checker if custom_formatters is None: custom_formatters = {} self.custom_formatters = custom_formatters @@ -79,17 +81,10 @@ class SchemaUnmarshallersFactory(object): return default_formatters.get(schema_format) def get_validator(self, schema): - format_checker = self._get_format_checker() kwargs = { 'resolver': self.resolver, - 'format_checker': format_checker, + 'format_checker': self.format_checker, } if self.context is not None: kwargs[self.CONTEXT_VALIDATION[self.context]] = True return OAS30Validator(schema.__dict__, **kwargs) - - def _get_format_checker(self): - fc = copy(oas30_format_checker) - for name, formatter in self.custom_formatters.items(): - fc.checks(name)(formatter.validate) - return fc diff --git a/openapi_core/unmarshalling/schemas/util.py b/openapi_core/unmarshalling/schemas/util.py index d5ac76c..66654ca 100644 --- a/openapi_core/unmarshalling/schemas/util.py +++ b/openapi_core/unmarshalling/schemas/util.py @@ -1,10 +1,15 @@ """OpenAPI core schemas util module""" from base64 import b64decode +from copy import copy import datetime from distutils.util import strtobool from six import string_types, text_type, integer_types from uuid import UUID +from openapi_schema_validator import oas30_format_checker + +from openapi_core.compat import lru_cache + def forcebool(val): if isinstance(val, string_types): @@ -32,3 +37,14 @@ def format_number(value): return value return float(value) + + +@lru_cache() +def build_format_checker(**custom_formatters): + if not custom_formatters: + return oas30_format_checker + + fc = copy(oas30_format_checker) + for name, formatter in custom_formatters.items(): + fc.checks(name)(formatter.validate) + return fc diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 271209c..4d3639c 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,4 +1,5 @@ """OpenAPI core validation validators module""" +from openapi_core.unmarshalling.schemas.util import build_format_checker class BaseValidator(object): @@ -10,9 +11,11 @@ class BaseValidator(object): ): self.spec = spec self.base_url = base_url - self.custom_formatters = custom_formatters + self.custom_formatters = custom_formatters or {} self.custom_media_type_deserializers = custom_media_type_deserializers + self.format_checker = build_format_checker(**self.custom_formatters) + def _find_path(self, request): from openapi_core.templating.paths.finders import PathFinder finder = PathFinder(self.spec, base_url=self.base_url) @@ -45,8 +48,8 @@ class BaseValidator(object): SchemaUnmarshallersFactory, ) unmarshallers_factory = SchemaUnmarshallersFactory( - self.spec._resolver, self.custom_formatters, - context=context, + self.spec._resolver, self.format_checker, + self.custom_formatters, context=context, ) unmarshaller = unmarshallers_factory.create( param_or_media_type.schema) diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 72033fd..1c7e145 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -18,12 +18,16 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) from openapi_core.unmarshalling.schemas.formatters import Formatter +from openapi_core.unmarshalling.schemas.util import build_format_checker @pytest.fixture def unmarshaller_factory(): def create_unmarshaller(schema, custom_formatters=None, context=None): + custom_formatters = custom_formatters or {} + format_checker = build_format_checker(**custom_formatters) return SchemaUnmarshallersFactory( + format_checker=format_checker, custom_formatters=custom_formatters, context=context).create( schema) return create_unmarshaller diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/unmarshalling/test_validate.py index fdb5d95..6c91ce3 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/unmarshalling/test_validate.py @@ -12,6 +12,7 @@ from openapi_core.unmarshalling.schemas.factories import ( from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, InvalidSchemaValue, ) +from openapi_core.unmarshalling.schemas.util import build_format_checker from six import b, u @@ -21,7 +22,9 @@ class TestSchemaValidate(object): @pytest.fixture def validator_factory(self): def create_validator(schema): - return SchemaUnmarshallersFactory().create(schema) + format_checker = build_format_checker() + return SchemaUnmarshallersFactory( + format_checker=format_checker).create(schema) return create_validator @pytest.mark.parametrize('schema_type', [