sprockets-postgres/sprockets_postgres.py
2020-04-07 13:42:02 -04:00

321 lines
12 KiB
Python

import asyncio
import contextlib
import dataclasses
import logging
import os
import time
import typing
from distutils import util
import aiopg
import psycopg2
from aiopg import pool
from psycopg2 import errors, extras
from tornado import ioloop, web
LOGGER = logging.getLogger('sprockets-postgres')
DEFAULT_POSTGRES_CONNECTION_TIMEOUT = 10
DEFAULT_POSTGRES_CONNECTION_TTL = 300
DEFAULT_POSTGRES_HSTORE = 'FALSE'
DEFAULT_POSTGRES_JSON = 'FALSE'
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
DEFAULT_POSTGRES_UUID = 'TRUE'
QueryParameters = typing.Union[list, tuple, None]
Timeout = typing.Union[int, float, None]
@dataclasses.dataclass
class QueryResult:
row_count: int
row: typing.Optional[dict]
rows: typing.Optional[typing.List[dict]]
class PostgresConnector:
def __init__(self,
cursor: aiopg.Cursor,
on_error: typing.Callable,
record_duration: typing.Optional[typing.Callable] = None,
timeout: Timeout = None):
self.cursor = cursor
self._on_error = on_error
self._record_duration = record_duration
self._timeout = timeout or int(
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.callproc,
metric_name,
procname=name,
parameters=parameters,
timeout=timeout)
async def execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
return await self._query(
self.cursor.execute,
metric_name,
operation=sql,
parameters=parameters,
timeout=timeout)
@contextlib.asynccontextmanager
async def transaction(self) \
-> typing.AsyncContextManager['PostgresConnector']:
async with self.cursor.begin():
yield self
async def _query(self,
method: typing.Callable,
metric_name: str,
**kwargs):
if kwargs['timeout'] is None:
kwargs['timeout'] = self._timeout
start_time = time.monotonic()
try:
await method(**kwargs)
except (asyncio.TimeoutError, psycopg2.Error) as err:
LOGGER.error('Caught %r', err)
exc = self._on_error(metric_name, err)
if exc:
raise exc
finally:
if self._record_duration:
self._record_duration(
metric_name, time.monotonic() - start_time)
return await self._query_results()
async def _query_results(self) -> QueryResult:
row, rows = None, None
if self.cursor.rowcount == 1:
try:
row = dict(await self.cursor.fetchone())
except psycopg2.ProgrammingError:
pass
elif self.cursor.rowcount > 1:
try:
rows = [dict(row) for row in await self.cursor.fetchall()]
except psycopg2.ProgrammingError:
pass
return QueryResult(self.cursor.rowcount, row, rows)
class ConnectionException(Exception):
"""Raised when the connection to Postgres can not be established"""
class ApplicationMixin:
"""
:class:`sprockets.http.app.Application` mixin for handling the connection
to Postgres and exporting functions for querying the database,
getting the status, and proving a cursor.
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
and shutdown.
"""
POSTGRES_STATUS_TIMEOUT = 3
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._postgres_pool: typing.Optional[pool.Pool] = None
self.runner_callbacks['on_start'].append(self._postgres_setup)
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager
async def postgres_connector(self,
on_error: typing.Callable,
record_duration: typing.Optional[
typing.Callable] = None,
timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
try:
async with self._postgres_pool.acquire() as conn:
async with conn.cursor(
cursor_factory=extras.RealDictCursor,
timeout=timeout) as cursor:
yield PostgresConnector(
cursor, on_error, record_duration, timeout)
except (asyncio.TimeoutError, psycopg2.Error) as err:
on_error('postgres_connector', ConnectionException(str(err)))
async def postgres_status(self) -> dict:
"""Invoke from the ``/status`` RequestHandler to check that there is
a Postgres connection handler available and return info about the
pool.
"""
query_error = asyncio.Event()
def on_error(_metric_name, _exc) -> None:
query_error.set()
return None
async with self.postgres_connector(
on_error,
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
await connector.execute('SELECT 1')
return {
'available': not query_error.is_set(),
'pool_size': self._postgres_pool.size,
'pool_free': self._postgres_pool.freesize
}
async def _postgres_setup(self,
_app: web.Application,
loop: ioloop.IOLoop) -> None:
"""Setup the Postgres pool of connections and log if there is an error.
This is invoked by the Application on start callback mechanism.
"""
if 'POSTGRES_URL' not in os.environ:
LOGGER.critical('Missing POSTGRES_URL environment variable')
return self.stop(loop)
self._postgres_pool = pool.Pool(
os.environ['POSTGRES_URL'],
minsize=int(
os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
maxsize=int(
os.environ.get(
'POSTGRES_MAX_POOL_SIZE',
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
timeout=int(
os.environ.get(
'POSTGRES_CONNECT_TIMEOUT',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
enable_hstore=util.strtobool(
os.environ.get(
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
enable_json=util.strtobool(
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
enable_uuid=util.strtobool(
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
echo=False,
on_connect=None,
pool_recycle=int(
os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TTL)))
try:
async with self._postgres_pool._cond:
await self._postgres_pool._fill_free_pool(False)
except (psycopg2.OperationalError,
psycopg2.Error) as error: # pragma: nocover
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
error)
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
"""Shutdown the Postgres connections and wait for them to close.
This is invoked by the Application shutdown callback mechanism.
"""
self._postgres_pool.close()
await self._postgres_pool.wait_closed()
class RequestHandlerMixin:
"""
RequestHandler mixin class exposing functions for querying the database,
recording the duration to either `sprockets-influxdb` or
`sprockets.mixins.metrics`, and handling exceptions.
"""
async def postgres_callproc(self,
name: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
async with self._postgres_connector(timeout) as connector:
return await connector.callproc(
name, parameters, metric_name, timeout=timeout)
async def postgres_execute(self,
sql: str,
parameters: QueryParameters = None,
metric_name: str = '',
*,
timeout: Timeout = None) -> QueryResult:
"""Execute a query, specifying a name for the query, the SQL statement,
and optional positional arguments to pass in with the query.
Parameters may be provided as sequence or mapping and will be
bound to variables in the operation. Variables are specified
either with positional ``%s`` or named ``%({name})s`` placeholders.
"""
async with self._postgres_connector(timeout) as connector:
return await connector.execute(
sql, parameters, metric_name, timeout=timeout)
@contextlib.asynccontextmanager
async def postgres_transaction(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
"""Yields a :class:`PostgresConnector` instance in a transaction.
Will automatically commit or rollback based upon exception.
"""
async with self._postgres_connector(timeout) as connector:
async with connector.transaction():
yield connector
@contextlib.asynccontextmanager
async def _postgres_connector(self, timeout: Timeout = None) \
-> typing.AsyncContextManager[PostgresConnector]:
async with self.application.postgres_connector(
self.__on_postgres_error,
self.__on_postgres_timing,
timeout) as connector:
yield connector
def __on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
"""Override for different error handling behaviors"""
LOGGER.error('%s in %s for %s (%s)',
exc.__class__.__name__,
self.__class__.__name__,
metric_name,
str(exc).split('\n')[0])
if isinstance(exc, ConnectionException):
raise web.HTTPError(503, reason='Database Connection Error')
elif isinstance(exc, asyncio.TimeoutError):
raise web.HTTPError(500, reason='Query Timeout')
elif isinstance(exc, errors.UniqueViolation):
raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.Error):
raise web.HTTPError(500, reason='Database Error')
return exc
def __on_postgres_timing(self,
metric_name: str,
duration: float) -> None:
"""Override for custom metric recording"""
if hasattr(self, 'influxdb'): # sprockets-influxdb
self.influxdb.set_field(metric_name, duration)
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
self.record_timing(metric_name, duration)
else:
LOGGER.debug('Postgres query %s duration: %s',
metric_name, duration)