diff --git a/openapi_core/components.py b/openapi_core/components.py new file mode 100644 index 0000000..54a5311 --- /dev/null +++ b/openapi_core/components.py @@ -0,0 +1,48 @@ +from openapi_core.schemas import SchemasGenerator + + +class Components(object): + """Represents an OpenAPI Components in a service.""" + + def __init__( + self, schemas=None, responses=None, parameters=None, + request_bodies=None): + self.schemas = schemas and dict(schemas) or {} + self.responses = responses and dict(responses) or {} + self.parameters = parameters and dict(parameters) or {} + self.request_bodies = request_bodies and dict(request_bodies) or {} + + +class ComponentsFactory(object): + + def __init__(self, dereferencer): + self.dereferencer = dereferencer + + def create(self, components_spec): + components_deref = self.dereferencer.dereference(components_spec) + + schemas_spec = components_deref.get('schemas', []) + responses_spec = components_deref.get('responses', []) + parameters_spec = components_deref.get('parameters', []) + request_bodies_spec = components_deref.get('request_bodies', []) + + schemas = self._generate_schemas(schemas_spec) + responses = self._generate_response(responses_spec) + parameters = self._generate_parameters(parameters_spec) + request_bodies = self._generate_request_bodies(request_bodies_spec) + return Components( + schemas=list(schemas), responses=responses, parameters=parameters, + request_bodies=request_bodies, + ) + + def _generate_schemas(self, schemas_spec): + return SchemasGenerator(self.dereferencer).generate(schemas_spec) + + def _generate_response(self, responses_spec): + return responses_spec + + def _generate_parameters(self, parameters_spec): + return parameters_spec + + def _generate_request_bodies(self, request_bodies_spec): + return request_bodies_spec diff --git a/openapi_core/infos.py b/openapi_core/infos.py new file mode 100644 index 0000000..e92eacd --- /dev/null +++ b/openapi_core/infos.py @@ -0,0 +1,17 @@ +class Info(object): + + def __init__(self, title, version): + self.title = title + self.version = version + + +class InfoFactory(object): + + def __init__(self, dereferencer): + self.dereferencer = dereferencer + + def create(self, info_spec): + info_deref = self.dereferencer.dereference(info_spec) + title = info_deref['title'] + version = info_deref['version'] + return Info(title, version) diff --git a/openapi_core/media_types.py b/openapi_core/media_types.py index 3689078..f4c8295 100644 --- a/openapi_core/media_types.py +++ b/openapi_core/media_types.py @@ -1,7 +1,7 @@ """OpenAPI core mediaTypes module""" from six import iteritems -from openapi_core.schemas import SchemaFactory +from openapi_core.schemas import SchemaRegistry class MediaType(object): @@ -34,4 +34,6 @@ class MediaTypeGenerator(object): yield content_type, MediaType(content_type, schema) def _create_schema(self, schema_spec): - return SchemaFactory(self.dereferencer).create(schema_spec) + schema, _ = SchemaRegistry(self.dereferencer).get_or_create( + schema_spec) + return schema diff --git a/openapi_core/models.py b/openapi_core/models.py new file mode 100644 index 0000000..5917709 --- /dev/null +++ b/openapi_core/models.py @@ -0,0 +1,27 @@ +"""OpenAPI core models module""" + + +class BaseModel(dict): + """Base class for OpenAPI models.""" + + def __getattr__(self, attr_name): + """Only search through properties if attribute not found normally. + :type attr_name: str + """ + try: + return self[attr_name] + except KeyError: + raise AttributeError( + 'type object {0!r} has no attribute {1!r}' + .format(type(self).__name__, attr_name) + ) + + +class ModelFactory(object): + + def create(self, properties, name=None): + model = BaseModel + if name is not None: + model = type(name, (BaseModel, ), {}) + + return model(**properties) diff --git a/openapi_core/parameters.py b/openapi_core/parameters.py index 5a49707..676bfc8 100644 --- a/openapi_core/parameters.py +++ b/openapi_core/parameters.py @@ -1,7 +1,7 @@ """OpenAPI core parameters module""" import logging -from openapi_core.schemas import SchemaFactory +from openapi_core.schemas import SchemaRegistry log = logging.getLogger(__name__) @@ -56,4 +56,6 @@ class ParametersGenerator(object): ) def _create_schema(self, schema_spec): - return SchemaFactory(self.dereferencer).create(schema_spec) + schema, _ = SchemaRegistry(self.dereferencer).get_or_create( + schema_spec) + return schema diff --git a/openapi_core/schemas.py b/openapi_core/schemas.py index 6acbf94..97630d5 100644 --- a/openapi_core/schemas.py +++ b/openapi_core/schemas.py @@ -6,13 +6,14 @@ from collections import defaultdict from json import loads from six import iteritems +from openapi_core.models import ModelFactory + log = logging.getLogger(__name__) DEFAULT_CAST_CALLABLE_GETTER = { 'integer': int, 'number': float, 'boolean': lambda x: bool(strtobool(x)), - 'object': loads, } @@ -20,9 +21,10 @@ class Schema(object): """Represents an OpenAPI Schema.""" def __init__( - self, schema_type, properties=None, items=None, spec_format=None, - required=False): + self, schema_type, model=None, properties=None, items=None, + spec_format=None, required=False): self.type = schema_type + self.model = model self.properties = properties and dict(properties) or {} self.items = items self.format = spec_format @@ -33,10 +35,10 @@ class Schema(object): def get_cast_mapping(self): mapping = DEFAULT_CAST_CALLABLE_GETTER.copy() - if self.items: - mapping.update({ - 'array': lambda x: list(map(self.items.unmarshal, x)), - }) + mapping.update({ + 'array': self._unmarshal_collection, + 'object': self._unmarshal_object, + }) return defaultdict(lambda: lambda x: x, mapping) @@ -68,6 +70,19 @@ class Schema(object): return casted + def _unmarshal_collection(self, value): + return list(map(self.items.unmarshal, value)) + + def _unmarshal_object(self, value): + if isinstance(value, (str, bytes)): + value = loads(value) + + properties = {} + for prop_name, prop in iteritems(self.properties): + prop_value = value.get(prop_name) + properties[prop_name] = prop.unmarshal(prop_value) + return ModelFactory().create(properties, name=self.model) + class PropertiesGenerator(object): @@ -90,7 +105,9 @@ class SchemaFactory(object): def create(self, schema_spec): schema_deref = self.dereferencer.dereference(schema_spec) + schema_type = schema_deref['type'] + model = schema_deref.get('x-model', None) required = schema_deref.get('required', False) properties_spec = schema_deref.get('properties', None) items_spec = schema_deref.get('items', None) @@ -104,10 +121,50 @@ class SchemaFactory(object): items = self._create_items(items_spec) return Schema( - schema_type, properties=properties, items=items, required=required) + schema_type, model=model, properties=properties, items=items, + required=required, + ) def _generate_properties(self, properties_spec): return PropertiesGenerator(self.dereferencer).generate(properties_spec) def _create_items(self, items_spec): - return SchemaFactory(self.dereferencer).create(items_spec) + return self.create(items_spec) + + +class SchemaRegistry(SchemaFactory): + + def __init__(self, dereferencer): + super(SchemaRegistry, self).__init__(dereferencer) + self._schemas = {} + + def get_or_create(self, schema_spec): + schema_deref = self.dereferencer.dereference(schema_spec) + model = schema_deref.get('x-model', None) + + if model and model in self._schemas: + return self._schemas[model], False + + return self.create(schema_deref), True + + def _create_items(self, items_spec): + schema, _ = self.get_or_create(items_spec) + return schema + + +class SchemasGenerator(object): + + def __init__(self, dereferencer): + self.dereferencer = dereferencer + + def generate(self, schemas_spec): + schemas_deref = self.dereferencer.dereference(schemas_spec) + + for schema_name, schema_spec in iteritems(schemas_deref): + schema = self._create_schema(schema_spec) + yield schema_name, schema + + def _create_schema(self, schema_spec): + schema, _ = SchemaRegistry(self.dereferencer).get_or_create( + schema_spec) + return schema diff --git a/openapi_core/specs.py b/openapi_core/specs.py index 37be7b1..f10dc46 100644 --- a/openapi_core/specs.py +++ b/openapi_core/specs.py @@ -5,6 +5,8 @@ from functools import partialmethod from openapi_spec_validator import openapi_v3_spec_validator +from openapi_core.components import ComponentsFactory +from openapi_core.infos import InfoFactory from openapi_core.paths import PathsGenerator @@ -14,9 +16,11 @@ log = logging.getLogger(__name__) class Spec(object): """Represents an OpenAPI Specification for a service.""" - def __init__(self, servers=None, paths=None): + def __init__(self, info, paths, servers=None, components=None): + self.info = info + self.paths = paths and dict(paths) self.servers = servers or [] - self.paths = paths and dict(paths) or {} + self.components = components def __getitem__(self, path_name): return self.paths[path_name] @@ -27,6 +31,9 @@ class Spec(object): def get_operation(self, path_pattern, http_method): return self.paths[path_pattern].operations[http_method] + def get_schema(self, name): + return self.components.schemas[name] + # operations shortcuts get = partialmethod(get_operation, http_method='get') @@ -50,11 +57,22 @@ class SpecFactory(object): spec_dict_deref = self.dereferencer.dereference(spec_dict) + info_spec = spec_dict_deref.get('info', []) servers = spec_dict_deref.get('servers', []) - paths = spec_dict_deref.get('paths', []) + components_spec = spec_dict_deref.get('components', []) + + info = self._create_info(info_spec) paths = self._generate_paths(paths) - return Spec(servers=servers, paths=list(paths)) + components = self._create_components(components_spec) + spec = Spec(info, list(paths), servers=servers, components=components) + return spec + + def _create_info(self, info_spec): + return InfoFactory(self.dereferencer).create(info_spec) def _generate_paths(self, paths): return PathsGenerator(self.dereferencer).generate(paths) + + def _create_components(self, components_spec): + return ComponentsFactory(self.dereferencer).create(components_spec) diff --git a/tests/integration/data/v3.0/petstore.yaml b/tests/integration/data/v3.0/petstore.yaml index 3089f7a..1228873 100644 --- a/tests/integration/data/v3.0/petstore.yaml +++ b/tests/integration/data/v3.0/petstore.yaml @@ -95,21 +95,30 @@ paths: $ref: "#/components/schemas/Error" components: schemas: + Address: + type: object + x-model: Address + required: + - city + properties: + street: + type: string + city: + type: string Pet: type: object + x-model: Pet + allOf: + - $ref: "#/components/schemas/PetCreate" required: - id - - name properties: id: type: integer format: int64 - name: - type: string - tag: - type: string PetCreate: type: object + x-model: PetCreate required: - name properties: @@ -117,6 +126,8 @@ components: type: string tag: type: string + address: + $ref: "#/components/schemas/Address" Pets: type: array items: diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index ccd3e3b..82ca975 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -44,6 +44,9 @@ class TestPetstore(object): return create_spec(spec_dict) def test_spec(self, spec, spec_dict): + assert spec.info.title == spec_dict['info']['title'] + assert spec.info.version == spec_dict['info']['version'] + assert spec.servers == spec_dict['servers'] assert spec.get_server_url() == spec_dict['servers'][0]['url'] @@ -89,6 +92,12 @@ class TestPetstore(object): assert media_type.schema.required == schema_spec.get( 'required', False) + if not spec.components: + return + + for schema_name, schema in iteritems(spec.components.schemas): + assert type(schema) == Schema + def test_get_pets(self, spec): query_params = { 'limit': '20', @@ -157,18 +166,34 @@ class TestPetstore(object): } } - def test_post_pets(self, spec): + def test_post_pets(self, spec, spec_dict): + pet_name = 'Cat' + pet_tag = 'cats' + pet_street = 'Piekna' + pet_city = 'Warsaw' data_json = { - 'name': 'Cat', - 'tag': 'cats', + 'name': pet_name, + 'tag': pet_tag, + 'address': { + 'street': pet_street, + 'city': pet_city, + } } data = json.dumps(data_json) request = RequestMock('post', '/pets', data=data) - body = request.get_body(spec) + pet = request.get_body(spec) - assert body == data_json + schemas = spec_dict['components']['schemas'] + pet_model = schemas['PetCreate']['x-model'] + address_model = schemas['Address']['x-model'] + assert pet.__class__.__name__ == pet_model + assert pet.name == pet_name + assert pet.tag == pet_tag + assert pet.address.__class__.__name__ == address_model + assert pet.address.street == pet_street + assert pet.address.city == pet_city def test_post_pets_raises_invalid_content_type(self, spec): data_json = {