mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2025-04-09 09:21:01 -09:00
Change sprockets_postgres.QueryResult to a class
After use in a couple APIs, I found having to check against QueryResult.row and QueryResult.rows too difficult, so this change will always have content in QueryResult.rows, even if it's one row. In addition, it turns the object into an iterator and adds __repr__ and __len__ magic methods to make interacting with it easier
This commit is contained in:
parent
86b6c022f0
commit
10b98cba10
3 changed files with 44 additions and 11 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
1.0.2
|
1.1.0
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
@ -31,10 +30,8 @@ Timeout = typing.Union[int, float, None]
|
||||||
"""Type annotation for timeout values"""
|
"""Type annotation for timeout values"""
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class QueryResult:
|
class QueryResult:
|
||||||
"""A :func:`Data Class <dataclasses.dataclass>` that is generated as a
|
"""Contains the results of the query that was executed.
|
||||||
result of each query that is executed.
|
|
||||||
|
|
||||||
:param row_count: The quantity of rows impacted by the query
|
:param row_count: The quantity of rows impacted by the query
|
||||||
:param row: If a single row is returned, the data for that row
|
:param row: If a single row is returned, the data for that row
|
||||||
|
@ -42,9 +39,41 @@ class QueryResult:
|
||||||
list of rows, in order.
|
list of rows, in order.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
row_count: int
|
def __init__(self,
|
||||||
row: typing.Optional[dict]
|
row_count: int,
|
||||||
rows: typing.Optional[typing.List[dict]]
|
row: typing.Optional[dict],
|
||||||
|
rows: typing.Optional[typing.List[dict]]):
|
||||||
|
self._row_count = row_count
|
||||||
|
self._row = row
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return '<QueryResult row_count={}>'.format(self._row_count)
|
||||||
|
|
||||||
|
def __iter__(self) -> typing.Iterator[dict]:
|
||||||
|
"""Iterate across all rows in the result"""
|
||||||
|
for row in self.rows:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the number of rows impacted by the query"""
|
||||||
|
return self._row_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def row(self) -> typing.Optional[dict]:
|
||||||
|
return self._row
|
||||||
|
|
||||||
|
@property
|
||||||
|
def row_count(self) -> int:
|
||||||
|
"""Return the number of rows for the result"""
|
||||||
|
return self._row_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rows(self) -> typing.List[dict]:
|
||||||
|
"""Return the result as a list of one or more rows"""
|
||||||
|
if self.row_count == 1:
|
||||||
|
return [self._row]
|
||||||
|
return self._rows or []
|
||||||
|
|
||||||
|
|
||||||
class PostgresConnector:
|
class PostgresConnector:
|
||||||
|
|
10
tests.py
10
tests.py
|
@ -44,6 +44,8 @@ class CountRequestHandler(RequestHandler):
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
result = await self.postgres_execute(self.GET_SQL)
|
result = await self.postgres_execute(self.GET_SQL)
|
||||||
|
assert '<QueryResult row_count=1>' == repr(result)
|
||||||
|
assert result.rows[0] == result.row
|
||||||
await self.finish(self.cast_data(result.row))
|
await self.finish(self.cast_data(result.row))
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +119,8 @@ class MultiRowRequestHandler(RequestHandler):
|
||||||
result = await self.postgres_execute(self.GET_SQL)
|
result = await self.postgres_execute(self.GET_SQL)
|
||||||
await self.finish({
|
await self.finish({
|
||||||
'count': result.row_count,
|
'count': result.row_count,
|
||||||
'rows': self.cast_data(result.rows)})
|
'rows': self.cast_data(result.rows),
|
||||||
|
'iterator_rows': [self.cast_data(r) for r in result]})
|
||||||
|
|
||||||
async def post(self):
|
async def post(self):
|
||||||
body = json.loads(self.request.body.decode('utf-8'))
|
body = json.loads(self.request.body.decode('utf-8'))
|
||||||
|
@ -144,6 +147,7 @@ class NoRowRequestHandler(RequestHandler):
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
result = await self.postgres_execute(self.GET_SQL)
|
result = await self.postgres_execute(self.GET_SQL)
|
||||||
|
assert len(result) == result.row_count
|
||||||
await self.finish({
|
await self.finish({
|
||||||
'count': result.row_count,
|
'count': result.row_count,
|
||||||
'rows': self.cast_data(result.rows)})
|
'rows': self.cast_data(result.rows)})
|
||||||
|
@ -333,14 +337,14 @@ class RequestHandlerMixinTestCase(TestCase):
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
body = json.loads(response.body)
|
body = json.loads(response.body)
|
||||||
self.assertEqual(body['count'], 5)
|
self.assertEqual(body['count'], 5)
|
||||||
self.assertIsNone(body['rows'])
|
self.assertListEqual(body['rows'], [])
|
||||||
|
|
||||||
def test_postgres_norow(self):
|
def test_postgres_norow(self):
|
||||||
response = self.fetch('/no-row')
|
response = self.fetch('/no-row')
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
body = json.loads(response.body)
|
body = json.loads(response.body)
|
||||||
self.assertEqual(body['count'], 0)
|
self.assertEqual(body['count'], 0)
|
||||||
self.assertIsNone(body['rows'])
|
self.assertListEqual(body['rows'], [])
|
||||||
|
|
||||||
@mock.patch('aiopg.cursor.Cursor.execute')
|
@mock.patch('aiopg.cursor.Cursor.execute')
|
||||||
def test_postgres_execute_timeout_error(self, execute):
|
def test_postgres_execute_timeout_error(self, execute):
|
||||||
|
|
Loading…
Add table
Reference in a new issue