From 10b98cba10a2798ce97a87c0503eaac91c209659 Mon Sep 17 00:00:00 2001 From: "Gavin M. Roy" Date: Mon, 27 Apr 2020 18:16:18 -0400 Subject: [PATCH] Change sprockets_postgres.QueryResult to a class After use in a couple APIs, I found having to check against QueryResult.row and QueryResult.rows too difficult, so this change will always have content in QueryResult.rows, even if it's one row. In addition, it turns the object into an iterator and adds __repr__ and __len__ magic methods to make interacting with it easier --- VERSION | 2 +- sprockets_postgres.py | 43 ++++++++++++++++++++++++++++++++++++------- tests.py | 10 +++++++--- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/VERSION b/VERSION index 6d7de6e..9084fa2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.2 +1.1.0 diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 5b23848..c05a93b 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import dataclasses import logging import os import time @@ -31,10 +30,8 @@ Timeout = typing.Union[int, float, None] """Type annotation for timeout values""" -@dataclasses.dataclass class QueryResult: - """A :func:`Data Class ` that is generated as a - result of each query that is executed. + """Contains the results of the query that was executed. :param row_count: The quantity of rows impacted by the query :param row: If a single row is returned, the data for that row @@ -42,9 +39,41 @@ class QueryResult: list of rows, in order. """ - row_count: int - row: typing.Optional[dict] - rows: typing.Optional[typing.List[dict]] + def __init__(self, + row_count: int, + row: typing.Optional[dict], + rows: typing.Optional[typing.List[dict]]): + self._row_count = row_count + self._row = row + self._rows = rows + + def __repr__(self) -> str: + return ''.format(self._row_count) + + def __iter__(self) -> typing.Iterator[dict]: + """Iterate across all rows in the result""" + for row in self.rows: + yield row + + def __len__(self) -> int: + """Returns the number of rows impacted by the query""" + return self._row_count + + @property + def row(self) -> typing.Optional[dict]: + return self._row + + @property + def row_count(self) -> int: + """Return the number of rows for the result""" + return self._row_count + + @property + def rows(self) -> typing.List[dict]: + """Return the result as a list of one or more rows""" + if self.row_count == 1: + return [self._row] + return self._rows or [] class PostgresConnector: diff --git a/tests.py b/tests.py index 67d0402..f96d8b5 100644 --- a/tests.py +++ b/tests.py @@ -44,6 +44,8 @@ class CountRequestHandler(RequestHandler): async def get(self): result = await self.postgres_execute(self.GET_SQL) + assert '' == repr(result) + assert result.rows[0] == result.row await self.finish(self.cast_data(result.row)) @@ -117,7 +119,8 @@ class MultiRowRequestHandler(RequestHandler): result = await self.postgres_execute(self.GET_SQL) await self.finish({ 'count': result.row_count, - 'rows': self.cast_data(result.rows)}) + 'rows': self.cast_data(result.rows), + 'iterator_rows': [self.cast_data(r) for r in result]}) async def post(self): body = json.loads(self.request.body.decode('utf-8')) @@ -144,6 +147,7 @@ class NoRowRequestHandler(RequestHandler): async def get(self): result = await self.postgres_execute(self.GET_SQL) + assert len(result) == result.row_count await self.finish({ 'count': result.row_count, 'rows': self.cast_data(result.rows)}) @@ -333,14 +337,14 @@ class RequestHandlerMixinTestCase(TestCase): self.assertEqual(response.code, 200) body = json.loads(response.body) self.assertEqual(body['count'], 5) - self.assertIsNone(body['rows']) + self.assertListEqual(body['rows'], []) def test_postgres_norow(self): response = self.fetch('/no-row') self.assertEqual(response.code, 200) body = json.loads(response.body) self.assertEqual(body['count'], 0) - self.assertIsNone(body['rows']) + self.assertListEqual(body['rows'], []) @mock.patch('aiopg.cursor.Cursor.execute') def test_postgres_execute_timeout_error(self, execute):