Change sprockets_postgres.QueryResult to a class

After use in a couple APIs, I found having to check against QueryResult.row and QueryResult.rows too difficult, so this change will always have content in QueryResult.rows, even if it's one row.

In addition, it turns the object into an iterator and adds __repr__ and __len__ magic methods to make interacting with it easier
This commit is contained in:
Gavin M. Roy 2020-04-27 18:16:18 -04:00
parent 86b6c022f0
commit 10b98cba10
3 changed files with 44 additions and 11 deletions

View file

@ -1 +1 @@
1.0.2 1.1.0

View file

@ -1,6 +1,5 @@
import asyncio import asyncio
import contextlib import contextlib
import dataclasses
import logging import logging
import os import os
import time import time
@ -31,10 +30,8 @@ Timeout = typing.Union[int, float, None]
"""Type annotation for timeout values""" """Type annotation for timeout values"""
@dataclasses.dataclass
class QueryResult: class QueryResult:
"""A :func:`Data Class <dataclasses.dataclass>` that is generated as a """Contains the results of the query that was executed.
result of each query that is executed.
:param row_count: The quantity of rows impacted by the query :param row_count: The quantity of rows impacted by the query
:param row: If a single row is returned, the data for that row :param row: If a single row is returned, the data for that row
@ -42,9 +39,41 @@ class QueryResult:
list of rows, in order. list of rows, in order.
""" """
row_count: int def __init__(self,
row: typing.Optional[dict] row_count: int,
rows: typing.Optional[typing.List[dict]] row: typing.Optional[dict],
rows: typing.Optional[typing.List[dict]]):
self._row_count = row_count
self._row = row
self._rows = rows
def __repr__(self) -> str:
return '<QueryResult row_count={}>'.format(self._row_count)
def __iter__(self) -> typing.Iterator[dict]:
"""Iterate across all rows in the result"""
for row in self.rows:
yield row
def __len__(self) -> int:
"""Returns the number of rows impacted by the query"""
return self._row_count
@property
def row(self) -> typing.Optional[dict]:
return self._row
@property
def row_count(self) -> int:
"""Return the number of rows for the result"""
return self._row_count
@property
def rows(self) -> typing.List[dict]:
"""Return the result as a list of one or more rows"""
if self.row_count == 1:
return [self._row]
return self._rows or []
class PostgresConnector: class PostgresConnector:

View file

@ -44,6 +44,8 @@ class CountRequestHandler(RequestHandler):
async def get(self): async def get(self):
result = await self.postgres_execute(self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
assert '<QueryResult row_count=1>' == repr(result)
assert result.rows[0] == result.row
await self.finish(self.cast_data(result.row)) await self.finish(self.cast_data(result.row))
@ -117,7 +119,8 @@ class MultiRowRequestHandler(RequestHandler):
result = await self.postgres_execute(self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
await self.finish({ await self.finish({
'count': result.row_count, 'count': result.row_count,
'rows': self.cast_data(result.rows)}) 'rows': self.cast_data(result.rows),
'iterator_rows': [self.cast_data(r) for r in result]})
async def post(self): async def post(self):
body = json.loads(self.request.body.decode('utf-8')) body = json.loads(self.request.body.decode('utf-8'))
@ -144,6 +147,7 @@ class NoRowRequestHandler(RequestHandler):
async def get(self): async def get(self):
result = await self.postgres_execute(self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
assert len(result) == result.row_count
await self.finish({ await self.finish({
'count': result.row_count, 'count': result.row_count,
'rows': self.cast_data(result.rows)}) 'rows': self.cast_data(result.rows)})
@ -333,14 +337,14 @@ class RequestHandlerMixinTestCase(TestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = json.loads(response.body) body = json.loads(response.body)
self.assertEqual(body['count'], 5) self.assertEqual(body['count'], 5)
self.assertIsNone(body['rows']) self.assertListEqual(body['rows'], [])
def test_postgres_norow(self): def test_postgres_norow(self):
response = self.fetch('/no-row') response = self.fetch('/no-row')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = json.loads(response.body) body = json.loads(response.body)
self.assertEqual(body['count'], 0) self.assertEqual(body['count'], 0)
self.assertIsNone(body['rows']) self.assertListEqual(body['rows'], [])
@mock.patch('aiopg.cursor.Cursor.execute') @mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_timeout_error(self, execute): def test_postgres_execute_timeout_error(self, execute):