Change sprockets_postgres.QueryResult to a class

After use in a couple APIs, I found having to check against QueryResult.row and QueryResult.rows too difficult, so this change will always have content in QueryResult.rows, even if it's one row.

In addition, it turns the object into an iterator and adds __repr__ and __len__ magic methods to make interacting with it easier
This commit is contained in:
Gavin M. Roy 2020-04-27 18:16:18 -04:00
parent 86b6c022f0
commit 10b98cba10
3 changed files with 44 additions and 11 deletions

View file

@ -1 +1 @@
1.0.2
1.1.0

View file

@ -1,6 +1,5 @@
import asyncio
import contextlib
import dataclasses
import logging
import os
import time
@ -31,10 +30,8 @@ Timeout = typing.Union[int, float, None]
"""Type annotation for timeout values"""
@dataclasses.dataclass
class QueryResult:
"""A :func:`Data Class <dataclasses.dataclass>` that is generated as a
result of each query that is executed.
"""Contains the results of the query that was executed.
:param row_count: The quantity of rows impacted by the query
:param row: If a single row is returned, the data for that row
@ -42,9 +39,41 @@ class QueryResult:
list of rows, in order.
"""
row_count: int
row: typing.Optional[dict]
rows: typing.Optional[typing.List[dict]]
def __init__(self,
row_count: int,
row: typing.Optional[dict],
rows: typing.Optional[typing.List[dict]]):
self._row_count = row_count
self._row = row
self._rows = rows
def __repr__(self) -> str:
return '<QueryResult row_count={}>'.format(self._row_count)
def __iter__(self) -> typing.Iterator[dict]:
"""Iterate across all rows in the result"""
for row in self.rows:
yield row
def __len__(self) -> int:
"""Returns the number of rows impacted by the query"""
return self._row_count
@property
def row(self) -> typing.Optional[dict]:
return self._row
@property
def row_count(self) -> int:
"""Return the number of rows for the result"""
return self._row_count
@property
def rows(self) -> typing.List[dict]:
"""Return the result as a list of one or more rows"""
if self.row_count == 1:
return [self._row]
return self._rows or []
class PostgresConnector:

View file

@ -44,6 +44,8 @@ class CountRequestHandler(RequestHandler):
async def get(self):
result = await self.postgres_execute(self.GET_SQL)
assert '<QueryResult row_count=1>' == repr(result)
assert result.rows[0] == result.row
await self.finish(self.cast_data(result.row))
@ -117,7 +119,8 @@ class MultiRowRequestHandler(RequestHandler):
result = await self.postgres_execute(self.GET_SQL)
await self.finish({
'count': result.row_count,
'rows': self.cast_data(result.rows)})
'rows': self.cast_data(result.rows),
'iterator_rows': [self.cast_data(r) for r in result]})
async def post(self):
body = json.loads(self.request.body.decode('utf-8'))
@ -144,6 +147,7 @@ class NoRowRequestHandler(RequestHandler):
async def get(self):
result = await self.postgres_execute(self.GET_SQL)
assert len(result) == result.row_count
await self.finish({
'count': result.row_count,
'rows': self.cast_data(result.rows)})
@ -333,14 +337,14 @@ class RequestHandlerMixinTestCase(TestCase):
self.assertEqual(response.code, 200)
body = json.loads(response.body)
self.assertEqual(body['count'], 5)
self.assertIsNone(body['rows'])
self.assertListEqual(body['rows'], [])
def test_postgres_norow(self):
response = self.fetch('/no-row')
self.assertEqual(response.code, 200)
body = json.loads(response.body)
self.assertEqual(body['count'], 0)
self.assertIsNone(body['rows'])
self.assertListEqual(body['rows'], [])
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_timeout_error(self, execute):