diff --git a/VERSION b/VERSION index 347f583..bc80560 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.1 +1.5.0 diff --git a/setup.cfg b/setup.cfg index d63605c..20f299b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ testing = flake8-rst-docstrings flake8-tuple pygments + tornado-problem-details [coverage:run] branch = True diff --git a/sprockets_postgres.py b/sprockets_postgres.py index d3da35b..0c253ec 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -16,6 +16,10 @@ from aiodns import error as aiodns_error from aiopg import pool from psycopg2 import errors, extras from tornado import ioloop, web +try: + import problemdetails +except ImportError: # pragma: nocover + problemdetails = None LOGGER = logging.getLogger('sprockets-postgres') @@ -369,7 +373,10 @@ class ApplicationMixin: _attempt + 1) as connector: yield connector return - exc = on_error('postgres_connector', ConnectionException(str(err))) + if on_error is None: + raise ConnectionException(str(err)) + exc = on_error( + 'postgres_connector', ConnectionException(str(err))) if exc: raise exc else: # postgres_status.on_error does not return an exception @@ -699,7 +706,13 @@ class RequestHandlerMixin: def on_postgres_error(self, metric_name: str, exc: Exception) -> typing.Optional[Exception]: - """Override for different error handling behaviors + """Invoked when an error occurs when executing a query + + If `tornado-problem-details` is available, + :exc:`problemdetails.Problem` will be raised instead of + :exc:`tornado.web.HTTPError`. + + Override for different error handling behaviors. Return an exception if you would like for it to be raised, or swallow it here. @@ -709,12 +722,24 @@ class RequestHandlerMixin: exc.__class__.__name__, self.__class__.__name__, metric_name, str(exc).split('\n')[0]) if isinstance(exc, ConnectionException): + if problemdetails: + raise problemdetails.Problem( + status_code=503, title='Database Connection Error') raise web.HTTPError(503, reason='Database Connection Error') elif isinstance(exc, asyncio.TimeoutError): + if problemdetails: + raise problemdetails.Problem( + status_code=500, title='Query Timeout') raise web.HTTPError(500, reason='Query Timeout') elif isinstance(exc, errors.UniqueViolation): + if problemdetails: + raise problemdetails.Problem( + status_code=409, title='Unique Violation') raise web.HTTPError(409, reason='Unique Violation') elif isinstance(exc, psycopg2.Error): + if problemdetails: + raise problemdetails.Problem( + status_code=500, title='Database Error') raise web.HTTPError(500, reason='Database Error') return exc diff --git a/tests.py b/tests.py index a28fb29..56b6152 100644 --- a/tests.py +++ b/tests.py @@ -10,6 +10,7 @@ from urllib import parse import aiopg import asynctest +import problemdetails import psycopg2 import pycares from asynctest import mock @@ -99,6 +100,11 @@ class ExecuteRequestHandler(RequestHandler): 'value': result.row['value'] if result.row else None}) +class ProblemDetailsExecuteRequestHandler( + problemdetails.ErrorWriter, ExecuteRequestHandler): + pass + + class InfluxDBRequestHandler(ExecuteRequestHandler): def __init__(self, *args, **kwargs): @@ -162,6 +168,18 @@ class NoErrorRequestHandler(ErrorRequestHandler): return None +class NoMixinRequestHandler(web.RequestHandler): + + async def get(self): + try: + async with self.application.postgres_connector() as connector: + await connector.execute('SELECT 1', {}) + except sprockets_postgres.ConnectionException: + raise web.HTTPError(503) + else: + raise web.HTTPError(418) + + class NoRowRequestHandler(RequestHandler): GET_SQL = """\ @@ -251,7 +269,7 @@ class UnhandledExceptionRequestHandler(RequestHandler): await self.postgres_execute(self.GET_SQL) except psycopg2.DataError: raise web.HTTPError(422) - raise web.HTTPError(500, 'This should have failed') + raise web.HTTPError(418, 'This should have failed') def on_postgres_error(self, metric_name: str, @@ -265,6 +283,14 @@ class UnhandledExceptionRequestHandler(RequestHandler): return exc +class DontRaiseRequestHandler(UnhandledExceptionRequestHandler): + + def on_postgres_error(self, + metric_name: str, + exc: Exception) -> typing.Optional[Exception]: + return None + + class Application(sprockets_postgres.ApplicationMixin, app.Application): @@ -297,6 +323,7 @@ class TestCase(testing.SprocketsHttpTestCase): self.app = Application(handlers=[ web.url('/callproc', CallprocRequestHandler), web.url('/count', CountRequestHandler), + web.url('/dont-raise', DontRaiseRequestHandler), web.url('/error', ErrorRequestHandler), web.url('/error-passthrough', ErrorPassthroughRequestHandler), web.url('/execute', ExecuteRequestHandler), @@ -304,7 +331,9 @@ class TestCase(testing.SprocketsHttpTestCase): web.url('/metrics-mixin', MetricsMixinRequestHandler), web.url('/multi-row', MultiRowRequestHandler), web.url('/no-error', NoErrorRequestHandler), + web.url('/no-mixin', NoMixinRequestHandler), web.url('/no-row', NoRowRequestHandler), + web.url('/pdexecute', ProblemDetailsExecuteRequestHandler), web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), web.url('/status', sprockets_postgres.StatusRequestHandler), web.url('/timeout-error', TimeoutErrorRequestHandler), @@ -425,40 +454,6 @@ class RequestHandlerMixinTestCase(TestCase): self.assertEqual(body['count'], 0) self.assertListEqual(body['rows'], []) - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_timeout_error(self, execute): - execute.side_effect = asyncio.TimeoutError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 500) - self.assertIn(b'Query Timeout', response.body) - - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_unique_violation(self, execute): - execute.side_effect = errors.UniqueViolation() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 409) - self.assertIn(b'Unique Violation', response.body) - - @mock.patch('aiopg.cursor.Cursor.execute') - def test_postgres_execute_error(self, execute): - execute.side_effect = psycopg2.Error() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 500) - self.assertIn(b'Database Error', response.body) - - @mock.patch('aiopg.cursor.Cursor.fetchone') - def test_postgres_programming_error(self, fetchone): - fetchone.side_effect = psycopg2.ProgrammingError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 200) - self.assertIsNone(json.loads(response.body)['value']) - - @mock.patch('aiopg.connection.Connection.cursor') - def test_postgres_cursor_raises(self, cursor): - cursor.side_effect = asyncio.TimeoutError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 503) - def test_postgres_cursor_operational_error_reconnects(self): original = aiopg.connection.Connection.cursor @@ -524,6 +519,102 @@ class RequestHandlerMixinTestCase(TestCase): response = self.fetch('/unhandled-exception') self.assertEqual(response.code, 422) + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_timeout_error(self, execute): + execute.side_effect = asyncio.TimeoutError() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 500) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Query Timeout') + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_unique_violation(self, execute): + execute.side_effect = errors.UniqueViolation() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 409) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Unique Violation') + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_error(self, execute): + execute.side_effect = psycopg2.Error() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 500) + problem = json.loads(response.body) + self.assertEqual(problem['title'], 'Database Error') + + @mock.patch('aiopg.cursor.Cursor.fetchone') + def test_postgres_programming_error(self, fetchone): + fetchone.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/pdexecute?value=1') + self.assertEqual(response.code, 200) + self.assertIsNone(json.loads(response.body)['value']) + + +class HTTPErrorTestCase(TestCase): + + def setUp(self): + super().setUp() + self._problemdetails = sprockets_postgres.problemdetails + sprockets_postgres.problemdetails = None + + def tearDown(self): + sprockets_postgres.problemdetails = self._problemdetails + super().tearDown() + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_timeout_error(self, execute): + execute.side_effect = asyncio.TimeoutError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 500) + self.assertIn(b'Query Timeout', response.body) + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_unique_violation(self, execute): + execute.side_effect = errors.UniqueViolation() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 409) + self.assertIn(b'Unique Violation', response.body) + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_execute_error(self, execute): + execute.side_effect = psycopg2.Error() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 500) + self.assertIn(b'Database Error', response.body) + + @mock.patch('aiopg.cursor.Cursor.fetchone') + def test_postgres_programming_error(self, fetchone): + fetchone.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 200) + self.assertIsNone(json.loads(response.body)['value']) + + @mock.patch('aiopg.connection.Connection.cursor') + def test_postgres_cursor_raises(self, cursor): + cursor.side_effect = asyncio.TimeoutError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 503) + + def test_on_error_no_exception_branch(self): + response = self.fetch('/dont-raise') + self.assertEqual(response.code, 418) + + +class NoMixinTestCase(TestCase): + + @mock.patch('aiopg.cursor.Cursor.execute') + def test_postgres_cursor_raises(self, execute): + execute.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/no-mixin') + self.assertEqual(response.code, 500) + + @mock.patch('aiopg.pool.Pool.acquire') + def test_postgres_status_connect_error(self, acquire): + acquire.side_effect = asyncio.TimeoutError() + response = self.fetch('/no-mixin') + self.assertEqual(response.code, 503) + class TransactionTestCase(TestCase):