diff --git a/VERSION b/VERSION index 347f583..bc80560 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.1 +1.5.0 diff --git a/setup.cfg b/setup.cfg index d63605c..20f299b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ testing = flake8-rst-docstrings flake8-tuple pygments + tornado-problem-details [coverage:run] branch = True diff --git a/sprockets_postgres.py b/sprockets_postgres.py index d3da35b..48586f2 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -16,6 +16,10 @@ from aiodns import error as aiodns_error from aiopg import pool from psycopg2 import errors, extras from tornado import ioloop, web +try: + import problemdetails +except ImportError: # pragma: nocover + problemdetails = None LOGGER = logging.getLogger('sprockets-postgres') @@ -699,7 +703,13 @@ class RequestHandlerMixin: def on_postgres_error(self, metric_name: str, exc: Exception) -> typing.Optional[Exception]: - """Override for different error handling behaviors + """Invoked when an error occurs when executing a query + + If `tornado-problem-details` is available, + :exc:`problemdetails.Problem` will be raised instead of + :exc:`tornado.web.HTTPError`. + + Override for different error handling behaviors. Return an exception if you would like for it to be raised, or swallow it here. @@ -709,12 +719,25 @@ class RequestHandlerMixin: exc.__class__.__name__, self.__class__.__name__, metric_name, str(exc).split('\n')[0]) if isinstance(exc, ConnectionException): + if problemdetails: + raise problemdetails.Problem( + status_code=503, title='Database Connection Error', + detail=str(exc)) raise web.HTTPError(503, reason='Database Connection Error') elif isinstance(exc, asyncio.TimeoutError): + if problemdetails: + raise problemdetails.Problem( + status_code=500, title='Query Timeout', detail=str(exc)) raise web.HTTPError(500, reason='Query Timeout') elif isinstance(exc, errors.UniqueViolation): + if problemdetails: + raise problemdetails.Problem( + status_code=409, title='Unique Violation', detail=str(exc)) raise web.HTTPError(409, reason='Unique Violation') elif isinstance(exc, psycopg2.Error): + if problemdetails: + raise problemdetails.Problem( + status_code=500, title='Database Error', detail=str(exc)) raise web.HTTPError(500, reason='Database Error') return exc diff --git a/tests.py b/tests.py index a28fb29..c4bdb8a 100644 --- a/tests.py +++ b/tests.py @@ -10,6 +10,7 @@ from urllib import parse import aiopg import asynctest +import problemdetails import psycopg2 import pycares from asynctest import mock @@ -99,6 +100,11 @@ class ExecuteRequestHandler(RequestHandler): 'value': result.row['value'] if result.row else None}) +class ProblemDetailsExecuteRequestHandler( + problemdetails.ErrorWriter, ExecuteRequestHandler): + pass + + class InfluxDBRequestHandler(ExecuteRequestHandler): def __init__(self, *args, **kwargs): @@ -305,6 +311,7 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/multi-row', MultiRowRequestHandler), web.url('/no-error', NoErrorRequestHandler), web.url('/no-row', NoRowRequestHandler), + web.url('/pdexecute', ProblemDetailsExecuteRequestHandler), web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), web.url('/status', sprockets_postgres.StatusRequestHandler), web.url('/timeout-error', TimeoutErrorRequestHandler), @@ -425,40 +432,6 @@ class RequestHandlerMixinTestCase(TestCase): self.assertEqual(body['count'], 0) self.assertListEqual(body['rows'], []) - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_timeout_error(self, execute): - execute.side_effect = asyncio.TimeoutError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 500) - self.assertIn(b'Query Timeout', response.body) - - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_unique_violation(self, execute): - execute.side_effect = errors.UniqueViolation() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 409) - self.assertIn(b'Unique Violation', response.body) - - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_error(self, execute): - execute.side_effect = psycopg2.Error() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 500) - self.assertIn(b'Database Error', response.body) - - @mock.patch('aiopg.cursor.Cursor.fetchone') - def test_postgres_programming_error(self, fetchone): - fetchone.side_effect = psycopg2.ProgrammingError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 200) - self.assertIsNone(json.loads(response.body)['value']) - - @mock.patch('aiopg.connection.Connection.cursor') - def test_postgres_cursor_raises(self, cursor): - cursor.side_effect = asyncio.TimeoutError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 503) - def test_postgres_cursor_operational_error_reconnects(self): original = aiopg.connection.Connection.cursor @@ -524,6 +497,83 @@ class RequestHandlerMixinTestCase(TestCase): response = self.fetch('/unhandled-exception') self.assertEqual(response.code, 422) + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_timeout_error(self, execute): + execute.side_effect = asyncio.TimeoutError() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 500) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Query Timeout') + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_unique_violation(self, execute): + execute.side_effect = errors.UniqueViolation() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 409) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Unique Violation') + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_error(self, execute): + execute.side_effect = psycopg2.Error() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 500) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Database Error') + + @mock.patch('aiopg.cursor.Cursor.fetchone') + def test_postgres_programming_error(self, fetchone): + fetchone.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 200) + self.assertIsNone(json.loads(response.body)['value']) + + +class RequestHandlerMixinHTTPErrorTestCase(TestCase): + + def setUp(self): + super().setUp() + self._problemdetails = sprockets_postgres.problemdetails + sprockets_postgres.problemdetails = None + + def tearDown(self): + sprockets_postgres.problemdetails = self._problemdetails + super().tearDown() + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_timeout_error(self, execute): + execute.side_effect = asyncio.TimeoutError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 500) + self.assertIn(b'Query Timeout', response.body) + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_unique_violation(self, execute): + execute.side_effect = errors.UniqueViolation() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 409) + self.assertIn(b'Unique Violation', response.body) + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_error(self, execute): + execute.side_effect = psycopg2.Error() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 500) + self.assertIn(b'Database Error', response.body) + + @mock.patch('aiopg.cursor.Cursor.fetchone') + def test_postgres_programming_error(self, fetchone): + fetchone.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 200) + self.assertIsNone(json.loads(response.body)['value']) + + @mock.patch('aiopg.connection.Connection.cursor') + def test_postgres_cursor_raises(self, cursor): + cursor.side_effect = asyncio.TimeoutError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 503) + class TransactionTestCase(TestCase):