diff --git a/sprockets/mixins/mediatype/transcoders.py b/sprockets/mixins/mediatype/transcoders.py index 1db8658..c852d4d 100644 --- a/sprockets/mixins/mediatype/transcoders.py +++ b/sprockets/mixins/mediatype/transcoders.py @@ -186,6 +186,12 @@ class MsgPackTranscoder(handlers.BinaryContentHandler): +-------------------------------+-------------------------------+ | String | `str family`_ | +-------------------------------+-------------------------------+ + | :class:`bytes` | `bin family`_ | + +-------------------------------+-------------------------------+ + | :class:`bytearray` | `bin family`_ | + +-------------------------------+-------------------------------+ + | :class:`memoryview` | `bin family`_ | + +-------------------------------+-------------------------------+ | :class:`collections.Sequence` | `array family`_ | +-------------------------------+-------------------------------+ | :class:`collections.Set` | `array family`_ | @@ -223,16 +229,26 @@ class MsgPackTranscoder(handlers.BinaryContentHandler): if isinstance(datum, uuid.UUID): datum = str(datum) + if isinstance(datum, bytearray): + datum = bytes(datum) + + if isinstance(datum, memoryview): + datum = datum.tobytes() + if hasattr(datum, 'isoformat'): datum = datum.isoformat() - if isinstance(datum, bytes): - datum = datum.decode('utf-8') - - if isinstance(datum, str): + if sys.version_info[0] < 3 and isinstance(datum, (str, unicode)): + if isinstance(datum, str): + # try to decode this into a string to make the common + # case work. If we fail, then send along the bytes. + try: + datum = datum.decode('utf-8') + except UnicodeDecodeError: + pass return datum - if sys.version_info[0] < 3 and isinstance(datum, unicode): + if isinstance(datum, (bytes, str)): return datum if isinstance(datum, (collections.Sequence, collections.Set)): diff --git a/tests.py b/tests.py index eded984..c6c8b4b 100644 --- a/tests.py +++ b/tests.py @@ -48,6 +48,18 @@ def pack_string(obj): return prefix + payload +def pack_bytes(payload): + """Optimally pack a byte string according to msgpack format""" + l = len(payload) + if l < (2 ** 8): + prefix = struct.pack('BB', 0xC4, l) + elif l < (2 ** 16): + prefix = struct.pack('>BH', 0xC5, l) + else: + prefix = struct.pack('>BI', 0xC6, l) + return prefix + payload + + class SendResponseTests(testing.AsyncHTTPTestCase): def get_app(self): @@ -283,3 +295,21 @@ class MsgPackTranscoderTests(unittest.TestCase): dumped = self.transcoder.packb(now) self.assertEqual(self.transcoder.unpackb(dumped), now.isoformat()) self.assertEqual(dumped, pack_string(now.isoformat())) + + def test_that_bytes_are_sent_as_bytes(self): + data = bytes(os.urandom(127)) + dumped = self.transcoder.packb(data) + self.assertEqual(self.transcoder.unpackb(dumped), data) + self.assertEqual(dumped, pack_bytes(data)) + + def test_that_bytearrays_are_sent_as_bytes(self): + data = bytearray(os.urandom(127)) + dumped = self.transcoder.packb(data) + self.assertEqual(self.transcoder.unpackb(dumped), data) + self.assertEqual(dumped, pack_bytes(data)) + + def test_that_memoryviews_are_sent_as_bytes(self): + data = memoryview(os.urandom(127)) + dumped = self.transcoder.packb(data) + self.assertEqual(self.transcoder.unpackb(dumped), data) + self.assertEqual(dumped, pack_bytes(data.tobytes()))