WIP refactor

This commit is contained in:
Gavin M. Roy 2020-04-07 13:42:02 -04:00
parent b9c9495545
commit 611dfd1ec7
2 changed files with 277 additions and 159 deletions

View file

@ -1,24 +1,16 @@
"""
:class:`sprockets.http.app.Application` mixin for handling the connection to
Postgres and exporting functions for querying the database, getting the status,
and proving a cursor.
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
and shutdown.
"""
import asyncio
import contextlib
import functools
import dataclasses
import logging
import os
import time
import typing
from distutils import util
import aiopg
import psycopg2
from aiopg import pool
from psycopg2 import errors, extensions, extras
from psycopg2 import errors, extras
from tornado import ioloop, web
LOGGER = logging.getLogger('sprockets-postgres')
@ -30,13 +22,115 @@ DEFAULT_POSTGRES_JSON = 'FALSE'
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
DEFAULT_POSTGRES_URL = 'postgresql://localhost:5432'
DEFAULT_POSTGRES_UUID = 'TRUE'
QueryParameters = typing.Union[list, tuple, None]
Timeout = typing.Union[int, float, None]
@dataclasses.dataclass
class QueryResult:
row_count: int
row: typing.Optional[dict]
rows: typing.Optional[typing.List[dict]]
class PostgresConnector:
def __init__(self,
cursor: aiopg.Cursor,
on_error: typing.Callable,
record_duration: typing.Optional[typing.Callable] = None,
timeout: Timeout = None):
self.cursor = cursor
self._on_error = on_error
self._record_duration = record_duration
self._timeout = timeout or int(
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.callproc,
metric_name,
procname=name,
parameters=parameters,
timeout=timeout)
async def execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.execute,
metric_name,
operation=sql,
parameters=parameters,
timeout=timeout)
@contextlib.asynccontextmanager
async def transaction(self) \
-> typing.AsyncContextManager['PostgresConnector']:
async with self.cursor.begin():
yield self
async def _query(self,
method: typing.Callable,
metric_name: str,
**kwargs):
if kwargs['timeout'] is None:
kwargs['timeout'] = self._timeout
start_time = time.monotonic()
try:
await method(**kwargs)
except (asyncio.TimeoutError, psycopg2.Error) as err:
LOGGER.error('Caught %r', err)
exc = self._on_error(metric_name, err)
if exc:
raise exc
finally:
if self._record_duration:
self._record_duration(
metric_name, time.monotonic() - start_time)
return await self._query_results()
async def _query_results(self) -> QueryResult:
row, rows = None, None
if self.cursor.rowcount == 1:
try:
row = dict(await self.cursor.fetchone())
except psycopg2.ProgrammingError:
pass
elif self.cursor.rowcount > 1:
try:
rows = [dict(row) for row in await self.cursor.fetchall()]
except psycopg2.ProgrammingError:
pass
return QueryResult(self.cursor.rowcount, row, rows)
class ConnectionException(Exception):
"""Raised when the connection to Postgres can not be established"""
class ApplicationMixin:
"""Application mixin for setting up the PostgreSQL client pool"""
"""
:class:`sprockets.http.app.Application` mixin for handling the connection
to Postgres and exporting functions for querying the database,
getting the status, and proving a cursor.
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
and shutdown.
"""
POSTGRES_STATUS_TIMEOUT = 3
def __init__(self, *args, **kwargs):
@ -46,61 +140,21 @@ class ApplicationMixin:
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager
async def postgres_cursor(self,
timeout: typing.Optional[int] = None,
raise_http_error: bool = True) \
-> typing.AsyncContextManager[extensions.cursor]:
"""Return a Postgres cursor for the pool"""
async def postgres_connector(self,
on_error: typing.Callable,
record_duration: typing.Optional[
typing.Callable] = None,
timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
try:
async with self._postgres_pool.acquire() as conn:
async with conn.cursor(
cursor_factory=extras.RealDictCursor,
timeout=self._postgres_query_timeout(timeout)) as pgc:
yield pgc
except (asyncio.TimeoutError,
psycopg2.OperationalError,
psycopg2.Error) as error:
LOGGER.critical('Error connecting to Postgres: %s', error)
if raise_http_error:
raise web.HTTPError(503, 'Database Unavailable')
raise
async def postgres_callproc(self,
name: str,
params: typing.Union[list, tuple, None] = None,
timeout: typing.Optional[int] = None) \
-> typing.Union[dict, list, None]:
"""Execute a stored procedure, specifying the name, SQL, passing in
optional parameters.
:param name: The stored-proc to call
:param params: Optional parameters to pass into the function
:param timeout: Optional timeout to override the default query timeout
"""
async with self.postgres_cursor(timeout) as cursor:
return await self._postgres_query(
cursor, cursor.callproc, name, name, params)
async def postgres_execute(self, name: str, sql: str,
*args,
timeout: typing.Optional[int] = None) \
-> typing.Union[dict, list, None]:
"""Execute a query, specifying a name for the query, the SQL statement,
and optional positional arguments to pass in with the query.
Parameters may be provided as sequence or mapping and will be
bound to variables in the operation. Variables are specified
either with positional ``%s`` or named ``%({name})s`` placeholders.
:param name: The stored-proc to call
:param sql: The SQL statement to execute
:param timeout: Optional timeout to override the default query timeout
"""
async with self.postgres_cursor(timeout) as cursor:
return await self._postgres_query(
cursor, cursor.execute, name, sql, args)
timeout=timeout) as cursor:
yield PostgresConnector(
cursor, on_error, record_duration, timeout)
except (asyncio.TimeoutError, psycopg2.Error) as err:
on_error('postgres_connector', ConnectionException(str(err)))
async def postgres_status(self) -> dict:
"""Invoke from the ``/status`` RequestHandler to check that there is
@ -108,93 +162,35 @@ class ApplicationMixin:
pool.
"""
available = True
try:
async with self.postgres_cursor(
self.POSTGRES_STATUS_TIMEOUT, False) as cursor:
await cursor.execute('SELECT 1')
except (asyncio.TimeoutError, psycopg2.OperationalError):
available = False
query_error = asyncio.Event()
def on_error(_metric_name, _exc) -> None:
query_error.set()
return None
async with self.postgres_connector(
on_error,
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
await connector.execute('SELECT 1')
return {
'available': available,
'available': not query_error.is_set(),
'pool_size': self._postgres_pool.size,
'pool_free': self._postgres_pool.freesize
}
async def _postgres_query(self,
cursor: aiopg.Cursor,
method: typing.Callable,
name: str,
sql: str,
parameters: typing.Union[dict, list, tuple]) \
-> typing.Union[dict, list, None]:
"""Execute a query, specifying the name, SQL, passing in
"""
try:
await method(sql, parameters)
except asyncio.TimeoutError as err:
LOGGER.error('Query timeout for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(500, reason='Query Timeout')
except errors.UniqueViolation as err:
LOGGER.error('Database error for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(409, reason='Unique Violation')
except psycopg2.Error as err:
LOGGER.error('Database error for %s: %s',
name, str(err).split('\n')[0])
raise web.HTTPError(500, reason='Database Error')
try:
return await self._postgres_query_results(cursor)
except psycopg2.ProgrammingError:
return
@staticmethod
async def _postgres_query_results(cursor: aiopg.Cursor) \
-> typing.Union[dict, list, None]:
"""Invoked by self.postgres_query to return all of the query results
as either a ``dict`` or ``list`` depending on the quantity of rows.
This can raise a ``psycopg2.ProgrammingError`` for an INSERT/UPDATE
without RETURNING or a DELETE. That exception is caught by the caller.
:raises psycopg2.ProgrammingError: when there are no rows to fetch
even though the rowcount is > 0
"""
if cursor.rowcount == 1:
return await cursor.fetchone()
elif cursor.rowcount > 1:
return await cursor.fetchall()
return None
@functools.lru_cache(2)
def _postgres_query_timeout(self,
timeout: typing.Optional[int] = None) -> int:
"""Return query timeout, either from the specified value or
``POSTGRES_QUERY_TIMEOUT`` environment variable, if set.
Defaults to sprockets_postgres.DEFAULT_POSTGRES_QUERY_TIMEOUT.
"""
return timeout if timeout else int(
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def _postgres_setup(self,
_app: web.Application,
_ioloop: ioloop.IOLoop) -> None:
loop: ioloop.IOLoop) -> None:
"""Setup the Postgres pool of connections and log if there is an error.
This is invoked by the Application on start callback mechanism.
"""
url = os.environ.get('POSTGRES_URL', DEFAULT_POSTGRES_URL)
LOGGER.debug('Connecting to PostgreSQL: %s', url)
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
self._postgres_pool = pool.Pool(
url,
os.environ['POSTGRES_URL'],
minsize=int(
os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
@ -236,3 +232,90 @@ class ApplicationMixin:
"""
self._postgres_pool.close()
await self._postgres_pool.wait_closed()
class RequestHandlerMixin:
"""
RequestHandler mixin class exposing functions for querying the database,
recording the duration to either `sprockets-influxdb` or
`sprockets.mixins.metrics`, and handling exceptions.
"""
async def postgres_callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
async with self._postgres_connector(timeout) as connector:
return await connector.callproc(
name, parameters, metric_name, timeout=timeout)
async def postgres_execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
"""Execute a query, specifying a name for the query, the SQL statement,
and optional positional arguments to pass in with the query.
Parameters may be provided as sequence or mapping and will be
bound to variables in the operation. Variables are specified
either with positional ``%s`` or named ``%({name})s`` placeholders.
"""
async with self._postgres_connector(timeout) as connector:
return await connector.execute(
sql, parameters, metric_name, timeout=timeout)
@contextlib.asynccontextmanager
async def postgres_transaction(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
"""Yields a :class:`PostgresConnector` instance in a transaction.
Will automatically commit or rollback based upon exception.
"""
async with self._postgres_connector(timeout) as connector:
async with connector.transaction():
yield connector
@contextlib.asynccontextmanager
async def _postgres_connector(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
async with self.application.postgres_connector(
self.__on_postgres_error,
self.__on_postgres_timing,
timeout) as connector:
yield connector
def __on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
"""Override for different error handling behaviors"""
LOGGER.error('%s in %s for %s (%s)',
exc.__class__.__name__,
self.__class__.__name__,
metric_name,
str(exc).split('\n')[0])
if isinstance(exc, ConnectionException):
raise web.HTTPError(503, reason='Database Connection Error')
elif isinstance(exc, asyncio.TimeoutError):
raise web.HTTPError(500, reason='Query Timeout')
elif isinstance(exc, errors.UniqueViolation):
raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.Error):
raise web.HTTPError(500, reason='Database Error')
return exc
def __on_postgres_timing(self,
metric_name: str,
duration: float) -> None:
"""Override for custom metric recording"""
if hasattr(self, 'influxdb'): # sprockets-influxdb
self.influxdb.set_field(metric_name, duration)
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
self.record_timing(metric_name, duration)
else:
LOGGER.debug('Postgres query %s duration: %s',
metric_name, duration)

View file

@ -12,40 +12,46 @@ from tornado import web
import sprockets_postgres
class CallprocRequestHandler(web.RequestHandler):
class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
async def get(self):
result = await self.application.postgres_callproc('uuid_generate_v4')
await self.finish({'value': str(result['uuid_generate_v4'])})
result = await self.postgres_callproc(
'uuid_generate_v4', metric_name='uuid')
await self.finish({'value': str(result.row['uuid_generate_v4'])})
class ExecuteRequestHandler(web.RequestHandler):
class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = 'SELECT %s::TEXT AS value;'
async def get(self):
result = await self.application.postgres_execute(
'get', self.GET_SQL, self.get_argument('value'))
await self.finish({'value': result['value'] if result else None})
result = await self.postgres_execute(
self.GET_SQL, [self.get_argument('value')], 'get')
await self.finish({
'value': result.row['value'] if result.row else None})
class MultiRowRequestHandler(web.RequestHandler):
class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
async def get(self):
rows = await self.application.postgres_execute('get', self.GET_SQL)
await self.finish({'rows': [row['role_name'] for row in rows]})
result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': [row['role_name'] for row in result.rows]})
class NoRowRequestHandler(web.RequestHandler):
class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler):
GET_SQL = """\
SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
async def get(self):
rows = await self.application.postgres_execute('get', self.GET_SQL)
await self.finish({'rows': rows})
result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': result.rows})
class StatusRequestHandler(web.RequestHandler):
@ -62,7 +68,7 @@ class Application(sprockets_postgres.ApplicationMixin,
pass
class ExecuteTestCase(testing.SprocketsHttpTestCase):
class TestCase(testing.SprocketsHttpTestCase):
@classmethod
def setUpClass(cls):
@ -143,9 +149,9 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase):
self.assertEqual(response.code, 500)
self.assertIn(b'Database Error', response.body)
def test_postgres_programming_error(self):
with mock.patch.object(self.app, '_postgres_query_results') as pqr:
pqr.side_effect = psycopg2.ProgrammingError()
@mock.patch('aiopg.cursor.Cursor.fetchone')
def test_postgres_programming_error(self, fetchone):
fetchone.side_effect = psycopg2.ProgrammingError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 200)
self.assertIsNone(json.loads(response.body)['value'])
@ -155,3 +161,32 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase):
cursor.side_effect = asyncio.TimeoutError()
response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503)
"""
class MissingURLTestCase(testing.SprocketsHttpTestCase):
@classmethod
def setUpClass(cls):
with open('build/test-environment') as f:
for line in f:
if line.startswith('export '):
line = line[7:]
name, _, value = line.strip().partition('=')
if name != 'POSTGRES_URL':
os.environ[name] = value
if 'POSTGRES_URL' in os.environ:
del os.environ['POSTGRES_URL']
def setUp(self):
self.stop_mock = None
super().setUp()
def get_app(self):
self.app = Application()
self.stop_mock = mock.Mock(
wraps=self.app.stop, side_effect=RuntimeError)
return self.app
def test_that_stop_is_invoked(self):
self.stop_mock.assert_called_once_with(self.io_loop)
"""