WIP refactor

This commit is contained in:
Gavin M. Roy 2020-04-07 13:42:02 -04:00
parent b9c9495545
commit 611dfd1ec7
2 changed files with 277 additions and 159 deletions

View file

@ -1,24 +1,16 @@
"""
:class:`sprockets.http.app.Application` mixin for handling the connection to
Postgres and exporting functions for querying the database, getting the status,
and proving a cursor.
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
and shutdown.
"""
import asyncio import asyncio
import contextlib import contextlib
import functools import dataclasses
import logging import logging
import os import os
import time
import typing import typing
from distutils import util from distutils import util
import aiopg import aiopg
import psycopg2 import psycopg2
from aiopg import pool from aiopg import pool
from psycopg2 import errors, extensions, extras from psycopg2 import errors, extras
from tornado import ioloop, web from tornado import ioloop, web
LOGGER = logging.getLogger('sprockets-postgres') LOGGER = logging.getLogger('sprockets-postgres')
@ -30,13 +22,115 @@ DEFAULT_POSTGRES_JSON = 'FALSE'
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0 DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1 DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120 DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
DEFAULT_POSTGRES_URL = 'postgresql://localhost:5432'
DEFAULT_POSTGRES_UUID = 'TRUE' DEFAULT_POSTGRES_UUID = 'TRUE'
QueryParameters = typing.Union[list, tuple, None]
Timeout = typing.Union[int, float, None]
@dataclasses.dataclass
class QueryResult:
row_count: int
row: typing.Optional[dict]
rows: typing.Optional[typing.List[dict]]
class PostgresConnector:
def __init__(self,
cursor: aiopg.Cursor,
on_error: typing.Callable,
record_duration: typing.Optional[typing.Callable] = None,
timeout: Timeout = None):
self.cursor = cursor
self._on_error = on_error
self._record_duration = record_duration
self._timeout = timeout or int(
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.callproc,
metric_name,
procname=name,
parameters=parameters,
timeout=timeout)
async def execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.execute,
metric_name,
operation=sql,
parameters=parameters,
timeout=timeout)
@contextlib.asynccontextmanager
async def transaction(self) \
-> typing.AsyncContextManager['PostgresConnector']:
async with self.cursor.begin():
yield self
async def _query(self,
method: typing.Callable,
metric_name: str,
**kwargs):
if kwargs['timeout'] is None:
kwargs['timeout'] = self._timeout
start_time = time.monotonic()
try:
await method(**kwargs)
except (asyncio.TimeoutError, psycopg2.Error) as err:
LOGGER.error('Caught %r', err)
exc = self._on_error(metric_name, err)
if exc:
raise exc
finally:
if self._record_duration:
self._record_duration(
metric_name, time.monotonic() - start_time)
return await self._query_results()
async def _query_results(self) -> QueryResult:
row, rows = None, None
if self.cursor.rowcount == 1:
try:
row = dict(await self.cursor.fetchone())
except psycopg2.ProgrammingError:
pass
elif self.cursor.rowcount > 1:
try:
rows = [dict(row) for row in await self.cursor.fetchall()]
except psycopg2.ProgrammingError:
pass
return QueryResult(self.cursor.rowcount, row, rows)
class ConnectionException(Exception):
"""Raised when the connection to Postgres can not be established"""
class ApplicationMixin: class ApplicationMixin:
"""Application mixin for setting up the PostgreSQL client pool""" """
:class:`sprockets.http.app.Application` mixin for handling the connection
to Postgres and exporting functions for querying the database,
getting the status, and proving a cursor.
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
and shutdown.
"""
POSTGRES_STATUS_TIMEOUT = 3 POSTGRES_STATUS_TIMEOUT = 3
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -46,155 +140,57 @@ class ApplicationMixin:
self.runner_callbacks['shutdown'].append(self._postgres_shutdown) self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def postgres_cursor(self, async def postgres_connector(self,
timeout: typing.Optional[int] = None, on_error: typing.Callable,
raise_http_error: bool = True) \ record_duration: typing.Optional[
-> typing.AsyncContextManager[extensions.cursor]: typing.Callable] = None,
"""Return a Postgres cursor for the pool""" timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
try: try:
async with self._postgres_pool.acquire() as conn: async with self._postgres_pool.acquire() as conn:
async with conn.cursor( async with conn.cursor(
cursor_factory=extras.RealDictCursor, cursor_factory=extras.RealDictCursor,
timeout=self._postgres_query_timeout(timeout)) as pgc: timeout=timeout) as cursor:
yield pgc yield PostgresConnector(
except (asyncio.TimeoutError, cursor, on_error, record_duration, timeout)
psycopg2.OperationalError, except (asyncio.TimeoutError, psycopg2.Error) as err:
psycopg2.Error) as error: on_error('postgres_connector', ConnectionException(str(err)))
LOGGER.critical('Error connecting to Postgres: %s', error)
if raise_http_error:
raise web.HTTPError(503, 'Database Unavailable')
raise
async def postgres_callproc(self,
name: str,
params: typing.Union[list, tuple, None] = None,
timeout: typing.Optional[int] = None) \
-> typing.Union[dict, list, None]:
"""Execute a stored procedure, specifying the name, SQL, passing in
optional parameters.
:param name: The stored-proc to call
:param params: Optional parameters to pass into the function
:param timeout: Optional timeout to override the default query timeout
"""
async with self.postgres_cursor(timeout) as cursor:
return await self._postgres_query(
cursor, cursor.callproc, name, name, params)
async def postgres_execute(self, name: str, sql: str,
*args,
timeout: typing.Optional[int] = None) \
-> typing.Union[dict, list, None]:
"""Execute a query, specifying a name for the query, the SQL statement,
and optional positional arguments to pass in with the query.
Parameters may be provided as sequence or mapping and will be
bound to variables in the operation. Variables are specified
either with positional ``%s`` or named ``%({name})s`` placeholders.
:param name: The stored-proc to call
:param sql: The SQL statement to execute
:param timeout: Optional timeout to override the default query timeout
"""
async with self.postgres_cursor(timeout) as cursor:
return await self._postgres_query(
cursor, cursor.execute, name, sql, args)
async def postgres_status(self) -> dict: async def postgres_status(self) -> dict:
"""Invoke from the ``/status`` RequestHandler to check that there is """Invoke from the ``/status`` RequestHandler to check that there is
a Postgres connection handler available and return info about the a Postgres connection handler available and return info about the
pool. pool.
""" """
available = True query_error = asyncio.Event()
try:
async with self.postgres_cursor( def on_error(_metric_name, _exc) -> None:
self.POSTGRES_STATUS_TIMEOUT, False) as cursor: query_error.set()
await cursor.execute('SELECT 1') return None
except (asyncio.TimeoutError, psycopg2.OperationalError):
available = False async with self.postgres_connector(
on_error,
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
await connector.execute('SELECT 1')
return { return {
'available': available, 'available': not query_error.is_set(),
'pool_size': self._postgres_pool.size, 'pool_size': self._postgres_pool.size,
'pool_free': self._postgres_pool.freesize 'pool_free': self._postgres_pool.freesize
} }
async def _postgres_query(self,
cursor: aiopg.Cursor,
method: typing.Callable,
name: str,
sql: str,
parameters: typing.Union[dict, list, tuple]) \
-> typing.Union[dict, list, None]:
"""Execute a query, specifying the name, SQL, passing in
"""
try:
await method(sql, parameters)
except asyncio.TimeoutError as err:
LOGGER.error('Query timeout for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(500, reason='Query Timeout')
except errors.UniqueViolation as err:
LOGGER.error('Database error for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(409, reason='Unique Violation')
except psycopg2.Error as err:
LOGGER.error('Database error for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(500, reason='Database Error')
try:
return await self._postgres_query_results(cursor)
except psycopg2.ProgrammingError:
return
@staticmethod
async def _postgres_query_results(cursor: aiopg.Cursor) \
-> typing.Union[dict, list, None]:
"""Invoked by self.postgres_query to return all of the query results
as either a ``dict`` or ``list`` depending on the quantity of rows.
This can raise a ``psycopg2.ProgrammingError`` for an INSERT/UPDATE
without RETURNING or a DELETE. That exception is caught by the caller.
:raises psycopg2.ProgrammingError: when there are no rows to fetch
even though the rowcount is > 0
"""
if cursor.rowcount == 1:
return await cursor.fetchone()
elif cursor.rowcount > 1:
return await cursor.fetchall()
return None
@functools.lru_cache(2)
def _postgres_query_timeout(self,
timeout: typing.Optional[int] = None) -> int:
"""Return query timeout, either from the specified value or
``POSTGRES_QUERY_TIMEOUT`` environment variable, if set.
Defaults to sprockets_postgres.DEFAULT_POSTGRES_QUERY_TIMEOUT.
"""
return timeout if timeout else int(
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def _postgres_setup(self, async def _postgres_setup(self,
_app: web.Application, _app: web.Application,
_ioloop: ioloop.IOLoop) -> None: loop: ioloop.IOLoop) -> None:
"""Setup the Postgres pool of connections and log if there is an error. """Setup the Postgres pool of connections and log if there is an error.
This is invoked by the Application on start callback mechanism. This is invoked by the Application on start callback mechanism.
""" """
url = os.environ.get('POSTGRES_URL', DEFAULT_POSTGRES_URL) if 'POSTGRES_URL' not in os.environ:
LOGGER.debug('Connecting to PostgreSQL: %s', url) LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
self._postgres_pool = pool.Pool( self._postgres_pool = pool.Pool(
url, os.environ['POSTGRES_URL'],
minsize=int( minsize=int(
os.environ.get( os.environ.get(
'POSTGRES_MIN_POOL_SIZE', 'POSTGRES_MIN_POOL_SIZE',
@ -236,3 +232,90 @@ class ApplicationMixin:
""" """
self._postgres_pool.close() self._postgres_pool.close()
await self._postgres_pool.wait_closed() await self._postgres_pool.wait_closed()
class RequestHandlerMixin:
"""
RequestHandler mixin class exposing functions for querying the database,
recording the duration to either `sprockets-influxdb` or
`sprockets.mixins.metrics`, and handling exceptions.
"""
async def postgres_callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
async with self._postgres_connector(timeout) as connector:
return await connector.callproc(
name, parameters, metric_name, timeout=timeout)
async def postgres_execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
"""Execute a query, specifying a name for the query, the SQL statement,
and optional positional arguments to pass in with the query.
Parameters may be provided as sequence or mapping and will be
bound to variables in the operation. Variables are specified
either with positional ``%s`` or named ``%({name})s`` placeholders.
"""
async with self._postgres_connector(timeout) as connector:
return await connector.execute(
sql, parameters, metric_name, timeout=timeout)
@contextlib.asynccontextmanager
async def postgres_transaction(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
"""Yields a :class:`PostgresConnector` instance in a transaction.
Will automatically commit or rollback based upon exception.
"""
async with self._postgres_connector(timeout) as connector:
async with connector.transaction():
yield connector
@contextlib.asynccontextmanager
async def _postgres_connector(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
async with self.application.postgres_connector(
self.__on_postgres_error,
self.__on_postgres_timing,
timeout) as connector:
yield connector
def __on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
"""Override for different error handling behaviors"""
LOGGER.error('%s in %s for %s (%s)',
exc.__class__.__name__,
self.__class__.__name__,
metric_name,
str(exc).split('\n')[0])
if isinstance(exc, ConnectionException):
raise web.HTTPError(503, reason='Database Connection Error')
elif isinstance(exc, asyncio.TimeoutError):
raise web.HTTPError(500, reason='Query Timeout')
elif isinstance(exc, errors.UniqueViolation):
raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.Error):
raise web.HTTPError(500, reason='Database Error')
return exc
def __on_postgres_timing(self,
metric_name: str,
duration: float) -> None:
"""Override for custom metric recording"""
if hasattr(self, 'influxdb'): # sprockets-influxdb
self.influxdb.set_field(metric_name, duration)
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
self.record_timing(metric_name, duration)
else:
LOGGER.debug('Postgres query %s duration: %s',
metric_name, duration)

View file

@ -12,40 +12,46 @@ from tornado import web
import sprockets_postgres import sprockets_postgres
class CallprocRequestHandler(web.RequestHandler): class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
async def get(self): async def get(self):
result = await self.application.postgres_callproc('uuid_generate_v4') result = await self.postgres_callproc(
await self.finish({'value': str(result['uuid_generate_v4'])}) 'uuid_generate_v4', metric_name='uuid')
await self.finish({'value': str(result.row['uuid_generate_v4'])})
class ExecuteRequestHandler(web.RequestHandler): class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = 'SELECT %s::TEXT AS value;' GET_SQL = 'SELECT %s::TEXT AS value;'
async def get(self): async def get(self):
result = await self.application.postgres_execute( result = await self.postgres_execute(
'get', self.GET_SQL, self.get_argument('value')) self.GET_SQL, [self.get_argument('value')], 'get')
await self.finish({'value': result['value'] if result else None}) await self.finish({
'value': result.row['value'] if result.row else None})
class MultiRowRequestHandler(web.RequestHandler): class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;' GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
async def get(self): async def get(self):
rows = await self.application.postgres_execute('get', self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': [row['role_name'] for row in rows]}) await self.finish({'rows': [row['role_name'] for row in result.rows]})
class NoRowRequestHandler(web.RequestHandler): class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = """\ GET_SQL = """\
SELECT * FROM information_schema.tables WHERE table_schema = 'public';""" SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
async def get(self): async def get(self):
rows = await self.application.postgres_execute('get', self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': rows}) await self.finish({'rows': result.rows})
class StatusRequestHandler(web.RequestHandler): class StatusRequestHandler(web.RequestHandler):
@ -62,7 +68,7 @@ class Application(sprockets_postgres.ApplicationMixin,
pass pass
class ExecuteTestCase(testing.SprocketsHttpTestCase): class TestCase(testing.SprocketsHttpTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -143,15 +149,44 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase):
self.assertEqual(response.code, 500) self.assertEqual(response.code, 500)
self.assertIn(b'Database Error', response.body) self.assertIn(b'Database Error', response.body)
def test_postgres_programming_error(self): @mock.patch('aiopg.cursor.Cursor.fetchone')
with mock.patch.object(self.app, '_postgres_query_results') as pqr: def test_postgres_programming_error(self, fetchone):
pqr.side_effect = psycopg2.ProgrammingError() fetchone.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/execute?value=1') response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
self.assertIsNone(json.loads(response.body)['value']) self.assertIsNone(json.loads(response.body)['value'])
@mock.patch('aiopg.connection.Connection.cursor') @mock.patch('aiopg.connection.Connection.cursor')
def test_postgres_cursor_raises(self, cursor): def test_postgres_cursor_raises(self, cursor):
cursor.side_effect = asyncio.TimeoutError() cursor.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1') response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503) self.assertEqual(response.code, 503)
"""
class MissingURLTestCase(testing.SprocketsHttpTestCase):
@classmethod
def setUpClass(cls):
with open('build/test-environment') as f:
for line in f:
if line.startswith('export '):
line = line[7:]
name, _, value = line.strip().partition('=')
if name != 'POSTGRES_URL':
os.environ[name] = value
if 'POSTGRES_URL' in os.environ:
del os.environ['POSTGRES_URL']
def setUp(self):
self.stop_mock = None
super().setUp()
def get_app(self):
self.app = Application()
self.stop_mock = mock.Mock(
wraps=self.app.stop, side_effect=RuntimeError)
return self.app
def test_that_stop_is_invoked(self):
self.stop_mock.assert_called_once_with(self.io_loop)
"""