mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-12-27 11:17:30 +00:00
WIP refactor
This commit is contained in:
parent
b9c9495545
commit
611dfd1ec7
2 changed files with 277 additions and 159 deletions
|
@ -1,24 +1,16 @@
|
|||
"""
|
||||
:class:`sprockets.http.app.Application` mixin for handling the connection to
|
||||
Postgres and exporting functions for querying the database, getting the status,
|
||||
and proving a cursor.
|
||||
|
||||
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
|
||||
and shutdown.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import typing
|
||||
from distutils import util
|
||||
|
||||
import aiopg
|
||||
import psycopg2
|
||||
from aiopg import pool
|
||||
from psycopg2 import errors, extensions, extras
|
||||
from psycopg2 import errors, extras
|
||||
from tornado import ioloop, web
|
||||
|
||||
LOGGER = logging.getLogger('sprockets-postgres')
|
||||
|
@ -30,13 +22,115 @@ DEFAULT_POSTGRES_JSON = 'FALSE'
|
|||
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
|
||||
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
|
||||
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
||||
DEFAULT_POSTGRES_URL = 'postgresql://localhost:5432'
|
||||
DEFAULT_POSTGRES_UUID = 'TRUE'
|
||||
|
||||
QueryParameters = typing.Union[list, tuple, None]
|
||||
Timeout = typing.Union[int, float, None]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueryResult:
|
||||
row_count: int
|
||||
row: typing.Optional[dict]
|
||||
rows: typing.Optional[typing.List[dict]]
|
||||
|
||||
|
||||
class PostgresConnector:
|
||||
|
||||
def __init__(self,
|
||||
cursor: aiopg.Cursor,
|
||||
on_error: typing.Callable,
|
||||
record_duration: typing.Optional[typing.Callable] = None,
|
||||
timeout: Timeout = None):
|
||||
self.cursor = cursor
|
||||
self._on_error = on_error
|
||||
self._record_duration = record_duration
|
||||
self._timeout = timeout or int(
|
||||
os.environ.get(
|
||||
'POSTGRES_QUERY_TIMEOUT',
|
||||
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
||||
|
||||
async def callproc(self,
|
||||
name: str,
|
||||
parameters: QueryParameters = None,
|
||||
metric_name: str = '',
|
||||
*,
|
||||
timeout: Timeout = None) -> QueryResult:
|
||||
return await self._query(
|
||||
self.cursor.callproc,
|
||||
metric_name,
|
||||
procname=name,
|
||||
parameters=parameters,
|
||||
timeout=timeout)
|
||||
|
||||
async def execute(self,
|
||||
sql: str,
|
||||
parameters: QueryParameters = None,
|
||||
metric_name: str = '',
|
||||
*,
|
||||
timeout: Timeout = None) -> QueryResult:
|
||||
return await self._query(
|
||||
self.cursor.execute,
|
||||
metric_name,
|
||||
operation=sql,
|
||||
parameters=parameters,
|
||||
timeout=timeout)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def transaction(self) \
|
||||
-> typing.AsyncContextManager['PostgresConnector']:
|
||||
async with self.cursor.begin():
|
||||
yield self
|
||||
|
||||
async def _query(self,
|
||||
method: typing.Callable,
|
||||
metric_name: str,
|
||||
**kwargs):
|
||||
if kwargs['timeout'] is None:
|
||||
kwargs['timeout'] = self._timeout
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
await method(**kwargs)
|
||||
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||
LOGGER.error('Caught %r', err)
|
||||
exc = self._on_error(metric_name, err)
|
||||
if exc:
|
||||
raise exc
|
||||
finally:
|
||||
if self._record_duration:
|
||||
self._record_duration(
|
||||
metric_name, time.monotonic() - start_time)
|
||||
return await self._query_results()
|
||||
|
||||
async def _query_results(self) -> QueryResult:
|
||||
row, rows = None, None
|
||||
if self.cursor.rowcount == 1:
|
||||
try:
|
||||
row = dict(await self.cursor.fetchone())
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
elif self.cursor.rowcount > 1:
|
||||
try:
|
||||
rows = [dict(row) for row in await self.cursor.fetchall()]
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
return QueryResult(self.cursor.rowcount, row, rows)
|
||||
|
||||
|
||||
class ConnectionException(Exception):
|
||||
"""Raised when the connection to Postgres can not be established"""
|
||||
|
||||
|
||||
class ApplicationMixin:
|
||||
"""Application mixin for setting up the PostgreSQL client pool"""
|
||||
"""
|
||||
:class:`sprockets.http.app.Application` mixin for handling the connection
|
||||
to Postgres and exporting functions for querying the database,
|
||||
getting the status, and proving a cursor.
|
||||
|
||||
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
|
||||
and shutdown.
|
||||
|
||||
"""
|
||||
POSTGRES_STATUS_TIMEOUT = 3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -46,155 +140,57 @@ class ApplicationMixin:
|
|||
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def postgres_cursor(self,
|
||||
timeout: typing.Optional[int] = None,
|
||||
raise_http_error: bool = True) \
|
||||
-> typing.AsyncContextManager[extensions.cursor]:
|
||||
"""Return a Postgres cursor for the pool"""
|
||||
async def postgres_connector(self,
|
||||
on_error: typing.Callable,
|
||||
record_duration: typing.Optional[
|
||||
typing.Callable] = None,
|
||||
timeout: Timeout = None) \
|
||||
-> typing.AsyncContextManager[PostgresConnector]:
|
||||
try:
|
||||
async with self._postgres_pool.acquire() as conn:
|
||||
async with conn.cursor(
|
||||
cursor_factory=extras.RealDictCursor,
|
||||
timeout=self._postgres_query_timeout(timeout)) as pgc:
|
||||
yield pgc
|
||||
except (asyncio.TimeoutError,
|
||||
psycopg2.OperationalError,
|
||||
psycopg2.Error) as error:
|
||||
LOGGER.critical('Error connecting to Postgres: %s', error)
|
||||
if raise_http_error:
|
||||
raise web.HTTPError(503, 'Database Unavailable')
|
||||
raise
|
||||
|
||||
async def postgres_callproc(self,
|
||||
name: str,
|
||||
params: typing.Union[list, tuple, None] = None,
|
||||
timeout: typing.Optional[int] = None) \
|
||||
-> typing.Union[dict, list, None]:
|
||||
"""Execute a stored procedure, specifying the name, SQL, passing in
|
||||
optional parameters.
|
||||
|
||||
:param name: The stored-proc to call
|
||||
:param params: Optional parameters to pass into the function
|
||||
:param timeout: Optional timeout to override the default query timeout
|
||||
|
||||
"""
|
||||
async with self.postgres_cursor(timeout) as cursor:
|
||||
return await self._postgres_query(
|
||||
cursor, cursor.callproc, name, name, params)
|
||||
|
||||
async def postgres_execute(self, name: str, sql: str,
|
||||
*args,
|
||||
timeout: typing.Optional[int] = None) \
|
||||
-> typing.Union[dict, list, None]:
|
||||
"""Execute a query, specifying a name for the query, the SQL statement,
|
||||
and optional positional arguments to pass in with the query.
|
||||
|
||||
Parameters may be provided as sequence or mapping and will be
|
||||
bound to variables in the operation. Variables are specified
|
||||
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
||||
|
||||
:param name: The stored-proc to call
|
||||
:param sql: The SQL statement to execute
|
||||
:param timeout: Optional timeout to override the default query timeout
|
||||
|
||||
"""
|
||||
async with self.postgres_cursor(timeout) as cursor:
|
||||
return await self._postgres_query(
|
||||
cursor, cursor.execute, name, sql, args)
|
||||
timeout=timeout) as cursor:
|
||||
yield PostgresConnector(
|
||||
cursor, on_error, record_duration, timeout)
|
||||
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||
on_error('postgres_connector', ConnectionException(str(err)))
|
||||
|
||||
async def postgres_status(self) -> dict:
|
||||
"""Invoke from the ``/status`` RequestHandler to check that there is
|
||||
a Postgres connection handler available and return info about the
|
||||
pool.
|
||||
|
||||
"""
|
||||
available = True
|
||||
try:
|
||||
async with self.postgres_cursor(
|
||||
self.POSTGRES_STATUS_TIMEOUT, False) as cursor:
|
||||
await cursor.execute('SELECT 1')
|
||||
except (asyncio.TimeoutError, psycopg2.OperationalError):
|
||||
available = False
|
||||
"""
|
||||
query_error = asyncio.Event()
|
||||
|
||||
def on_error(_metric_name, _exc) -> None:
|
||||
query_error.set()
|
||||
return None
|
||||
|
||||
async with self.postgres_connector(
|
||||
on_error,
|
||||
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
|
||||
await connector.execute('SELECT 1')
|
||||
return {
|
||||
'available': available,
|
||||
'available': not query_error.is_set(),
|
||||
'pool_size': self._postgres_pool.size,
|
||||
'pool_free': self._postgres_pool.freesize
|
||||
}
|
||||
|
||||
async def _postgres_query(self,
|
||||
cursor: aiopg.Cursor,
|
||||
method: typing.Callable,
|
||||
name: str,
|
||||
sql: str,
|
||||
parameters: typing.Union[dict, list, tuple]) \
|
||||
-> typing.Union[dict, list, None]:
|
||||
"""Execute a query, specifying the name, SQL, passing in
|
||||
|
||||
"""
|
||||
try:
|
||||
await method(sql, parameters)
|
||||
except asyncio.TimeoutError as err:
|
||||
LOGGER.error('Query timeout for %s: %s',
|
||||
name, str(err).split('\n')[0])
|
||||
raise web.HTTPError(500, reason='Query Timeout')
|
||||
except errors.UniqueViolation as err:
|
||||
LOGGER.error('Database error for %s: %s',
|
||||
name, str(err).split('\n')[0])
|
||||
raise web.HTTPError(409, reason='Unique Violation')
|
||||
except psycopg2.Error as err:
|
||||
LOGGER.error('Database error for %s: %s',
|
||||
name, str(err).split('\n')[0])
|
||||
raise web.HTTPError(500, reason='Database Error')
|
||||
try:
|
||||
return await self._postgres_query_results(cursor)
|
||||
except psycopg2.ProgrammingError:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def _postgres_query_results(cursor: aiopg.Cursor) \
|
||||
-> typing.Union[dict, list, None]:
|
||||
"""Invoked by self.postgres_query to return all of the query results
|
||||
as either a ``dict`` or ``list`` depending on the quantity of rows.
|
||||
|
||||
This can raise a ``psycopg2.ProgrammingError`` for an INSERT/UPDATE
|
||||
without RETURNING or a DELETE. That exception is caught by the caller.
|
||||
|
||||
:raises psycopg2.ProgrammingError: when there are no rows to fetch
|
||||
even though the rowcount is > 0
|
||||
|
||||
"""
|
||||
if cursor.rowcount == 1:
|
||||
return await cursor.fetchone()
|
||||
elif cursor.rowcount > 1:
|
||||
return await cursor.fetchall()
|
||||
return None
|
||||
|
||||
@functools.lru_cache(2)
|
||||
def _postgres_query_timeout(self,
|
||||
timeout: typing.Optional[int] = None) -> int:
|
||||
"""Return query timeout, either from the specified value or
|
||||
``POSTGRES_QUERY_TIMEOUT`` environment variable, if set.
|
||||
|
||||
Defaults to sprockets_postgres.DEFAULT_POSTGRES_QUERY_TIMEOUT.
|
||||
|
||||
"""
|
||||
return timeout if timeout else int(
|
||||
os.environ.get(
|
||||
'POSTGRES_QUERY_TIMEOUT',
|
||||
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
||||
|
||||
async def _postgres_setup(self,
|
||||
_app: web.Application,
|
||||
_ioloop: ioloop.IOLoop) -> None:
|
||||
loop: ioloop.IOLoop) -> None:
|
||||
"""Setup the Postgres pool of connections and log if there is an error.
|
||||
|
||||
This is invoked by the Application on start callback mechanism.
|
||||
|
||||
"""
|
||||
url = os.environ.get('POSTGRES_URL', DEFAULT_POSTGRES_URL)
|
||||
LOGGER.debug('Connecting to PostgreSQL: %s', url)
|
||||
if 'POSTGRES_URL' not in os.environ:
|
||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||
return self.stop(loop)
|
||||
self._postgres_pool = pool.Pool(
|
||||
url,
|
||||
os.environ['POSTGRES_URL'],
|
||||
minsize=int(
|
||||
os.environ.get(
|
||||
'POSTGRES_MIN_POOL_SIZE',
|
||||
|
@ -236,3 +232,90 @@ class ApplicationMixin:
|
|||
"""
|
||||
self._postgres_pool.close()
|
||||
await self._postgres_pool.wait_closed()
|
||||
|
||||
|
||||
class RequestHandlerMixin:
|
||||
"""
|
||||
RequestHandler mixin class exposing functions for querying the database,
|
||||
recording the duration to either `sprockets-influxdb` or
|
||||
`sprockets.mixins.metrics`, and handling exceptions.
|
||||
|
||||
"""
|
||||
async def postgres_callproc(self,
|
||||
name: str,
|
||||
parameters: QueryParameters = None,
|
||||
metric_name: str = '',
|
||||
*,
|
||||
timeout: Timeout = None) -> QueryResult:
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
return await connector.callproc(
|
||||
name, parameters, metric_name, timeout=timeout)
|
||||
|
||||
async def postgres_execute(self,
|
||||
sql: str,
|
||||
parameters: QueryParameters = None,
|
||||
metric_name: str = '',
|
||||
*,
|
||||
timeout: Timeout = None) -> QueryResult:
|
||||
"""Execute a query, specifying a name for the query, the SQL statement,
|
||||
and optional positional arguments to pass in with the query.
|
||||
|
||||
Parameters may be provided as sequence or mapping and will be
|
||||
bound to variables in the operation. Variables are specified
|
||||
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
||||
|
||||
"""
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
return await connector.execute(
|
||||
sql, parameters, metric_name, timeout=timeout)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def postgres_transaction(self, timeout: Timeout = None) \
|
||||
-> typing.AsyncContextManager[PostgresConnector]:
|
||||
"""Yields a :class:`PostgresConnector` instance in a transaction.
|
||||
Will automatically commit or rollback based upon exception.
|
||||
|
||||
"""
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
async with connector.transaction():
|
||||
yield connector
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _postgres_connector(self, timeout: Timeout = None) \
|
||||
-> typing.AsyncContextManager[PostgresConnector]:
|
||||
async with self.application.postgres_connector(
|
||||
self.__on_postgres_error,
|
||||
self.__on_postgres_timing,
|
||||
timeout) as connector:
|
||||
yield connector
|
||||
|
||||
def __on_postgres_error(self,
|
||||
metric_name: str,
|
||||
exc: Exception) -> typing.Optional[Exception]:
|
||||
"""Override for different error handling behaviors"""
|
||||
LOGGER.error('%s in %s for %s (%s)',
|
||||
exc.__class__.__name__,
|
||||
self.__class__.__name__,
|
||||
metric_name,
|
||||
str(exc).split('\n')[0])
|
||||
if isinstance(exc, ConnectionException):
|
||||
raise web.HTTPError(503, reason='Database Connection Error')
|
||||
elif isinstance(exc, asyncio.TimeoutError):
|
||||
raise web.HTTPError(500, reason='Query Timeout')
|
||||
elif isinstance(exc, errors.UniqueViolation):
|
||||
raise web.HTTPError(409, reason='Unique Violation')
|
||||
elif isinstance(exc, psycopg2.Error):
|
||||
raise web.HTTPError(500, reason='Database Error')
|
||||
return exc
|
||||
|
||||
def __on_postgres_timing(self,
|
||||
metric_name: str,
|
||||
duration: float) -> None:
|
||||
"""Override for custom metric recording"""
|
||||
if hasattr(self, 'influxdb'): # sprockets-influxdb
|
||||
self.influxdb.set_field(metric_name, duration)
|
||||
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
|
||||
self.record_timing(metric_name, duration)
|
||||
else:
|
||||
LOGGER.debug('Postgres query %s duration: %s',
|
||||
metric_name, duration)
|
||||
|
|
75
tests.py
75
tests.py
|
@ -12,40 +12,46 @@ from tornado import web
|
|||
import sprockets_postgres
|
||||
|
||||
|
||||
class CallprocRequestHandler(web.RequestHandler):
|
||||
class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
|
||||
async def get(self):
|
||||
result = await self.application.postgres_callproc('uuid_generate_v4')
|
||||
await self.finish({'value': str(result['uuid_generate_v4'])})
|
||||
result = await self.postgres_callproc(
|
||||
'uuid_generate_v4', metric_name='uuid')
|
||||
await self.finish({'value': str(result.row['uuid_generate_v4'])})
|
||||
|
||||
|
||||
class ExecuteRequestHandler(web.RequestHandler):
|
||||
class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
|
||||
GET_SQL = 'SELECT %s::TEXT AS value;'
|
||||
|
||||
async def get(self):
|
||||
result = await self.application.postgres_execute(
|
||||
'get', self.GET_SQL, self.get_argument('value'))
|
||||
await self.finish({'value': result['value'] if result else None})
|
||||
result = await self.postgres_execute(
|
||||
self.GET_SQL, [self.get_argument('value')], 'get')
|
||||
await self.finish({
|
||||
'value': result.row['value'] if result.row else None})
|
||||
|
||||
|
||||
class MultiRowRequestHandler(web.RequestHandler):
|
||||
class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
|
||||
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
|
||||
|
||||
async def get(self):
|
||||
rows = await self.application.postgres_execute('get', self.GET_SQL)
|
||||
await self.finish({'rows': [row['role_name'] for row in rows]})
|
||||
result = await self.postgres_execute(self.GET_SQL)
|
||||
await self.finish({'rows': [row['role_name'] for row in result.rows]})
|
||||
|
||||
|
||||
class NoRowRequestHandler(web.RequestHandler):
|
||||
class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
|
||||
GET_SQL = """\
|
||||
SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
|
||||
|
||||
async def get(self):
|
||||
rows = await self.application.postgres_execute('get', self.GET_SQL)
|
||||
await self.finish({'rows': rows})
|
||||
result = await self.postgres_execute(self.GET_SQL)
|
||||
await self.finish({'rows': result.rows})
|
||||
|
||||
|
||||
class StatusRequestHandler(web.RequestHandler):
|
||||
|
@ -62,7 +68,7 @@ class Application(sprockets_postgres.ApplicationMixin,
|
|||
pass
|
||||
|
||||
|
||||
class ExecuteTestCase(testing.SprocketsHttpTestCase):
|
||||
class TestCase(testing.SprocketsHttpTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -143,15 +149,44 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase):
|
|||
self.assertEqual(response.code, 500)
|
||||
self.assertIn(b'Database Error', response.body)
|
||||
|
||||
def test_postgres_programming_error(self):
|
||||
with mock.patch.object(self.app, '_postgres_query_results') as pqr:
|
||||
pqr.side_effect = psycopg2.ProgrammingError()
|
||||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertIsNone(json.loads(response.body)['value'])
|
||||
@mock.patch('aiopg.cursor.Cursor.fetchone')
|
||||
def test_postgres_programming_error(self, fetchone):
|
||||
fetchone.side_effect = psycopg2.ProgrammingError()
|
||||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertIsNone(json.loads(response.body)['value'])
|
||||
|
||||
@mock.patch('aiopg.connection.Connection.cursor')
|
||||
def test_postgres_cursor_raises(self, cursor):
|
||||
cursor.side_effect = asyncio.TimeoutError()
|
||||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 503)
|
||||
|
||||
"""
|
||||
class MissingURLTestCase(testing.SprocketsHttpTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open('build/test-environment') as f:
|
||||
for line in f:
|
||||
if line.startswith('export '):
|
||||
line = line[7:]
|
||||
name, _, value = line.strip().partition('=')
|
||||
if name != 'POSTGRES_URL':
|
||||
os.environ[name] = value
|
||||
if 'POSTGRES_URL' in os.environ:
|
||||
del os.environ['POSTGRES_URL']
|
||||
|
||||
def setUp(self):
|
||||
self.stop_mock = None
|
||||
super().setUp()
|
||||
|
||||
def get_app(self):
|
||||
self.app = Application()
|
||||
self.stop_mock = mock.Mock(
|
||||
wraps=self.app.stop, side_effect=RuntimeError)
|
||||
return self.app
|
||||
|
||||
def test_that_stop_is_invoked(self):
|
||||
self.stop_mock.assert_called_once_with(self.io_loop)
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue