Merge pull request #7 from gmr/master

Add tornado-problem-details support and fix an issue when omitting an on_error callback
This commit is contained in:
Andrew Rabert 2020-09-17 12:06:44 -04:00 committed by GitHub
commit 6c2e26ef0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 155 additions and 38 deletions

View file

@ -1 +1 @@
1.4.1 1.5.0

View file

@ -54,6 +54,7 @@ testing =
flake8-rst-docstrings flake8-rst-docstrings
flake8-tuple flake8-tuple
pygments pygments
tornado-problem-details
[coverage:run] [coverage:run]
branch = True branch = True

View file

@ -16,6 +16,10 @@ from aiodns import error as aiodns_error
from aiopg import pool from aiopg import pool
from psycopg2 import errors, extras from psycopg2 import errors, extras
from tornado import ioloop, web from tornado import ioloop, web
try:
import problemdetails
except ImportError: # pragma: nocover
problemdetails = None
LOGGER = logging.getLogger('sprockets-postgres') LOGGER = logging.getLogger('sprockets-postgres')
@ -369,7 +373,10 @@ class ApplicationMixin:
_attempt + 1) as connector: _attempt + 1) as connector:
yield connector yield connector
return return
exc = on_error('postgres_connector', ConnectionException(str(err))) if on_error is None:
raise ConnectionException(str(err))
exc = on_error(
'postgres_connector', ConnectionException(str(err)))
if exc: if exc:
raise exc raise exc
else: # postgres_status.on_error does not return an exception else: # postgres_status.on_error does not return an exception
@ -699,7 +706,13 @@ class RequestHandlerMixin:
def on_postgres_error(self, def on_postgres_error(self,
metric_name: str, metric_name: str,
exc: Exception) -> typing.Optional[Exception]: exc: Exception) -> typing.Optional[Exception]:
"""Override for different error handling behaviors """Invoked when an error occurs when executing a query
If `tornado-problem-details` is available,
:exc:`problemdetails.Problem` will be raised instead of
:exc:`tornado.web.HTTPError`.
Override for different error handling behaviors.
Return an exception if you would like for it to be raised, or swallow Return an exception if you would like for it to be raised, or swallow
it here. it here.
@ -709,12 +722,24 @@ class RequestHandlerMixin:
exc.__class__.__name__, self.__class__.__name__, exc.__class__.__name__, self.__class__.__name__,
metric_name, str(exc).split('\n')[0]) metric_name, str(exc).split('\n')[0])
if isinstance(exc, ConnectionException): if isinstance(exc, ConnectionException):
if problemdetails:
raise problemdetails.Problem(
status_code=503, title='Database Connection Error')
raise web.HTTPError(503, reason='Database Connection Error') raise web.HTTPError(503, reason='Database Connection Error')
elif isinstance(exc, asyncio.TimeoutError): elif isinstance(exc, asyncio.TimeoutError):
if problemdetails:
raise problemdetails.Problem(
status_code=500, title='Query Timeout')
raise web.HTTPError(500, reason='Query Timeout') raise web.HTTPError(500, reason='Query Timeout')
elif isinstance(exc, errors.UniqueViolation): elif isinstance(exc, errors.UniqueViolation):
if problemdetails:
raise problemdetails.Problem(
status_code=409, title='Unique Violation')
raise web.HTTPError(409, reason='Unique Violation') raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.Error): elif isinstance(exc, psycopg2.Error):
if problemdetails:
raise problemdetails.Problem(
status_code=500, title='Database Error')
raise web.HTTPError(500, reason='Database Error') raise web.HTTPError(500, reason='Database Error')
return exc return exc

161
tests.py
View file

