"""OpenAPI core schemas module"""
import logging
from collections import defaultdict
import warnings

from distutils.util import strtobool
from functools import lru_cache

from json import loads
from six import iteritems

from openapi_core.exceptions import (
    InvalidValueType, UndefinedSchemaProperty, MissingProperty, InvalidValue,
)
from openapi_core.models import ModelFactory

log = logging.getLogger(__name__)

DEFAULT_CAST_CALLABLE_GETTER = {
    'integer': int,
    'number': float,
    'boolean': lambda x: bool(strtobool(x)),
}


class Schema(object):
    """Represents an OpenAPI Schema."""

    def __init__(
            self, schema_type, model=None, properties=None, items=None,
            spec_format=None, required=False, default=None, nullable=False,
            enum=None, deprecated=False):
        self.type = schema_type
        self.model = model
        self.properties = properties and dict(properties) or {}
        self.items = items
        self.format = spec_format
        self.required = required
        self.default = default
        self.nullable = nullable
        self.enum = enum
        self.deprecated = deprecated

    def __getitem__(self, name):
        return self.properties[name]

    def get_cast_mapping(self):
        mapping = DEFAULT_CAST_CALLABLE_GETTER.copy()
        mapping.update({
            'array': self._unmarshal_collection,
            'object': self._unmarshal_object,
        })

        return defaultdict(lambda: lambda x: x, mapping)

    def cast(self, value):
        """Cast value to schema type"""
        if value is None:
            if not self.nullable:
                raise InvalidValueType(
                    "Failed to cast value of {0} to {1}".format(
                        value, self.type)
                )
            return self.default

        cast_mapping = self.get_cast_mapping()

        if self.type in cast_mapping and value == '':
            return None

        cast_callable = cast_mapping[self.type]
        try:
            return cast_callable(value)
        except ValueError:
            raise InvalidValueType(
                "Failed to cast value of {0} to {1}".format(value, self.type)
            )

    def unmarshal(self, value):
        """Unmarshal parameter from the value."""
        if self.deprecated:
            warnings.warn(
                "The schema is deprecated", DeprecationWarning)
        casted = self.cast(value)

        if casted is None and not self.required:
            return None

        if self.enum and casted not in self.enum:
            raise InvalidValue(
                "Value of {0} not in enum choices: {1}".format(
                    value, self.enum)
            )

        return casted

    def _unmarshal_collection(self, value):
        return list(map(self.items.unmarshal, value))

    def _unmarshal_object(self, value):
        if isinstance(value, (str, bytes)):
            value = loads(value)

        properties_keys = self.properties.keys()
        value_keys = value.keys()

        extra_props = set(value_keys) - set(properties_keys)

        if extra_props:
            raise UndefinedSchemaProperty(
                "Undefined properties in schema: {0}".format(extra_props))

        properties = {}
        for prop_name, prop in iteritems(self.properties):
            try:
                prop_value = value[prop_name]
            except KeyError:
                if prop_name in self.required:
                    raise MissingProperty(
                        "Missing schema property {0}".format(prop_name))
                if not prop.nullable and not prop.default:
                    continue
                prop_value = prop.default
            properties[prop_name] = prop.unmarshal(prop_value)
        return ModelFactory().create(properties, name=self.model)


class PropertiesGenerator(object):

    def __init__(self, dereferencer):
        self.dereferencer = dereferencer

    def generate(self, properties):
        for property_name, schema_spec in iteritems(properties):
            schema = self._create_schema(schema_spec)
            yield property_name, schema

    def _create_schema(self, schema_spec):
        return SchemaFactory(self.dereferencer).create(schema_spec)


class SchemaFactory(object):

    def __init__(self, dereferencer):
        self.dereferencer = dereferencer

    def create(self, schema_spec):
        schema_deref = self.dereferencer.dereference(schema_spec)

        schema_type = schema_deref['type']
        model = schema_deref.get('x-model', None)
        required = schema_deref.get('required', False)
        default = schema_deref.get('default', None)
        properties_spec = schema_deref.get('properties', None)
        items_spec = schema_deref.get('items', None)
        nullable = schema_deref.get('nullable', False)
        enum = schema_deref.get('enum', None)
        deprecated = schema_deref.get('deprecated', False)

        properties = None
        if properties_spec:
            properties = self.properties_generator.generate(properties_spec)

        items = None
        if items_spec:
            items = self._create_items(items_spec)

        return Schema(
            schema_type, model=model, properties=properties, items=items,
            required=required, default=default, nullable=nullable, enum=enum,
            deprecated=deprecated,
        )

    @property
    @lru_cache()
    def properties_generator(self):
        return PropertiesGenerator(self.dereferencer)

    def _create_items(self, items_spec):
        return self.create(items_spec)


class SchemaRegistry(SchemaFactory):

    def __init__(self, dereferencer):
        super(SchemaRegistry, self).__init__(dereferencer)
        self._schemas = {}

    def get_or_create(self, schema_spec):
        schema_deref = self.dereferencer.dereference(schema_spec)
        model = schema_deref.get('x-model', None)

        if model and model in self._schemas:
            return self._schemas[model], False

        return self.create(schema_deref), True


class SchemasGenerator(object):

    def __init__(self, dereferencer, schemas_registry):
        self.dereferencer = dereferencer
        self.schemas_registry = schemas_registry

    def generate(self, schemas_spec):
        schemas_deref = self.dereferencer.dereference(schemas_spec)

        for schema_name, schema_spec in iteritems(schemas_deref):
            schema, _ = self.schemas_registry.get_or_create(schema_spec)
            yield schema_name, schema