Merge pull request #2 from p1c2u/feature/models

Component schemas models
This commit is contained in:
A 2017-09-22 09:19:22 +01:00 committed by GitHub
commit 17915a4dd6
9 changed files with 234 additions and 27 deletions

View file

@ -0,0 +1,48 @@
from openapi_core.schemas import SchemasGenerator
class Components(object):
"""Represents an OpenAPI Components in a service."""
def __init__(
self, schemas=None, responses=None, parameters=None,
request_bodies=None):
self.schemas = schemas and dict(schemas) or {}
self.responses = responses and dict(responses) or {}
self.parameters = parameters and dict(parameters) or {}
self.request_bodies = request_bodies and dict(request_bodies) or {}
class ComponentsFactory(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def create(self, components_spec):
components_deref = self.dereferencer.dereference(components_spec)
schemas_spec = components_deref.get('schemas', [])
responses_spec = components_deref.get('responses', [])
parameters_spec = components_deref.get('parameters', [])
request_bodies_spec = components_deref.get('request_bodies', [])
schemas = self._generate_schemas(schemas_spec)
responses = self._generate_response(responses_spec)
parameters = self._generate_parameters(parameters_spec)
request_bodies = self._generate_request_bodies(request_bodies_spec)
return Components(
schemas=list(schemas), responses=responses, parameters=parameters,
request_bodies=request_bodies,
)
def _generate_schemas(self, schemas_spec):
return SchemasGenerator(self.dereferencer).generate(schemas_spec)
def _generate_response(self, responses_spec):
return responses_spec
def _generate_parameters(self, parameters_spec):
return parameters_spec
def _generate_request_bodies(self, request_bodies_spec):
return request_bodies_spec

17
openapi_core/infos.py Normal file
View file

@ -0,0 +1,17 @@
class Info(object):
def __init__(self, title, version):
self.title = title
self.version = version
class InfoFactory(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def create(self, info_spec):
info_deref = self.dereferencer.dereference(info_spec)
title = info_deref['title']
version = info_deref['version']
return Info(title, version)

View file

@ -1,7 +1,7 @@
"""OpenAPI core mediaTypes module""" """OpenAPI core mediaTypes module"""
from six import iteritems from six import iteritems
from openapi_core.schemas import SchemaFactory from openapi_core.schemas import SchemaRegistry
class MediaType(object): class MediaType(object):
@ -34,4 +34,6 @@ class MediaTypeGenerator(object):
yield content_type, MediaType(content_type, schema) yield content_type, MediaType(content_type, schema)
def _create_schema(self, schema_spec): def _create_schema(self, schema_spec):
return SchemaFactory(self.dereferencer).create(schema_spec) schema, _ = SchemaRegistry(self.dereferencer).get_or_create(
schema_spec)
return schema

27
openapi_core/models.py Normal file
View file

@ -0,0 +1,27 @@
"""OpenAPI core models module"""
class BaseModel(dict):
"""Base class for OpenAPI models."""
def __getattr__(self, attr_name):
"""Only search through properties if attribute not found normally.
:type attr_name: str
"""
try:
return self[attr_name]
except KeyError:
raise AttributeError(
'type object {0!r} has no attribute {1!r}'
.format(type(self).__name__, attr_name)
)
class ModelFactory(object):
def create(self, properties, name=None):
model = BaseModel
if name is not None:
model = type(name, (BaseModel, ), {})
return model(**properties)

View file

@ -1,7 +1,7 @@
"""OpenAPI core parameters module""" """OpenAPI core parameters module"""
import logging import logging
from openapi_core.schemas import SchemaFactory from openapi_core.schemas import SchemaRegistry
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -56,4 +56,6 @@ class ParametersGenerator(object):
) )
def _create_schema(self, schema_spec): def _create_schema(self, schema_spec):
return SchemaFactory(self.dereferencer).create(schema_spec) schema, _ = SchemaRegistry(self.dereferencer).get_or_create(
schema_spec)
return schema

View file