@ -10,6 +10,7 @@ from urllib import parse
import aiopg import aiopg
import asynctest import asynctest
import problemdetails
import psycopg2 import psycopg2
import pycares import pycares
from asynctest import mock from asynctest import mock
@ -99,6 +100,11 @@ class ExecuteRequestHandler(RequestHandler):
'value': result.row['value'] if result.row else None}) 'value': result.row['value'] if result.row else None})
class ProblemDetailsExecuteRequestHandler(
problemdetails.ErrorWriter, ExecuteRequestHandler):
pass
class InfluxDBRequestHandler(ExecuteRequestHandler): class InfluxDBRequestHandler(ExecuteRequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -162,6 +168,18 @@ class NoErrorRequestHandler(ErrorRequestHandler):
return None return None
class NoMixinRequestHandler(web.RequestHandler):
async def get(self):
try:
async with self.application.postgres_connector() as connector:
await connector.execute('SELECT 1', {})
except sprockets_postgres.ConnectionException:
raise web.HTTPError(503)
else:
raise web.HTTPError(418)
class NoRowRequestHandler(RequestHandler): class NoRowRequestHandler(RequestHandler):
GET_SQL = """\ GET_SQL = """\
@ -251,7 +269,7 @@ class UnhandledExceptionRequestHandler(RequestHandler):
await self.postgres_execute(self.GET_SQL) await self.postgres_execute(self.GET_SQL)
except psycopg2.DataError: except psycopg2.DataError:
raise web.HTTPError(422) raise web.HTTPError(422)
raise web.HTTPError(500, 'This should have failed') raise web.HTTPError(418, 'This should have failed')
def on_postgres_error(self, def on_postgres_error(self,
metric_name: str, metric_name: str,
@ -265,6 +283,14 @@ class UnhandledExceptionRequestHandler(RequestHandler):
return exc return exc
class DontRaiseRequestHandler(UnhandledExceptionRequestHandler):
def on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
return None
class Application(sprockets_postgres.ApplicationMixin, class Application(sprockets_postgres.ApplicationMixin,
app.Application): app.Application):
@ -297,6 +323,7 @@ class TestCase(testing.SprocketsHttpTestCase):
self.app = Application(handlers=[ self.app = Application(handlers=[
web.url('/callproc', CallprocRequestHandler), web.url('/callproc', CallprocRequestHandler),
web.url('/count', CountRequestHandler), web.url('/count', CountRequestHandler),
web.url('/dont-raise', DontRaiseRequestHandler),
web.url('/error', ErrorRequestHandler), web.url('/error', ErrorRequestHandler),
web.url('/error-passthrough', ErrorPassthroughRequestHandler), web.url('/error-passthrough', ErrorPassthroughRequestHandler),
web.url('/execute', ExecuteRequestHandler), web.url('/execute', ExecuteRequestHandler),
@ -304,7 +331,9 @@ class TestCase(testing.SprocketsHttpTestCase):
web.url('/metrics-mixin', MetricsMixinRequestHandler), web.url('/metrics-mixin', MetricsMixinRequestHandler),
web.url('/multi-row', MultiRowRequestHandler), web.url('/multi-row', MultiRowRequestHandler),
web.url('/no-error', NoErrorRequestHandler), web.url('/no-error', NoErrorRequestHandler),
web.url('/no-mixin', NoMixinRequestHandler),
web.url('/no-row', NoRowRequestHandler), web.url('/no-row', NoRowRequestHandler),
web.url('/pdexecute', ProblemDetailsExecuteRequestHandler),
web.url('/row-count-no-rows', RowCountNoRowsRequestHandler), web.url('/row-count-no-rows', RowCountNoRowsRequestHandler),
web.url('/status', sprockets_postgres.StatusRequestHandler), web.url('/status', sprockets_postgres.StatusRequestHandler),
web.url('/timeout-error', TimeoutErrorRequestHandler), web.url('/timeout-error', TimeoutErrorRequestHandler),
@ -425,40 +454,6 @@ class RequestHandlerMixinTestCase(TestCase):
self.assertEqual(body['count'], 0) self.assertEqual(body['count'], 0)
self.assertListEqual(body['rows'], []) self.assertListEqual(body['rows'], [])
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_timeout_error(self, execute):
execute.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 500)
self.assertIn(b'Query Timeout', response.body)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_unique_violation(self, execute):
execute.side_effect = errors.UniqueViolation()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 409)
self.assertIn(b'Unique Violation', response.body)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_error(self, execute):
execute.side_effect = psycopg2.Error()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 500)
self.assertIn(b'Database Error', response.body)
@mock.patch('aiopg.cursor.Cursor.fetchone')
def test_postgres_programming_error(self, fetchone):
fetchone.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200)
self.assertIsNone(json.loads(response.body)['value'])
@mock.patch('aiopg.connection.Connection.cursor')
def test_postgres_cursor_raises(self, cursor):
cursor.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
def test_postgres_cursor_operational_error_reconnects(self): def test_postgres_cursor_operational_error_reconnects(self):
original = aiopg.connection.Connection.cursor original = aiopg.connection.Connection.cursor
@ -524,6 +519,102 @@ class RequestHandlerMixinTestCase(TestCase):
response = self.fetch('/unhandled-exception') response = self.fetch('/unhandled-exception')
self.assertEqual(response.code, 422) self.assertEqual(response.code, 422)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_timeout_error(self, execute):
execute.side_effect = asyncio.TimeoutError()
response = self.fetch('/pdexecute?value=1')
self.assertEqual(response.code, 500)
problem = json.loads(response.body)
self.assertEqual(problem['title'], 'Query Timeout')
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_unique_violation(self, execute):
execute.side_effect = errors.UniqueViolation()
response = self.fetch('/pdexecute?value=1')
self.assertEqual(response.code, 409)
problem = json.loads(response.body)
self.assertEqual(problem['title'], 'Unique Violation')
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_error(self, execute):
execute.side_effect = psycopg2.Error()
response = self.fetch('/pdexecute?value=1')
self.assertEqual(response.code, 500)
problem = json.loads(response.body)
self.assertEqual(problem['title'], 'Database Error')
@mock.patch('aiopg.cursor.Cursor.fetchone')
def test_postgres_programming_error(self, fetchone):
fetchone.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/pdexecute?value=1')
self.assertEqual(response.code, 200)
self.assertIsNone(json.loads(response.body)['value'])
class HTTPErrorTestCase(TestCase):
def setUp(self):
super().setUp()
self._problemdetails = sprockets_postgres.problemdetails
sprockets_postgres.problemdetails = None
def tearDown(self):
sprockets_postgres.problemdetails = self._problemdetails
super().tearDown()
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_timeout_error(self, execute):
execute.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 500)
self.assertIn(b'Query Timeout', response.body)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_unique_violation(self, execute):
execute.side_effect = errors.UniqueViolation()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 409)
self.assertIn(b'Unique Violation', response.body)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_execute_error(self, execute):
execute.side_effect = psycopg2.Error()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 500)
self.assertIn(b'Database Error', response.body)
@mock.patch('aiopg.cursor.Cursor.fetchone')
def test_postgres_programming_error(self, fetchone):
fetchone.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200)
self.assertIsNone(json.loads(response.body)['value'])
@mock.patch('aiopg.connection.Connection.cursor')
def test_postgres_cursor_raises(self, cursor):
cursor.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
def test_on_error_no_exception_branch(self):
response = self.fetch('/dont-raise')
self.assertEqual(response.code, 418)
class NoMixinTestCase(TestCase):
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_cursor_raises(self, execute):
execute.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/no-mixin')
self.assertEqual(response.code, 500)
@mock.patch('aiopg.pool.Pool.acquire')
def test_postgres_status_connect_error(self, acquire):
acquire.side_effect = asyncio.TimeoutError()
response = self.fetch('/no-mixin')
self.assertEqual(response.code, 503)
class TransactionTestCase(TestCase): class TransactionTestCase(TestCase):