mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-11-14 03:00:19 +00:00
321 lines
12 KiB
Python
321 lines
12 KiB
Python
import asyncio
|
|
import contextlib
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
import time
|
|
import typing
|
|
from distutils import util
|
|
|
|
import aiopg
|
|
import psycopg2
|
|
from aiopg import pool
|
|
from psycopg2 import errors, extras
|
|
from tornado import ioloop, web
|
|
|
|
LOGGER = logging.getLogger('sprockets-postgres')
|
|
|
|
DEFAULT_POSTGRES_CONNECTION_TIMEOUT = 10
|
|
DEFAULT_POSTGRES_CONNECTION_TTL = 300
|
|
DEFAULT_POSTGRES_HSTORE = 'FALSE'
|
|
DEFAULT_POSTGRES_JSON = 'FALSE'
|
|
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
|
|
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
|
|
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
|
DEFAULT_POSTGRES_UUID = 'TRUE'
|
|
|
|
QueryParameters = typing.Union[list, tuple, None]
|
|
Timeout = typing.Union[int, float, None]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class QueryResult:
|
|
row_count: int
|
|
row: typing.Optional[dict]
|
|
rows: typing.Optional[typing.List[dict]]
|
|
|
|
|
|
class PostgresConnector:
|
|
|
|
def __init__(self,
|
|
cursor: aiopg.Cursor,
|
|
on_error: typing.Callable,
|
|
record_duration: typing.Optional[typing.Callable] = None,
|
|
timeout: Timeout = None):
|
|
self.cursor = cursor
|
|
self._on_error = on_error
|
|
self._record_duration = record_duration
|
|
self._timeout = timeout or int(
|
|
os.environ.get(
|
|
'POSTGRES_QUERY_TIMEOUT',
|
|
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
|
|
|
async def callproc(self,
|
|
name: str,
|
|
parameters: QueryParameters = None,
|
|
metric_name: str = '',
|
|
*,
|
|
timeout: Timeout = None) -> QueryResult:
|
|
return await self._query(
|
|
self.cursor.callproc,
|
|
metric_name,
|
|
procname=name,
|
|
parameters=parameters,
|
|
timeout=timeout)
|
|
|
|
async def execute(self,
|
|
sql: str,
|
|
parameters: QueryParameters = None,
|
|
metric_name: str = '',
|
|
*,
|
|
timeout: Timeout = None) -> QueryResult:
|
|
return await self._query(
|
|
self.cursor.execute,
|
|
metric_name,
|
|
operation=sql,
|
|
parameters=parameters,
|
|
timeout=timeout)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def transaction(self) \
|
|
-> typing.AsyncContextManager['PostgresConnector']:
|
|
async with self.cursor.begin():
|
|
yield self
|
|
|
|
async def _query(self,
|
|
method: typing.Callable,
|
|
metric_name: str,
|
|
**kwargs):
|
|
if kwargs['timeout'] is None:
|
|
kwargs['timeout'] = self._timeout
|
|
start_time = time.monotonic()
|
|
try:
|
|
await method(**kwargs)
|
|
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
|
LOGGER.error('Caught %r', err)
|
|
exc = self._on_error(metric_name, err)
|
|
if exc:
|
|
raise exc
|
|
finally:
|
|
if self._record_duration:
|
|
self._record_duration(
|
|
metric_name, time.monotonic() - start_time)
|
|
return await self._query_results()
|
|
|
|
async def _query_results(self) -> QueryResult:
|
|
row, rows = None, None
|
|
if self.cursor.rowcount == 1:
|
|
try:
|
|
row = dict(await self.cursor.fetchone())
|
|
except psycopg2.ProgrammingError:
|
|
pass
|
|
elif self.cursor.rowcount > 1:
|
|
try:
|
|
rows = [dict(row) for row in await self.cursor.fetchall()]
|
|
except psycopg2.ProgrammingError:
|
|
pass
|
|
return QueryResult(self.cursor.rowcount, row, rows)
|
|
|
|
|
|
class ConnectionException(Exception):
|
|
"""Raised when the connection to Postgres can not be established"""
|
|
|
|
|
|
class ApplicationMixin:
|
|
"""
|
|
:class:`sprockets.http.app.Application` mixin for handling the connection
|
|
to Postgres and exporting functions for querying the database,
|
|
getting the status, and proving a cursor.
|
|
|
|
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
|
|
and shutdown.
|
|
|
|
"""
|
|
POSTGRES_STATUS_TIMEOUT = 3
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._postgres_pool: typing.Optional[pool.Pool] = None
|
|
self.runner_callbacks['on_start'].append(self._postgres_setup)
|
|
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def postgres_connector(self,
|
|
on_error: typing.Callable,
|
|
record_duration: typing.Optional[
|
|
typing.Callable] = None,
|
|
timeout: Timeout = None) \
|
|
-> typing.AsyncContextManager[PostgresConnector]:
|
|
try:
|
|
async with self._postgres_pool.acquire() as conn:
|
|
async with conn.cursor(
|
|
cursor_factory=extras.RealDictCursor,
|
|
timeout=timeout) as cursor:
|
|
yield PostgresConnector(
|
|
cursor, on_error, record_duration, timeout)
|
|
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
|
on_error('postgres_connector', ConnectionException(str(err)))
|
|
|
|
async def postgres_status(self) -> dict:
|
|
"""Invoke from the ``/status`` RequestHandler to check that there is
|
|
a Postgres connection handler available and return info about the
|
|
pool.
|
|
|
|
"""
|
|
query_error = asyncio.Event()
|
|
|
|
def on_error(_metric_name, _exc) -> None:
|
|
query_error.set()
|
|
return None
|
|
|
|
async with self.postgres_connector(
|
|
on_error,
|
|
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
|
|
await connector.execute('SELECT 1')
|
|
return {
|
|
'available': not query_error.is_set(),
|
|
'pool_size': self._postgres_pool.size,
|
|
'pool_free': self._postgres_pool.freesize
|
|
}
|
|
|
|
async def _postgres_setup(self,
|
|
_app: web.Application,
|
|
loop: ioloop.IOLoop) -> None:
|
|
"""Setup the Postgres pool of connections and log if there is an error.
|
|
|
|
This is invoked by the Application on start callback mechanism.
|
|
|
|
"""
|
|
if 'POSTGRES_URL' not in os.environ:
|
|
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
|
return self.stop(loop)
|
|
self._postgres_pool = pool.Pool(
|
|
os.environ['POSTGRES_URL'],
|
|
minsize=int(
|
|
os.environ.get(
|
|
'POSTGRES_MIN_POOL_SIZE',
|
|
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
|
|
maxsize=int(
|
|
os.environ.get(
|
|
'POSTGRES_MAX_POOL_SIZE',
|
|
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
|
|
timeout=int(
|
|
os.environ.get(
|
|
'POSTGRES_CONNECT_TIMEOUT',
|
|
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
|
|
enable_hstore=util.strtobool(
|
|
os.environ.get(
|
|
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
|
|
enable_json=util.strtobool(
|
|
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
|
|
enable_uuid=util.strtobool(
|
|
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
|
|
echo=False,
|
|
on_connect=None,
|
|
pool_recycle=int(
|
|
os.environ.get(
|
|
'POSTGRES_CONNECTION_TTL',
|
|
DEFAULT_POSTGRES_CONNECTION_TTL)))
|
|
try:
|
|
async with self._postgres_pool._cond:
|
|
await self._postgres_pool._fill_free_pool(False)
|
|
except (psycopg2.OperationalError,
|
|
psycopg2.Error) as error: # pragma: nocover
|
|
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
|
error)
|
|
|
|
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
|
"""Shutdown the Postgres connections and wait for them to close.
|
|
|
|
This is invoked by the Application shutdown callback mechanism.
|
|
|
|
"""
|
|
self._postgres_pool.close()
|
|
await self._postgres_pool.wait_closed()
|
|
|
|
|
|
class RequestHandlerMixin:
|
|
"""
|
|
RequestHandler mixin class exposing functions for querying the database,
|
|
recording the duration to either `sprockets-influxdb` or
|
|
`sprockets.mixins.metrics`, and handling exceptions.
|
|
|
|
"""
|
|
async def postgres_callproc(self,
|
|
name: str,
|
|
parameters: QueryParameters = None,
|
|
metric_name: str = '',
|
|
*,
|
|
timeout: Timeout = None) -> QueryResult:
|
|
async with self._postgres_connector(timeout) as connector:
|
|
return await connector.callproc(
|
|
name, parameters, metric_name, timeout=timeout)
|
|
|
|
async def postgres_execute(self,
|
|
sql: str,
|
|
parameters: QueryParameters = None,
|
|
metric_name: str = '',
|
|
*,
|
|
timeout: Timeout = None) -> QueryResult:
|
|
"""Execute a query, specifying a name for the query, the SQL statement,
|
|
and optional positional arguments to pass in with the query.
|
|
|
|
Parameters may be provided as sequence or mapping and will be
|
|
bound to variables in the operation. Variables are specified
|
|
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
|
|
|
"""
|
|
async with self._postgres_connector(timeout) as connector:
|
|
return await connector.execute(
|
|
sql, parameters, metric_name, timeout=timeout)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def postgres_transaction(self, timeout: Timeout = None) \
|
|
-> typing.AsyncContextManager[PostgresConnector]:
|
|
"""Yields a :class:`PostgresConnector` instance in a transaction.
|
|
Will automatically commit or rollback based upon exception.
|
|
|
|
"""
|
|
async with self._postgres_connector(timeout) as connector:
|
|
async with connector.transaction():
|
|
yield connector
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _postgres_connector(self, timeout: Timeout = None) \
|
|
-> typing.AsyncContextManager[PostgresConnector]:
|
|
async with self.application.postgres_connector(
|
|
self.__on_postgres_error,
|
|
self.__on_postgres_timing,
|
|
timeout) as connector:
|
|
yield connector
|
|
|
|
def __on_postgres_error(self,
|
|
metric_name: str,
|
|
exc: Exception) -> typing.Optional[Exception]:
|
|
"""Override for different error handling behaviors"""
|
|
LOGGER.error('%s in %s for %s (%s)',
|
|
exc.__class__.__name__,
|
|
self.__class__.__name__,
|
|
metric_name,
|
|
str(exc).split('\n')[0])
|
|
if isinstance(exc, ConnectionException):
|
|
raise web.HTTPError(503, reason='Database Connection Error')
|
|
elif isinstance(exc, asyncio.TimeoutError):
|
|
raise web.HTTPError(500, reason='Query Timeout')
|
|
elif isinstance(exc, errors.UniqueViolation):
|
|
raise web.HTTPError(409, reason='Unique Violation')
|
|
elif isinstance(exc, psycopg2.Error):
|
|
raise web.HTTPError(500, reason='Database Error')
|
|
return exc
|
|
|
|
def __on_postgres_timing(self,
|
|
metric_name: str,
|
|
duration: float) -> None:
|
|
"""Override for custom metric recording"""
|
|
if hasattr(self, 'influxdb'): # sprockets-influxdb
|
|
self.influxdb.set_field(metric_name, duration)
|
|
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
|
|
self.record_timing(metric_name, duration)
|
|
else:
|
|
LOGGER.debug('Postgres query %s duration: %s',
|
|
metric_name, duration)
|