diff --git a/fixtures/testing.sql b/fixtures/testing.sql index b6b3ce0..8b683b3 100644 --- a/fixtures/testing.sql +++ b/fixtures/testing.sql @@ -28,3 +28,9 @@ INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); + +CREATE TABLE public.row_count_no_rows ( + id INTEGER NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + value UUID NOT NULL DEFAULT uuid_generate_v4() +); diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 252fb46..8c2655c 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -258,16 +258,24 @@ class PostgresConnector: async def _query_results(self) -> QueryResult: count, row, rows = self.cursor.rowcount, None, None + + def _on_programming_error(err: psycopg2.ProgrammingError) -> None: + # Should always be empty in this context + if err.pgcode is not None: # pragma: nocover + LOGGER.warning( + 'Unexpected value for ProgrammingError(%s).pgcode: %r', + err, err.pgcode) + if self.cursor.rowcount == 1: try: row = dict(await self.cursor.fetchone()) - except psycopg2.ProgrammingError: - pass + except psycopg2.ProgrammingError as exc: + _on_programming_error(exc) elif self.cursor.rowcount > 1: try: rows = [dict(row) for row in await self.cursor.fetchall()] - except psycopg2.ProgrammingError: - pass + except psycopg2.ProgrammingError as exc: + _on_programming_error(exc) return QueryResult(count, row, rows) diff --git a/tests.py b/tests.py index 64f221c..2542e25 100644 --- a/tests.py +++ b/tests.py @@ -107,6 +107,19 @@ class InfluxDBRequestHandler(ExecuteRequestHandler): self.influxdb.add_field = mock.Mock() +class RowCountNoRowsRequestHandler(RequestHandler): + + GET_SQL = 'INSERT INTO public.row_count_no_rows (value) VALUES (%(value)s)' + + async def get(self): + count = 0 + for iteration in range(0, 5): + result = await self.postgres_execute( + self.GET_SQL, {'value': uuid.uuid4()}) + count += len(result) + await self.finish({'count': count}) + + class MetricsMixinRequestHandler(ExecuteRequestHandler): def __init__(self, *args, **kwargs): @@ -301,6 +314,7 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/multi-row', MultiRowRequestHandler), web.url('/no-error', NoErrorRequestHandler), web.url('/no-row', NoRowRequestHandler), + web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), web.url('/status', StatusRequestHandler), web.url('/timeout-error', TimeoutErrorRequestHandler), web.url('/transaction', TransactionRequestHandler), @@ -498,6 +512,12 @@ class RequestHandlerMixinTestCase(TestCase): with self.assertRaises(RuntimeError): await asyncio.gather(invoke_cursor(), invoke_cursor()) + def test_row_count_no_rows(self): + response = self.fetch('/row-count-no-rows') + self.assertEqual(response.code, 200) + data = json.loads(response.body) + self.assertEqual(data['count'], 5) + @mock.patch('aiopg.cursor.Cursor.execute') def test_timeout_error_when_overriding_on_postgres_error(self, execute): execute.side_effect = asyncio.TimeoutError