@ -6,13 +6,14 @@ from collections import defaultdict
from json import loads from json import loads
from six import iteritems from six import iteritems
from openapi_core.models import ModelFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_CAST_CALLABLE_GETTER = { DEFAULT_CAST_CALLABLE_GETTER = {
'integer': int, 'integer': int,
'number': float, 'number': float,
'boolean': lambda x: bool(strtobool(x)), 'boolean': lambda x: bool(strtobool(x)),
'object': loads,
} }
@ -20,9 +21,10 @@ class Schema(object):
"""Represents an OpenAPI Schema.""" """Represents an OpenAPI Schema."""
def __init__( def __init__(
self, schema_type, properties=None, items=None, spec_format=None, self, schema_type, model=None, properties=None, items=None,
required=False): spec_format=None, required=False):
self.type = schema_type self.type = schema_type
self.model = model
self.properties = properties and dict(properties) or {} self.properties = properties and dict(properties) or {}
self.items = items self.items = items
self.format = spec_format self.format = spec_format
@ -33,9 +35,9 @@ class Schema(object):
def get_cast_mapping(self): def get_cast_mapping(self):
mapping = DEFAULT_CAST_CALLABLE_GETTER.copy() mapping = DEFAULT_CAST_CALLABLE_GETTER.copy()
if self.items:
mapping.update({ mapping.update({
'array': lambda x: list(map(self.items.unmarshal, x)), 'array': self._unmarshal_collection,
'object': self._unmarshal_object,
}) })
return defaultdict(lambda: lambda x: x, mapping) return defaultdict(lambda: lambda x: x, mapping)
@ -68,6 +70,19 @@ class Schema(object):
return casted return casted
def _unmarshal_collection(self, value):
return list(map(self.items.unmarshal, value))
def _unmarshal_object(self, value):
if isinstance(value, (str, bytes)):
value = loads(value)
properties = {}
for prop_name, prop in iteritems(self.properties):
prop_value = value.get(prop_name)
properties[prop_name] = prop.unmarshal(prop_value)
return ModelFactory().create(properties, name=self.model)
class PropertiesGenerator(object): class PropertiesGenerator(object):
@ -90,7 +105,9 @@ class SchemaFactory(object):
def create(self, schema_spec): def create(self, schema_spec):
schema_deref = self.dereferencer.dereference(schema_spec) schema_deref = self.dereferencer.dereference(schema_spec)
schema_type = schema_deref['type'] schema_type = schema_deref['type']
model = schema_deref.get('x-model', None)
required = schema_deref.get('required', False) required = schema_deref.get('required', False)
properties_spec = schema_deref.get('properties', None) properties_spec = schema_deref.get('properties', None)
items_spec = schema_deref.get('items', None) items_spec = schema_deref.get('items', None)
@ -104,10 +121,50 @@ class SchemaFactory(object):
items = self._create_items(items_spec) items = self._create_items(items_spec)
return Schema( return Schema(
schema_type, properties=properties, items=items, required=required) schema_type, model=model, properties=properties, items=items,
required=required,
)
def _generate_properties(self, properties_spec): def _generate_properties(self, properties_spec):
return PropertiesGenerator(self.dereferencer).generate(properties_spec) return PropertiesGenerator(self.dereferencer).generate(properties_spec)
def _create_items(self, items_spec): def _create_items(self, items_spec):
return SchemaFactory(self.dereferencer).create(items_spec) return self.create(items_spec)
class SchemaRegistry(SchemaFactory):
def __init__(self, dereferencer):
super(SchemaRegistry, self).__init__(dereferencer)
self._schemas = {}
def get_or_create(self, schema_spec):
schema_deref = self.dereferencer.dereference(schema_spec)
model = schema_deref.get('x-model', None)
if model and model in self._schemas:
return self._schemas[model], False
return self.create(schema_deref), True
def _create_items(self, items_spec):
schema, _ = self.get_or_create(items_spec)
return schema
class SchemasGenerator(object):
def __init__(self, dereferencer):
self.dereferencer = dereferencer
def generate(self, schemas_spec):
schemas_deref = self.dereferencer.dereference(schemas_spec)
for schema_name, schema_spec in iteritems(schemas_deref):
schema = self._create_schema(schema_spec)
yield schema_name, schema
def _create_schema(self, schema_spec):
schema, _ = SchemaRegistry(self.dereferencer).get_or_create(
schema_spec)
return schema

View file

