diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 2c04970..9618843 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -1,24 +1,16 @@ -""" -:class:`sprockets.http.app.Application` mixin for handling the connection to -Postgres and exporting functions for querying the database, getting the status, -and proving a cursor. - -Automatically creates and shuts down :class:`aio.pool.Pool` on startup -and shutdown. - -""" import asyncio import contextlib -import functools +import dataclasses import logging import os +import time import typing from distutils import util import aiopg import psycopg2 from aiopg import pool -from psycopg2 import errors, extensions, extras +from psycopg2 import errors, extras from tornado import ioloop, web LOGGER = logging.getLogger('sprockets-postgres') @@ -30,13 +22,115 @@ DEFAULT_POSTGRES_JSON = 'FALSE' DEFAULT_POSTGRES_MAX_POOL_SIZE = 0 DEFAULT_POSTGRES_MIN_POOL_SIZE = 1 DEFAULT_POSTGRES_QUERY_TIMEOUT = 120 -DEFAULT_POSTGRES_URL = 'postgresql://localhost:5432' DEFAULT_POSTGRES_UUID = 'TRUE' +QueryParameters = typing.Union[list, tuple, None] +Timeout = typing.Union[int, float, None] + + +@dataclasses.dataclass +class QueryResult: + row_count: int + row: typing.Optional[dict] + rows: typing.Optional[typing.List[dict]] + + +class PostgresConnector: + + def __init__(self, + cursor: aiopg.Cursor, + on_error: typing.Callable, + record_duration: typing.Optional[typing.Callable] = None, + timeout: Timeout = None): + self.cursor = cursor + self._on_error = on_error + self._record_duration = record_duration + self._timeout = timeout or int( + os.environ.get( + 'POSTGRES_QUERY_TIMEOUT', + DEFAULT_POSTGRES_QUERY_TIMEOUT)) + + async def callproc(self, + name: str, + parameters: QueryParameters = None, + metric_name: str = '', + *, + timeout: Timeout = None) -> QueryResult: + return await self._query( + self.cursor.callproc, + metric_name, + procname=name, + parameters=parameters, + timeout=timeout) + + async def execute(self, + sql: str, + parameters: QueryParameters = None, + metric_name: str = '', + *, + timeout: Timeout = None) -> QueryResult: + return await self._query( + self.cursor.execute, + metric_name, + operation=sql, + parameters=parameters, + timeout=timeout) + + @contextlib.asynccontextmanager + async def transaction(self) \ + -> typing.AsyncContextManager['PostgresConnector']: + async with self.cursor.begin(): + yield self + + async def _query(self, + method: typing.Callable, + metric_name: str, + **kwargs): + if kwargs['timeout'] is None: + kwargs['timeout'] = self._timeout + start_time = time.monotonic() + try: + await method(**kwargs) + except (asyncio.TimeoutError, psycopg2.Error) as err: + LOGGER.error('Caught %r', err) + exc = self._on_error(metric_name, err) + if exc: + raise exc + finally: + if self._record_duration: + self._record_duration( + metric_name, time.monotonic() - start_time) + return await self._query_results() + + async def _query_results(self) -> QueryResult: + row, rows = None, None + if self.cursor.rowcount == 1: + try: + row = dict(await self.cursor.fetchone()) + except psycopg2.ProgrammingError: + pass + elif self.cursor.rowcount > 1: + try: + rows = [dict(row) for row in await self.cursor.fetchall()] + except psycopg2.ProgrammingError: + pass + return QueryResult(self.cursor.rowcount, row, rows) + + +class ConnectionException(Exception): + """Raised when the connection to Postgres can not be established""" + class ApplicationMixin: - """Application mixin for setting up the PostgreSQL client pool""" + """ + :class:`sprockets.http.app.Application` mixin for handling the connection + to Postgres and exporting functions for querying the database, + getting the status, and proving a cursor. + Automatically creates and shuts down :class:`aio.pool.Pool` on startup + and shutdown. + + """ POSTGRES_STATUS_TIMEOUT = 3 def __init__(self, *args, **kwargs): @@ -46,155 +140,57 @@ class ApplicationMixin: self.runner_callbacks['shutdown'].append(self._postgres_shutdown) @contextlib.asynccontextmanager - async def postgres_cursor(self, - timeout: typing.Optional[int] = None, - raise_http_error: bool = True) \ - -> typing.AsyncContextManager[extensions.cursor]: - """Return a Postgres cursor for the pool""" + async def postgres_connector(self, + on_error: typing.Callable, + record_duration: typing.Optional[ + typing.Callable] = None, + timeout: Timeout = None) \ + -> typing.AsyncContextManager[PostgresConnector]: try: async with self._postgres_pool.acquire() as conn: async with conn.cursor( cursor_factory=extras.RealDictCursor, - timeout=self._postgres_query_timeout(timeout)) as pgc: - yield pgc - except (asyncio.TimeoutError, - psycopg2.OperationalError, - psycopg2.Error) as error: - LOGGER.critical('Error connecting to Postgres: %s', error) - if raise_http_error: - raise web.HTTPError(503, 'Database Unavailable') - raise - - async def postgres_callproc(self, - name: str, - params: typing.Union[list, tuple, None] = None, - timeout: typing.Optional[int] = None) \ - -> typing.Union[dict, list, None]: - """Execute a stored procedure, specifying the name, SQL, passing in - optional parameters. - - :param name: The stored-proc to call - :param params: Optional parameters to pass into the function - :param timeout: Optional timeout to override the default query timeout - - """ - async with self.postgres_cursor(timeout) as cursor: - return await self._postgres_query( - cursor, cursor.callproc, name, name, params) - - async def postgres_execute(self, name: str, sql: str, - *args, - timeout: typing.Optional[int] = None) \ - -> typing.Union[dict, list, None]: - """Execute a query, specifying a name for the query, the SQL statement, - and optional positional arguments to pass in with the query. - - Parameters may be provided as sequence or mapping and will be - bound to variables in the operation. Variables are specified - either with positional ``%s`` or named ``%({name})s`` placeholders. - - :param name: The stored-proc to call - :param sql: The SQL statement to execute - :param timeout: Optional timeout to override the default query timeout - - """ - async with self.postgres_cursor(timeout) as cursor: - return await self._postgres_query( - cursor, cursor.execute, name, sql, args) + timeout=timeout) as cursor: + yield PostgresConnector( + cursor, on_error, record_duration, timeout) + except (asyncio.TimeoutError, psycopg2.Error) as err: + on_error('postgres_connector', ConnectionException(str(err))) async def postgres_status(self) -> dict: """Invoke from the ``/status`` RequestHandler to check that there is a Postgres connection handler available and return info about the pool. - """ - available = True - try: - async with self.postgres_cursor( - self.POSTGRES_STATUS_TIMEOUT, False) as cursor: - await cursor.execute('SELECT 1') - except (asyncio.TimeoutError, psycopg2.OperationalError): - available = False + """ + query_error = asyncio.Event() + + def on_error(_metric_name, _exc) -> None: + query_error.set() + return None + + async with self.postgres_connector( + on_error, + timeout=self.POSTGRES_STATUS_TIMEOUT) as connector: + await connector.execute('SELECT 1') return { - 'available': available, + 'available': not query_error.is_set(), 'pool_size': self._postgres_pool.size, 'pool_free': self._postgres_pool.freesize } - async def _postgres_query(self, - cursor: aiopg.Cursor, - method: typing.Callable, - name: str, - sql: str, - parameters: typing.Union[dict, list, tuple]) \ - -> typing.Union[dict, list, None]: - """Execute a query, specifying the name, SQL, passing in - - """ - try: - await method(sql, parameters) - except asyncio.TimeoutError as err: - LOGGER.error('Query timeout for %s: %s', - name, str(err).split('\n')[0]) - raise web.HTTPError(500, reason='Query Timeout') - except errors.UniqueViolation as err: - LOGGER.error('Database error for %s: %s', - name, str(err).split('\n')[0]) - raise web.HTTPError(409, reason='Unique Violation') - except psycopg2.Error as err: - LOGGER.error('Database error for %s: %s', - name, str(err).split('\n')[0]) - raise web.HTTPError(500, reason='Database Error') - try: - return await self._postgres_query_results(cursor) - except psycopg2.ProgrammingError: - return - - @staticmethod - async def _postgres_query_results(cursor: aiopg.Cursor) \ - -> typing.Union[dict, list, None]: - """Invoked by self.postgres_query to return all of the query results - as either a ``dict`` or ``list`` depending on the quantity of rows. - - This can raise a ``psycopg2.ProgrammingError`` for an INSERT/UPDATE - without RETURNING or a DELETE. That exception is caught by the caller. - - :raises psycopg2.ProgrammingError: when there are no rows to fetch - even though the rowcount is > 0 - - """ - if cursor.rowcount == 1: - return await cursor.fetchone() - elif cursor.rowcount > 1: - return await cursor.fetchall() - return None - - @functools.lru_cache(2) - def _postgres_query_timeout(self, - timeout: typing.Optional[int] = None) -> int: - """Return query timeout, either from the specified value or - ``POSTGRES_QUERY_TIMEOUT`` environment variable, if set. - - Defaults to sprockets_postgres.DEFAULT_POSTGRES_QUERY_TIMEOUT. - - """ - return timeout if timeout else int( - os.environ.get( - 'POSTGRES_QUERY_TIMEOUT', - DEFAULT_POSTGRES_QUERY_TIMEOUT)) - async def _postgres_setup(self, _app: web.Application, - _ioloop: ioloop.IOLoop) -> None: + loop: ioloop.IOLoop) -> None: """Setup the Postgres pool of connections and log if there is an error. This is invoked by the Application on start callback mechanism. """ - url = os.environ.get('POSTGRES_URL', DEFAULT_POSTGRES_URL) - LOGGER.debug('Connecting to PostgreSQL: %s', url) + if 'POSTGRES_URL' not in os.environ: + LOGGER.critical('Missing POSTGRES_URL environment variable') + return self.stop(loop) self._postgres_pool = pool.Pool( - url, + os.environ['POSTGRES_URL'], minsize=int( os.environ.get( 'POSTGRES_MIN_POOL_SIZE', @@ -236,3 +232,90 @@ class ApplicationMixin: """ self._postgres_pool.close() await self._postgres_pool.wait_closed() + + +class RequestHandlerMixin: + """ + RequestHandler mixin class exposing functions for querying the database, + recording the duration to either `sprockets-influxdb` or + `sprockets.mixins.metrics`, and handling exceptions. + + """ + async def postgres_callproc(self, + name: str, + parameters: QueryParameters = None, + metric_name: str = '', + *, + timeout: Timeout = None) -> QueryResult: + async with self._postgres_connector(timeout) as connector: + return await connector.callproc( + name, parameters, metric_name, timeout=timeout) + + async def postgres_execute(self, + sql: str, + parameters: QueryParameters = None, + metric_name: str = '', + *, + timeout: Timeout = None) -> QueryResult: + """Execute a query, specifying a name for the query, the SQL statement, + and optional positional arguments to pass in with the query. + + Parameters may be provided as sequence or mapping and will be + bound to variables in the operation. Variables are specified + either with positional ``%s`` or named ``%({name})s`` placeholders. + + """ + async with self._postgres_connector(timeout) as connector: + return await connector.execute( + sql, parameters, metric_name, timeout=timeout) + + @contextlib.asynccontextmanager + async def postgres_transaction(self, timeout: Timeout = None) \ + -> typing.AsyncContextManager[PostgresConnector]: + """Yields a :class:`PostgresConnector` instance in a transaction. + Will automatically commit or rollback based upon exception. + + """ + async with self._postgres_connector(timeout) as connector: + async with connector.transaction(): + yield connector + + @contextlib.asynccontextmanager + async def _postgres_connector(self, timeout: Timeout = None) \ + -> typing.AsyncContextManager[PostgresConnector]: + async with self.application.postgres_connector( + self.__on_postgres_error, + self.__on_postgres_timing, + timeout) as connector: + yield connector + + def __on_postgres_error(self, + metric_name: str, + exc: Exception) -> typing.Optional[Exception]: + """Override for different error handling behaviors""" + LOGGER.error('%s in %s for %s (%s)', + exc.__class__.__name__, + self.__class__.__name__, + metric_name, + str(exc).split('\n')[0]) + if isinstance(exc, ConnectionException): + raise web.HTTPError(503, reason='Database Connection Error') + elif isinstance(exc, asyncio.TimeoutError): + raise web.HTTPError(500, reason='Query Timeout') + elif isinstance(exc, errors.UniqueViolation): + raise web.HTTPError(409, reason='Unique Violation') + elif isinstance(exc, psycopg2.Error): + raise web.HTTPError(500, reason='Database Error') + return exc + + def __on_postgres_timing(self, + metric_name: str, + duration: float) -> None: + """Override for custom metric recording""" + if hasattr(self, 'influxdb'): # sprockets-influxdb + self.influxdb.set_field(metric_name, duration) + elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics + self.record_timing(metric_name, duration) + else: + LOGGER.debug('Postgres query %s duration: %s', + metric_name, duration) diff --git a/tests.py b/tests.py index 2294ff8..acfa9a2 100644 --- a/tests.py +++ b/tests.py @@ -12,40 +12,46 @@ from tornado import web import sprockets_postgres -class CallprocRequestHandler(web.RequestHandler): +class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin, + web.RequestHandler): async def get(self): - result = await self.application.postgres_callproc('uuid_generate_v4') - await self.finish({'value': str(result['uuid_generate_v4'])}) + result = await self.postgres_callproc( + 'uuid_generate_v4', metric_name='uuid') + await self.finish({'value': str(result.row['uuid_generate_v4'])}) -class ExecuteRequestHandler(web.RequestHandler): +class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin, + web.RequestHandler): GET_SQL = 'SELECT %s::TEXT AS value;' async def get(self): - result = await self.application.postgres_execute( - 'get', self.GET_SQL, self.get_argument('value')) - await self.finish({'value': result['value'] if result else None}) + result = await self.postgres_execute( + self.GET_SQL, [self.get_argument('value')], 'get') + await self.finish({ + 'value': result.row['value'] if result.row else None}) -class MultiRowRequestHandler(web.RequestHandler): +class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin, + web.RequestHandler): GET_SQL = 'SELECT * FROM information_schema.enabled_roles;' async def get(self): - rows = await self.application.postgres_execute('get', self.GET_SQL) - await self.finish({'rows': [row['role_name'] for row in rows]}) + result = await self.postgres_execute(self.GET_SQL) + await self.finish({'rows': [row['role_name'] for row in result.rows]}) -class NoRowRequestHandler(web.RequestHandler): +class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin, + web.RequestHandler): GET_SQL = """\ SELECT * FROM information_schema.tables WHERE table_schema = 'public';""" async def get(self): - rows = await self.application.postgres_execute('get', self.GET_SQL) - await self.finish({'rows': rows}) + result = await self.postgres_execute(self.GET_SQL) + await self.finish({'rows': result.rows}) class StatusRequestHandler(web.RequestHandler): @@ -62,7 +68,7 @@ class Application(sprockets_postgres.ApplicationMixin, pass -class ExecuteTestCase(testing.SprocketsHttpTestCase): +class TestCase(testing.SprocketsHttpTestCase): @classmethod def setUpClass(cls): @@ -143,15 +149,44 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase): self.assertEqual(response.code, 500) self.assertIn(b'Database Error', response.body) - def test_postgres_programming_error(self): - with mock.patch.object(self.app, '_postgres_query_results') as pqr: - pqr.side_effect = psycopg2.ProgrammingError() - response = self.fetch('/execute?value=1') - self.assertEqual(response.code, 200) - self.assertIsNone(json.loads(response.body)['value']) + @mock.patch('aiopg.cursor.Cursor.fetchone') + def test_postgres_programming_error(self, fetchone): + fetchone.side_effect = psycopg2.ProgrammingError() + response = self.fetch('/execute?value=1') + self.assertEqual(response.code, 200) + self.assertIsNone(json.loads(response.body)['value']) @mock.patch('aiopg.connection.Connection.cursor') def test_postgres_cursor_raises(self, cursor): cursor.side_effect = asyncio.TimeoutError() response = self.fetch('/execute?value=1') self.assertEqual(response.code, 503) + +""" +class MissingURLTestCase(testing.SprocketsHttpTestCase): + + @classmethod + def setUpClass(cls): + with open('build/test-environment') as f: + for line in f: + if line.startswith('export '): + line = line[7:] + name, _, value = line.strip().partition('=') + if name != 'POSTGRES_URL': + os.environ[name] = value + if 'POSTGRES_URL' in os.environ: + del os.environ['POSTGRES_URL'] + + def setUp(self): + self.stop_mock = None + super().setUp() + + def get_app(self): + self.app = Application() + self.stop_mock = mock.Mock( + wraps=self.app.stop, side_effect=RuntimeError) + return self.app + + def test_that_stop_is_invoked(self): + self.stop_mock.assert_called_once_with(self.io_loop) +"""