diff --git a/VERSION b/VERSION index 6d7de6e..9084fa2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.2 +1.1.0 diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 5b23848..c05a93b 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import dataclasses import logging import os import time @@ -31,10 +30,8 @@ Timeout = typing.Union[int, float, None] """Type annotation for timeout values""" -@dataclasses.dataclass class QueryResult: - """A :func:`Data Class ` that is generated as a - result of each query that is executed. + """Contains the results of the query that was executed. :param row_count: The quantity of rows impacted by the query :param row: If a single row is returned, the data for that row @@ -42,9 +39,41 @@ class QueryResult: list of rows, in order. """ - row_count: int - row: typing.Optional[dict] - rows: typing.Optional[typing.List[dict]] + def __init__(self, + row_count: int, + row: typing.Optional[dict], + rows: typing.Optional[typing.List[dict]]): + self._row_count = row_count + self._row = row + self._rows = rows + + def __repr__(self) -> str: + return ''.format(self._row_count) + + def __iter__(self) -> typing.Iterator[dict]: + """Iterate across all rows in the result""" + for row in self.rows: + yield row + + def __len__(self) -> int: + """Returns the number of rows impacted by the query""" + return self._row_count + + @property + def row(self) -> typing.Optional[dict]: + return self._row + + @property + def row_count(self) -> int: + """Return the number of rows for the result""" + return self._row_count + + @property + def rows(self) -> typing.List[dict]: + """Return the result as a list of one or more rows""" + if self.row_count == 1: + return [self._row] + return self._rows or [] class PostgresConnector: diff --git a/tests.py b/tests.py index 67d0402..f96d8b5 100644 --- a/tests.py +++ b/tests.py @@ -44,6 +44,8 @@ class CountRequestHandler(RequestHandler): async def get(self): result = await self.postgres_execute(self.GET_SQL) + assert '' == repr(result) + assert result.rows[0] == result.row await self.finish(self.cast_data(result.row)) @@ -117,7 +119,8 @@ class MultiRowRequestHandler(RequestHandler): result = await self.postgres_execute(self.GET_SQL) await self.finish({ 'count': result.row_count, - 'rows': self.cast_data(result.rows)}) + 'rows': self.cast_data(result.rows), + 'iterator_rows': [self.cast_data(r) for r in result]}) async def post(self): body = json.loads(self.request.body.decode('utf-8')) @@ -144,6 +147,7 @@ class NoRowRequestHandler(RequestHandler): async def get(self): result = await self.postgres_execute(self.GET_SQL) + assert len(result) == result.row_count await self.finish({ 'count': result.row_count, 'rows': self.cast_data(result.rows)}) @@ -333,14 +337,14 @@ class RequestHandlerMixinTestCase(TestCase): self.assertEqual(response.code, 200) body = json.loads(response.body) self.assertEqual(body['count'], 5) - self.assertIsNone(body['rows']) + self.assertListEqual(body['rows'], []) def test_postgres_norow(self): response = self.fetch('/no-row') self.assertEqual(response.code, 200) body = json.loads(response.body) self.assertEqual(body['count'], 0) - self.assertIsNone(body['rows']) + self.assertListEqual(body['rows'], []) @mock.patch('aiopg.cursor.Cursor.execute') def test_postgres_execute_timeout_error(self, execute):