From 664bedbb2451d2a2a1a96e1d8c8d4f77606e59d5 Mon Sep 17 00:00:00 2001 From: "Gavin M. Roy" Date: Tue, 11 Aug 2020 19:25:34 -0400 Subject: [PATCH] Don't blindly swallow ProgrammingError ProgrammingError is raised when you try and fetch data from a cursor and there is no data to fetch. When this happens ProgrammingError.pgcode is None. It is also raised when your query has errors in it. Now if that's the case, they will be caught on L249 and not inside the function at L259. This new branch in the code will ensure that should we unexpectedly encounter a "real" programming error from Postgres, it is not blindly swallowed and a warning is issued. This should NEVER happen based upon my understanding of the psycopg2 internals. Unfortunately I couldn't come up with a good test case using mocks to make it happen, as ProgrammingError() takes no keyword arguments and pgcode is a read-only attribute on a ProgrammingError instance. I also couldn't figure out a way to raise ProgrammingError from psycopg2.errors.lookup/1. Thus, the # pragma: nocover --- fixtures/testing.sql | 6 ++++++ sprockets_postgres.py | 16 ++++++++++++---- tests.py | 20 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/fixtures/testing.sql b/fixtures/testing.sql index b6b3ce0..8b683b3 100644 --- a/fixtures/testing.sql +++ b/fixtures/testing.sql @@ -28,3 +28,9 @@ INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); INSERT INTO public.test_rows (toggle) VALUES (FALSE); + +CREATE TABLE public.row_count_no_rows ( + id INTEGER NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + value UUID NOT NULL DEFAULT uuid_generate_v4() +); diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 252fb46..8c2655c 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -258,16 +258,24 @@ class PostgresConnector: async def _query_results(self) -> QueryResult: count, row, rows = self.cursor.rowcount, None, None + + def _on_programming_error(err: psycopg2.ProgrammingError) -> None: + # Should always be empty in this context + if err.pgcode is not None: # pragma: nocover + LOGGER.warning( + 'Unexpected value for ProgrammingError(%s).pgcode: %r', + err, err.pgcode) + if self.cursor.rowcount == 1: try: row = dict(await self.cursor.fetchone()) - except psycopg2.ProgrammingError: - pass + except psycopg2.ProgrammingError as exc: + _on_programming_error(exc) elif self.cursor.rowcount > 1: try: rows = [dict(row) for row in await self.cursor.fetchall()] - except psycopg2.ProgrammingError: - pass + except psycopg2.ProgrammingError as exc: + _on_programming_error(exc) return QueryResult(count, row, rows) diff --git a/tests.py b/tests.py index 64f221c..2542e25 100644 --- a/tests.py +++ b/tests.py @@ -107,6 +107,19 @@ class InfluxDBRequestHandler(ExecuteRequestHandler): self.influxdb.add_field = mock.Mock() +class RowCountNoRowsRequestHandler(RequestHandler): + + GET_SQL = 'INSERT INTO public.row_count_no_rows (value) VALUES (%(value)s)' + + async def get(self): + count = 0 + for iteration in range(0, 5): + result = await self.postgres_execute( + self.GET_SQL, {'value': uuid.uuid4()}) + count += len(result) + await self.finish({'count': count}) + + class MetricsMixinRequestHandler(ExecuteRequestHandler): def __init__(self, *args, **kwargs): @@ -301,6 +314,7 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/multi-row', MultiRowRequestHandler), web.url('/no-error', NoErrorRequestHandler), web.url('/no-row', NoRowRequestHandler), + web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), web.url('/status', StatusRequestHandler), web.url('/timeout-error', TimeoutErrorRequestHandler), web.url('/transaction', TransactionRequestHandler), @@ -498,6 +512,12 @@ class RequestHandlerMixinTestCase(TestCase): with self.assertRaises(RuntimeError): await asyncio.gather(invoke_cursor(), invoke_cursor()) + def test_row_count_no_rows(self): + response = self.fetch('/row-count-no-rows') + self.assertEqual(response.code, 200) + data = json.loads(response.body) + self.assertEqual(data['count'], 5) + @mock.patch('aiopg.cursor.Cursor.execute') def test_timeout_error_when_overriding_on_postgres_error(self, execute): execute.side_effect = asyncio.TimeoutError