diff --git a/docs/history.rst b/docs/history.rst index 4dea9bc..a6500d3 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -4,6 +4,7 @@ Version History :compare:`Next <3.0.4...master>` -------------------------------- - Add a transcoder for `application/x-www-formurlencoded`_ +- Add support for encoding :class:`decimal.Decimal` - Add type annotations (see :ref:`type-info`) - Return a "406 Not Acceptable" if the :http:header:`Accept` header values cannot be matched and there is no default content type configured diff --git a/sprockets/mixins/mediatype/transcoders.py b/sprockets/mixins/mediatype/transcoders.py index 7f932d0..470908f 100644 --- a/sprockets/mixins/mediatype/transcoders.py +++ b/sprockets/mixins/mediatype/transcoders.py @@ -10,6 +10,7 @@ from __future__ import annotations import base64 import dataclasses +import decimal import json import string import typing @@ -85,7 +86,8 @@ class JSONTranscoder(handlers.TextContentHandler): return typing.cast(type_info.Deserialized, json.loads(str_repr, **self.load_options)) - def dump_object(self, obj: type_info.Serializable) -> str: + def dump_object(self, + obj: type_info.Serializable) -> typing.Union[str, float]: """ Called to encode unrecognized object. @@ -111,6 +113,8 @@ class JSONTranscoder(handlers.TextContentHandler): +----------------------------+---------------------------------------+ | :class:`uuid.UUID` | Same as ``str(value)`` | +----------------------------+---------------------------------------+ + | :class:`decimal.Decimal` | Same as ``float(value)`` | + +----------------------------+---------------------------------------+ """ if isinstance(obj, uuid.UUID): @@ -119,6 +123,8 @@ class JSONTranscoder(handlers.TextContentHandler): return typing.cast(type_info.SupportsIsoFormat, obj).isoformat() if isinstance(obj, (bytes, bytearray, memoryview)): return base64.b64encode(obj).decode('ASCII') + if isinstance(obj, decimal.Decimal): + return float(obj) raise TypeError('{!r} is not JSON serializable'.format(obj)) @@ -196,6 +202,8 @@ class MsgPackTranscoder(handlers.BinaryContentHandler): +-----------------------------------+-------------------------------+ | :class:`uuid.UUID` | Converted to String | +-----------------------------------+-------------------------------+ + | :class:`decimal.Decimal` | `float family`_ | + +-----------------------------------+-------------------------------+ .. _nil byte: https://github.com/msgpack/msgpack/blob/ 0b8f5ac67cdd130f4d4d4fe6afb839b989fdb86a/spec.md#formats-nil @@ -221,6 +229,9 @@ class MsgPackTranscoder(handlers.BinaryContentHandler): if datum is None: return datum + if isinstance(datum, decimal.Decimal): + datum = float(datum) + if isinstance(datum, self.PACKABLE_TYPES): return datum @@ -298,7 +309,8 @@ class FormUrlEncodedTranscoder: +----------------------------+---------------------------------------+ | :data:`None` | the empty string | +----------------------------+---------------------------------------+ - | numbers | ``str(n)`` | + | numbers including | ``str(n)`` | + | :class:`decimal.Decimal` | | +----------------------------+---------------------------------------+ | byte sequences | percent-encoded bytes | +----------------------------+---------------------------------------+ diff --git a/sprockets/mixins/mediatype/type_info.py b/sprockets/mixins/mediatype/type_info.py index ba80aaf..f9d87d5 100644 --- a/sprockets/mixins/mediatype/type_info.py +++ b/sprockets/mixins/mediatype/type_info.py @@ -1,5 +1,6 @@ from __future__ import annotations +import decimal import typing import uuid @@ -27,7 +28,8 @@ class SupportsSettings(Protocol): Serializable = typing.Union[SupportsIsoFormat, None, bool, bytearray, bytes, float, int, memoryview, str, typing.Mapping, - typing.Sequence, typing.Set, uuid.UUID] + typing.Sequence, typing.Set, uuid.UUID, + decimal.Decimal] """Types that can be serialized by this library. This is the set of types that diff --git a/tests.py b/tests.py index 550bb2d..2db3328 100644 --- a/tests.py +++ b/tests.py @@ -1,5 +1,6 @@ import base64 import datetime +import decimal import json import math import os @@ -347,6 +348,12 @@ class JSONTranscoderTests(unittest.TestCase): with self.assertRaises(TypeError): self.transcoder.dumps(object()) + def test_that_decimals_are_converted_to_floats(self): + pi = decimal.Decimal('3.142857142857142857142857143') + dumped = self.transcoder.dumps({'n': pi}) + loaded = json.loads(dumped) + self.assertEqual(loaded['n'], float(pi)) + class ContentSettingsTests(unittest.TestCase): def test_that_handler_listed_in_available_content_types(self): @@ -552,6 +559,13 @@ class MsgPackTranscoderTests(unittest.TestCase): with self.assertRaises(RuntimeError): transcoders.MsgPackTranscoder() + def test_that_decimals_are_converted_to_floats(self): + pi = decimal.Decimal('3.142857142857142857142857143') + dumped = self.transcoder.packb(pi) + # 0xCB -> 8 byte IEEE float in big endian order + self.assertEqual(0xcb, dumped[0]) + self.assertEqual(struct.pack('>d', float(pi)), dumped[1:]) + class FormUrlEncodingTranscoderTests(unittest.TestCase): transcoder: type_info.Transcoder @@ -696,3 +710,8 @@ class FormUrlEncodingTranscoderTests(unittest.TestCase): _, result = self.transcoder.to_bytes(value) self.assertEqual(b'list=1&list=2&tuple=1&tuple=2&set=1&set=2&str=val', result) + + def test_that_decimals_are_stringified(self): + pi = decimal.Decimal('3.142857142857142857142857143') + _, result = self.transcoder.to_bytes({'pi': pi}) + self.assertEqual('pi={}'.format(str(pi)).encode(), result)