@ -5,6 +5,8 @@ from functools import partialmethod
from openapi_spec_validator import openapi_v3_spec_validator from openapi_spec_validator import openapi_v3_spec_validator
from openapi_core.components import ComponentsFactory
from openapi_core.infos import InfoFactory
from openapi_core.paths import PathsGenerator from openapi_core.paths import PathsGenerator
@ -14,9 +16,11 @@ log = logging.getLogger(__name__)
class Spec(object): class Spec(object):
"""Represents an OpenAPI Specification for a service.""" """Represents an OpenAPI Specification for a service."""
def __init__(self, servers=None, paths=None): def __init__(self, info, paths, servers=None, components=None):
self.info = info
self.paths = paths and dict(paths)
self.servers = servers or [] self.servers = servers or []
self.paths = paths and dict(paths) or {} self.components = components
def __getitem__(self, path_name): def __getitem__(self, path_name):
return self.paths[path_name] return self.paths[path_name]
@ -27,6 +31,9 @@ class Spec(object):
def get_operation(self, path_pattern, http_method): def get_operation(self, path_pattern, http_method):
return self.paths[path_pattern].operations[http_method] return self.paths[path_pattern].operations[http_method]
def get_schema(self, name):
return self.components.schemas[name]
# operations shortcuts # operations shortcuts
get = partialmethod(get_operation, http_method='get') get = partialmethod(get_operation, http_method='get')
@ -50,11 +57,22 @@ class SpecFactory(object):
spec_dict_deref = self.dereferencer.dereference(spec_dict) spec_dict_deref = self.dereferencer.dereference(spec_dict)
info_spec = spec_dict_deref.get('info', [])
servers = spec_dict_deref.get('servers', []) servers = spec_dict_deref.get('servers', [])
paths = spec_dict_deref.get('paths', []) paths = spec_dict_deref.get('paths', [])
components_spec = spec_dict_deref.get('components', [])
info = self._create_info(info_spec)
paths = self._generate_paths(paths) paths = self._generate_paths(paths)
return Spec(servers=servers, paths=list(paths)) components = self._create_components(components_spec)
spec = Spec(info, list(paths), servers=servers, components=components)
return spec
def _create_info(self, info_spec):
return InfoFactory(self.dereferencer).create(info_spec)
def _generate_paths(self, paths): def _generate_paths(self, paths):
return PathsGenerator(self.dereferencer).generate(paths) return PathsGenerator(self.dereferencer).generate(paths)
def _create_components(self, components_spec):
return ComponentsFactory(self.dereferencer).create(components_spec)

View file

@ -95,21 +95,30 @@ paths:
$ref: "#/components/schemas/Error" $ref: "#/components/schemas/Error"
components: components:
schemas: schemas:
Address:
type: object
x-model: Address
required:
- city
properties:
street:
type: string
city:
type: string
Pet: Pet:
type: object type: object
x-model: Pet
allOf:
- $ref: "#/components/schemas/PetCreate"
required: required:
- id - id
- name
properties: properties:
id: id:
type: integer type: integer
format: int64 format: int64
name:
type: string
tag:
type: string
PetCreate: PetCreate:
type: object type: object
x-model: PetCreate
required: required:
- name - name
properties: properties:
@ -117,6 +126,8 @@ components:
type: string type: string
tag: tag:
type: string type: string
address:
$ref: "#/components/schemas/Address"
Pets: Pets:
type: array type: array
items: items:

View file

@ -44,6 +44,9 @@ class TestPetstore(object):
return create_spec(spec_dict) return create_spec(spec_dict)
def test_spec(self, spec, spec_dict): def test_spec(self, spec, spec_dict):
assert spec.info.title == spec_dict['info']['title']
assert spec.info.version == spec_dict['info']['version']
assert spec.servers == spec_dict['servers'] assert spec.servers == spec_dict['servers']
assert spec.get_server_url() == spec_dict['servers'][0]['url'] assert spec.get_server_url() == spec_dict['servers'][0]['url']
@ -89,6 +92,12 @@ class TestPetstore(object):
assert media_type.schema.required == schema_spec.get( assert media_type.schema.required == schema_spec.get(
'required', False) 'required', False)
if not spec.components:
return
for schema_name, schema in iteritems(spec.components.schemas):
assert type(schema) == Schema
def test_get_pets(self, spec): def test_get_pets(self, spec):
query_params = { query_params = {
'limit': '20', 'limit': '20',
@ -157,18 +166,34 @@ class TestPetstore(object):
} }
} }
def test_post_pets(self, spec): def test_post_pets(self, spec, spec_dict):
pet_name = 'Cat'
pet_tag = 'cats'
pet_street = 'Piekna'
pet_city = 'Warsaw'
data_json = { data_json = {
'name': 'Cat', 'name': pet_name,
'tag': 'cats', 'tag': pet_tag,
'address': {
'street': pet_street,
'city': pet_city,
}
} }
data = json.dumps(data_json) data = json.dumps(data_json)
request = RequestMock('post', '/pets', data=data) request = RequestMock('post', '/pets', data=data)
body = request.get_body(spec) pet = request.get_body(spec)
assert body == data_json schemas = spec_dict['components']['schemas']
pet_model = schemas['PetCreate']['x-model']
address_model = schemas['Address']['x-model']
assert pet.__class__.__name__ == pet_model
assert pet.name == pet_name
assert pet.tag == pet_tag
assert pet.address.__class__.__name__ == address_model
assert pet.address.street == pet_street
assert pet.address.city == pet_city
def test_post_pets_raises_invalid_content_type(self, spec): def test_post_pets_raises_invalid_content_type(self, spec):
data_json = { data_json = {