mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-11-11 03:00:20 +00:00
WIP refactor
This commit is contained in:
parent
b9c9495545
commit
611dfd1ec7
2 changed files with 277 additions and 159 deletions
|
@ -1,24 +1,16 @@
|
||||||
"""
|
|
||||||
:class:`sprockets.http.app.Application` mixin for handling the connection to
|
|
||||||
Postgres and exporting functions for querying the database, getting the status,
|
|
||||||
and proving a cursor.
|
|
||||||
|
|
||||||
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
|
|
||||||
and shutdown.
|
|
||||||
|
|
||||||
"""
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import typing
|
import typing
|
||||||
from distutils import util
|
from distutils import util
|
||||||
|
|
||||||
import aiopg
|
import aiopg
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from aiopg import pool
|
from aiopg import pool
|
||||||
from psycopg2 import errors, extensions, extras
|
from psycopg2 import errors, extras
|
||||||
from tornado import ioloop, web
|
from tornado import ioloop, web
|
||||||
|
|
||||||
LOGGER = logging.getLogger('sprockets-postgres')
|
LOGGER = logging.getLogger('sprockets-postgres')
|
||||||
|
@ -30,13 +22,115 @@ DEFAULT_POSTGRES_JSON = 'FALSE'
|
||||||
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
|
DEFAULT_POSTGRES_MAX_POOL_SIZE = 0
|
||||||
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
|
DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
|
||||||
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
||||||
DEFAULT_POSTGRES_URL = 'postgresql://localhost:5432'
|
|
||||||
DEFAULT_POSTGRES_UUID = 'TRUE'
|
DEFAULT_POSTGRES_UUID = 'TRUE'
|
||||||
|
|
||||||
|
QueryParameters = typing.Union[list, tuple, None]
|
||||||
|
Timeout = typing.Union[int, float, None]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class QueryResult:
|
||||||
|
row_count: int
|
||||||
|
row: typing.Optional[dict]
|
||||||
|
rows: typing.Optional[typing.List[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresConnector:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cursor: aiopg.Cursor,
|
||||||
|
on_error: typing.Callable,
|
||||||
|
record_duration: typing.Optional[typing.Callable] = None,
|
||||||
|
timeout: Timeout = None):
|
||||||
|
self.cursor = cursor
|
||||||
|
self._on_error = on_error
|
||||||
|
self._record_duration = record_duration
|
||||||
|
self._timeout = timeout or int(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_QUERY_TIMEOUT',
|
||||||
|
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
||||||
|
|
||||||
|
async def callproc(self,
|
||||||
|
name: str,
|
||||||
|
parameters: QueryParameters = None,
|
||||||
|
metric_name: str = '',
|
||||||
|
*,
|
||||||
|
timeout: Timeout = None) -> QueryResult:
|
||||||
|
return await self._query(
|
||||||
|
self.cursor.callproc,
|
||||||
|
metric_name,
|
||||||
|
procname=name,
|
||||||
|
parameters=parameters,
|
||||||
|
timeout=timeout)
|
||||||
|
|
||||||
|
async def execute(self,
|
||||||
|
sql: str,
|
||||||
|
parameters: QueryParameters = None,
|
||||||
|
metric_name: str = '',
|
||||||
|
*,
|
||||||
|
timeout: Timeout = None) -> QueryResult:
|
||||||
|
return await self._query(
|
||||||
|
self.cursor.execute,
|
||||||
|
metric_name,
|
||||||
|
operation=sql,
|
||||||
|
parameters=parameters,
|
||||||
|
timeout=timeout)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def transaction(self) \
|
||||||
|
-> typing.AsyncContextManager['PostgresConnector']:
|
||||||
|
async with self.cursor.begin():
|
||||||
|
yield self
|
||||||
|
|
||||||
|
async def _query(self,
|
||||||
|
method: typing.Callable,
|
||||||
|
metric_name: str,
|
||||||
|
**kwargs):
|
||||||
|
if kwargs['timeout'] is None:
|
||||||
|
kwargs['timeout'] = self._timeout
|
||||||
|
start_time = time.monotonic()
|
||||||
|
try:
|
||||||
|
await method(**kwargs)
|
||||||
|
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||||
|
LOGGER.error('Caught %r', err)
|
||||||
|
exc = self._on_error(metric_name, err)
|
||||||
|
if exc:
|
||||||
|
raise exc
|
||||||
|
finally:
|
||||||
|
if self._record_duration:
|
||||||
|
self._record_duration(
|
||||||
|
metric_name, time.monotonic() - start_time)
|
||||||
|
return await self._query_results()
|
||||||
|
|
||||||
|
async def _query_results(self) -> QueryResult:
|
||||||
|
row, rows = None, None
|
||||||
|
if self.cursor.rowcount == 1:
|
||||||
|
try:
|
||||||
|
row = dict(await self.cursor.fetchone())
|
||||||
|
except psycopg2.ProgrammingError:
|
||||||
|
pass
|
||||||
|
elif self.cursor.rowcount > 1:
|
||||||
|
try:
|
||||||
|
rows = [dict(row) for row in await self.cursor.fetchall()]
|
||||||
|
except psycopg2.ProgrammingError:
|
||||||
|
pass
|
||||||
|
return QueryResult(self.cursor.rowcount, row, rows)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionException(Exception):
|
||||||
|
"""Raised when the connection to Postgres can not be established"""
|
||||||
|
|
||||||
|
|
||||||
class ApplicationMixin:
|
class ApplicationMixin:
|
||||||
"""Application mixin for setting up the PostgreSQL client pool"""
|
"""
|
||||||
|
:class:`sprockets.http.app.Application` mixin for handling the connection
|
||||||
|
to Postgres and exporting functions for querying the database,
|
||||||
|
getting the status, and proving a cursor.
|
||||||
|
|
||||||
|
Automatically creates and shuts down :class:`aio.pool.Pool` on startup
|
||||||
|
and shutdown.
|
||||||
|
|
||||||
|
"""
|
||||||
POSTGRES_STATUS_TIMEOUT = 3
|
POSTGRES_STATUS_TIMEOUT = 3
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -46,155 +140,57 @@ class ApplicationMixin:
|
||||||
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def postgres_cursor(self,
|
async def postgres_connector(self,
|
||||||
timeout: typing.Optional[int] = None,
|
on_error: typing.Callable,
|
||||||
raise_http_error: bool = True) \
|
record_duration: typing.Optional[
|
||||||
-> typing.AsyncContextManager[extensions.cursor]:
|
typing.Callable] = None,
|
||||||
"""Return a Postgres cursor for the pool"""
|
timeout: Timeout = None) \
|
||||||
|
-> typing.AsyncContextManager[PostgresConnector]:
|
||||||
try:
|
try:
|
||||||
async with self._postgres_pool.acquire() as conn:
|
async with self._postgres_pool.acquire() as conn:
|
||||||
async with conn.cursor(
|
async with conn.cursor(
|
||||||
cursor_factory=extras.RealDictCursor,
|
cursor_factory=extras.RealDictCursor,
|
||||||
timeout=self._postgres_query_timeout(timeout)) as pgc:
|
timeout=timeout) as cursor:
|
||||||
yield pgc
|
yield PostgresConnector(
|
||||||
except (asyncio.TimeoutError,
|
cursor, on_error, record_duration, timeout)
|
||||||
psycopg2.OperationalError,
|
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||||
psycopg2.Error) as error:
|
on_error('postgres_connector', ConnectionException(str(err)))
|
||||||
LOGGER.critical('Error connecting to Postgres: %s', error)
|
|
||||||
if raise_http_error:
|
|
||||||
raise web.HTTPError(503, 'Database Unavailable')
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def postgres_callproc(self,
|
|
||||||
name: str,
|
|
||||||
params: typing.Union[list, tuple, None] = None,
|
|
||||||
timeout: typing.Optional[int] = None) \
|
|
||||||
-> typing.Union[dict, list, None]:
|
|
||||||
"""Execute a stored procedure, specifying the name, SQL, passing in
|
|
||||||
optional parameters.
|
|
||||||
|
|
||||||
:param name: The stored-proc to call
|
|
||||||
:param params: Optional parameters to pass into the function
|
|
||||||
:param timeout: Optional timeout to override the default query timeout
|
|
||||||
|
|
||||||
"""
|
|
||||||
async with self.postgres_cursor(timeout) as cursor:
|
|
||||||
return await self._postgres_query(
|
|
||||||
cursor, cursor.callproc, name, name, params)
|
|
||||||
|
|
||||||
async def postgres_execute(self, name: str, sql: str,
|
|
||||||
*args,
|
|
||||||
timeout: typing.Optional[int] = None) \
|
|
||||||
-> typing.Union[dict, list, None]:
|
|
||||||
"""Execute a query, specifying a name for the query, the SQL statement,
|
|
||||||
and optional positional arguments to pass in with the query.
|
|
||||||
|
|
||||||
Parameters may be provided as sequence or mapping and will be
|
|
||||||
bound to variables in the operation. Variables are specified
|
|
||||||
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
|
||||||
|
|
||||||
:param name: The stored-proc to call
|
|
||||||
:param sql: The SQL statement to execute
|
|
||||||
:param timeout: Optional timeout to override the default query timeout
|
|
||||||
|
|
||||||
"""
|
|
||||||
async with self.postgres_cursor(timeout) as cursor:
|
|
||||||
return await self._postgres_query(
|
|
||||||
cursor, cursor.execute, name, sql, args)
|
|
||||||
|
|
||||||
async def postgres_status(self) -> dict:
|
async def postgres_status(self) -> dict:
|
||||||
"""Invoke from the ``/status`` RequestHandler to check that there is
|
"""Invoke from the ``/status`` RequestHandler to check that there is
|
||||||
a Postgres connection handler available and return info about the
|
a Postgres connection handler available and return info about the
|
||||||
pool.
|
pool.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
available = True
|
query_error = asyncio.Event()
|
||||||
try:
|
|
||||||
async with self.postgres_cursor(
|
def on_error(_metric_name, _exc) -> None:
|
||||||
self.POSTGRES_STATUS_TIMEOUT, False) as cursor:
|
query_error.set()
|
||||||
await cursor.execute('SELECT 1')
|
return None
|
||||||
except (asyncio.TimeoutError, psycopg2.OperationalError):
|
|
||||||
available = False
|
async with self.postgres_connector(
|
||||||
|
on_error,
|
||||||
|
timeout=self.POSTGRES_STATUS_TIMEOUT) as connector:
|
||||||
|
await connector.execute('SELECT 1')
|
||||||
return {
|
return {
|
||||||
'available': available,
|
'available': not query_error.is_set(),
|
||||||
'pool_size': self._postgres_pool.size,
|
'pool_size': self._postgres_pool.size,
|
||||||
'pool_free': self._postgres_pool.freesize
|
'pool_free': self._postgres_pool.freesize
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _postgres_query(self,
|
|
||||||
cursor: aiopg.Cursor,
|
|
||||||
method: typing.Callable,
|
|
||||||
name: str,
|
|
||||||
sql: str,
|
|
||||||
parameters: typing.Union[dict, list, tuple]) \
|
|
||||||
-> typing.Union[dict, list, None]:
|
|
||||||
"""Execute a query, specifying the name, SQL, passing in
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await method(sql, parameters)
|
|
||||||
except asyncio.TimeoutError as err:
|
|
||||||
LOGGER.error('Query timeout for %s: %s',
|
|
||||||
name, str(err).split('\n')[0])
|
|
||||||
raise web.HTTPError(500, reason='Query Timeout')
|
|
||||||
except errors.UniqueViolation as err:
|
|
||||||
LOGGER.error('Database error for %s: %s',
|
|
||||||
name, str(err).split('\n')[0])
|
|
||||||
raise web.HTTPError(409, reason='Unique Violation')
|
|
||||||
except psycopg2.Error as err:
|
|
||||||
LOGGER.error('Database error for %s: %s',
|
|
||||||
name, str(err).split('\n')[0])
|
|
||||||
raise web.HTTPError(500, reason='Database Error')
|
|
||||||
try:
|
|
||||||
return await self._postgres_query_results(cursor)
|
|
||||||
except psycopg2.ProgrammingError:
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _postgres_query_results(cursor: aiopg.Cursor) \
|
|
||||||
-> typing.Union[dict, list, None]:
|
|
||||||
"""Invoked by self.postgres_query to return all of the query results
|
|
||||||
as either a ``dict`` or ``list`` depending on the quantity of rows.
|
|
||||||
|
|
||||||
This can raise a ``psycopg2.ProgrammingError`` for an INSERT/UPDATE
|
|
||||||
without RETURNING or a DELETE. That exception is caught by the caller.
|
|
||||||
|
|
||||||
:raises psycopg2.ProgrammingError: when there are no rows to fetch
|
|
||||||
even though the rowcount is > 0
|
|
||||||
|
|
||||||
"""
|
|
||||||
if cursor.rowcount == 1:
|
|
||||||
return await cursor.fetchone()
|
|
||||||
elif cursor.rowcount > 1:
|
|
||||||
return await cursor.fetchall()
|
|
||||||
return None
|
|
||||||
|
|
||||||
@functools.lru_cache(2)
|
|
||||||
def _postgres_query_timeout(self,
|
|
||||||
timeout: typing.Optional[int] = None) -> int:
|
|
||||||
"""Return query timeout, either from the specified value or
|
|
||||||
``POSTGRES_QUERY_TIMEOUT`` environment variable, if set.
|
|
||||||
|
|
||||||
Defaults to sprockets_postgres.DEFAULT_POSTGRES_QUERY_TIMEOUT.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return timeout if timeout else int(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_QUERY_TIMEOUT',
|
|
||||||
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
|
||||||
|
|
||||||
async def _postgres_setup(self,
|
async def _postgres_setup(self,
|
||||||
_app: web.Application,
|
_app: web.Application,
|
||||||
_ioloop: ioloop.IOLoop) -> None:
|
loop: ioloop.IOLoop) -> None:
|
||||||
"""Setup the Postgres pool of connections and log if there is an error.
|
"""Setup the Postgres pool of connections and log if there is an error.
|
||||||
|
|
||||||
This is invoked by the Application on start callback mechanism.
|
This is invoked by the Application on start callback mechanism.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = os.environ.get('POSTGRES_URL', DEFAULT_POSTGRES_URL)
|
if 'POSTGRES_URL' not in os.environ:
|
||||||
LOGGER.debug('Connecting to PostgreSQL: %s', url)
|
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||||
|
return self.stop(loop)
|
||||||
self._postgres_pool = pool.Pool(
|
self._postgres_pool = pool.Pool(
|
||||||
url,
|
os.environ['POSTGRES_URL'],
|
||||||
minsize=int(
|
minsize=int(
|
||||||
os.environ.get(
|
os.environ.get(
|
||||||
'POSTGRES_MIN_POOL_SIZE',
|
'POSTGRES_MIN_POOL_SIZE',
|
||||||
|
@ -236,3 +232,90 @@ class ApplicationMixin:
|
||||||
"""
|
"""
|
||||||
self._postgres_pool.close()
|
self._postgres_pool.close()
|
||||||
await self._postgres_pool.wait_closed()
|
await self._postgres_pool.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
class RequestHandlerMixin:
|
||||||
|
"""
|
||||||
|
RequestHandler mixin class exposing functions for querying the database,
|
||||||
|
recording the duration to either `sprockets-influxdb` or
|
||||||
|
`sprockets.mixins.metrics`, and handling exceptions.
|
||||||
|
|
||||||
|
"""
|
||||||
|
async def postgres_callproc(self,
|
||||||
|
name: str,
|
||||||
|
parameters: QueryParameters = None,
|
||||||
|
metric_name: str = '',
|
||||||
|
*,
|
||||||
|
timeout: Timeout = None) -> QueryResult:
|
||||||
|
async with self._postgres_connector(timeout) as connector:
|
||||||
|
return await connector.callproc(
|
||||||
|
name, parameters, metric_name, timeout=timeout)
|
||||||
|
|
||||||
|
async def postgres_execute(self,
|
||||||
|
sql: str,
|
||||||
|
parameters: QueryParameters = None,
|
||||||
|
metric_name: str = '',
|
||||||
|
*,
|
||||||
|
timeout: Timeout = None) -> QueryResult:
|
||||||
|
"""Execute a query, specifying a name for the query, the SQL statement,
|
||||||
|
and optional positional arguments to pass in with the query.
|
||||||
|
|
||||||
|
Parameters may be provided as sequence or mapping and will be
|
||||||
|
bound to variables in the operation. Variables are specified
|
||||||
|
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self._postgres_connector(timeout) as connector:
|
||||||
|
return await connector.execute(
|
||||||
|
sql, parameters, metric_name, timeout=timeout)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def postgres_transaction(self, timeout: Timeout = None) \
|
||||||
|
-> typing.AsyncContextManager[PostgresConnector]:
|
||||||
|
"""Yields a :class:`PostgresConnector` instance in a transaction.
|
||||||
|
Will automatically commit or rollback based upon exception.
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self._postgres_connector(timeout) as connector:
|
||||||
|
async with connector.transaction():
|
||||||
|
yield connector
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _postgres_connector(self, timeout: Timeout = None) \
|
||||||
|
-> typing.AsyncContextManager[PostgresConnector]:
|
||||||
|
async with self.application.postgres_connector(
|
||||||
|
self.__on_postgres_error,
|
||||||
|
self.__on_postgres_timing,
|
||||||
|
timeout) as connector:
|
||||||
|
yield connector
|
||||||
|
|
||||||
|
def __on_postgres_error(self,
|
||||||
|
metric_name: str,
|
||||||
|
exc: Exception) -> typing.Optional[Exception]:
|
||||||
|
"""Override for different error handling behaviors"""
|
||||||
|
LOGGER.error('%s in %s for %s (%s)',
|
||||||
|
exc.__class__.__name__,
|
||||||
|
self.__class__.__name__,
|
||||||
|
metric_name,
|
||||||
|
str(exc).split('\n')[0])
|
||||||
|
if isinstance(exc, ConnectionException):
|
||||||
|
raise web.HTTPError(503, reason='Database Connection Error')
|
||||||
|
elif isinstance(exc, asyncio.TimeoutError):
|
||||||
|
raise web.HTTPError(500, reason='Query Timeout')
|
||||||
|
elif isinstance(exc, errors.UniqueViolation):
|
||||||
|
raise web.HTTPError(409, reason='Unique Violation')
|
||||||
|
elif isinstance(exc, psycopg2.Error):
|
||||||
|
raise web.HTTPError(500, reason='Database Error')
|
||||||
|
return exc
|
||||||
|
|
||||||
|
def __on_postgres_timing(self,
|
||||||
|
metric_name: str,
|
||||||
|
duration: float) -> None:
|
||||||
|
"""Override for custom metric recording"""
|
||||||
|
if hasattr(self, 'influxdb'): # sprockets-influxdb
|
||||||
|
self.influxdb.set_field(metric_name, duration)
|
||||||
|
elif hasattr(self, 'record_timing'): # sprockets.mixins.metrics
|
||||||
|
self.record_timing(metric_name, duration)
|
||||||
|
else:
|
||||||
|
LOGGER.debug('Postgres query %s duration: %s',
|
||||||
|
metric_name, duration)
|
||||||
|
|
75
tests.py
75
tests.py
|
@ -12,40 +12,46 @@ from tornado import web
|
||||||
import sprockets_postgres
|
import sprockets_postgres
|
||||||
|
|
||||||
|
|
||||||
class CallprocRequestHandler(web.RequestHandler):
|
class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||||
|
web.RequestHandler):
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
result = await self.application.postgres_callproc('uuid_generate_v4')
|
result = await self.postgres_callproc(
|
||||||
await self.finish({'value': str(result['uuid_generate_v4'])})
|
'uuid_generate_v4', metric_name='uuid')
|
||||||
|
await self.finish({'value': str(result.row['uuid_generate_v4'])})
|
||||||
|
|
||||||
|
|
||||||
class ExecuteRequestHandler(web.RequestHandler):
|
class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||||
|
web.RequestHandler):
|
||||||
|
|
||||||
GET_SQL = 'SELECT %s::TEXT AS value;'
|
GET_SQL = 'SELECT %s::TEXT AS value;'
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
result = await self.application.postgres_execute(
|
result = await self.postgres_execute(
|
||||||
'get', self.GET_SQL, self.get_argument('value'))
|
self.GET_SQL, [self.get_argument('value')], 'get')
|
||||||
await self.finish({'value': result['value'] if result else None})
|
await self.finish({
|
||||||
|
'value': result.row['value'] if result.row else None})
|
||||||
|
|
||||||
|
|
||||||
class MultiRowRequestHandler(web.RequestHandler):
|
class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||||
|
web.RequestHandler):
|
||||||
|
|
||||||
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
|
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
rows = await self.application.postgres_execute('get', self.GET_SQL)
|
result = await self.postgres_execute(self.GET_SQL)
|
||||||
await self.finish({'rows': [row['role_name'] for row in rows]})
|
await self.finish({'rows': [row['role_name'] for row in result.rows]})
|
||||||
|
|
||||||
|
|
||||||
class NoRowRequestHandler(web.RequestHandler):
|
class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||||
|
web.RequestHandler):
|
||||||
|
|
||||||
GET_SQL = """\
|
GET_SQL = """\
|
||||||
SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
|
SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
|
||||||
|
|
||||||
async def get(self):
|
async def get(self):
|
||||||
rows = await self.application.postgres_execute('get', self.GET_SQL)
|
result = await self.postgres_execute(self.GET_SQL)
|
||||||
await self.finish({'rows': rows})
|
await self.finish({'rows': result.rows})
|
||||||
|
|
||||||
|
|
||||||
class StatusRequestHandler(web.RequestHandler):
|
class StatusRequestHandler(web.RequestHandler):
|
||||||
|
@ -62,7 +68,7 @@ class Application(sprockets_postgres.ApplicationMixin,
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ExecuteTestCase(testing.SprocketsHttpTestCase):
|
class TestCase(testing.SprocketsHttpTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
@ -143,15 +149,44 @@ class ExecuteTestCase(testing.SprocketsHttpTestCase):
|
||||||
self.assertEqual(response.code, 500)
|
self.assertEqual(response.code, 500)
|
||||||
self.assertIn(b'Database Error', response.body)
|
self.assertIn(b'Database Error', response.body)
|
||||||
|
|
||||||
def test_postgres_programming_error(self):
|
@mock.patch('aiopg.cursor.Cursor.fetchone')
|
||||||
with mock.patch.object(self.app, '_postgres_query_results') as pqr:
|
def test_postgres_programming_error(self, fetchone):
|
||||||
pqr.side_effect = psycopg2.ProgrammingError()
|
fetchone.side_effect = psycopg2.ProgrammingError()
|
||||||
response = self.fetch('/execute?value=1')
|
response = self.fetch('/execute?value=1')
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
self.assertIsNone(json.loads(response.body)['value'])
|
self.assertIsNone(json.loads(response.body)['value'])
|
||||||
|
|
||||||
@mock.patch('aiopg.connection.Connection.cursor')
|
@mock.patch('aiopg.connection.Connection.cursor')
|
||||||
def test_postgres_cursor_raises(self, cursor):
|
def test_postgres_cursor_raises(self, cursor):
|
||||||
cursor.side_effect = asyncio.TimeoutError()
|
cursor.side_effect = asyncio.TimeoutError()
|
||||||
response = self.fetch('/execute?value=1')
|
response = self.fetch('/execute?value=1')
|
||||||
self.assertEqual(response.code, 503)
|
self.assertEqual(response.code, 503)
|
||||||
|
|
||||||
|
"""
|
||||||
|
class MissingURLTestCase(testing.SprocketsHttpTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
with open('build/test-environment') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith('export '):
|
||||||
|
line = line[7:]
|
||||||
|
name, _, value = line.strip().partition('=')
|
||||||
|
if name != 'POSTGRES_URL':
|
||||||
|
os.environ[name] = value
|
||||||
|
if 'POSTGRES_URL' in os.environ:
|
||||||
|
del os.environ['POSTGRES_URL']
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.stop_mock = None
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def get_app(self):
|
||||||
|
self.app = Application()
|
||||||
|
self.stop_mock = mock.Mock(
|
||||||
|
wraps=self.app.stop, side_effect=RuntimeError)
|
||||||
|
return self.app
|
||||||
|
|
||||||
|
def test_that_stop_is_invoked(self):
|
||||||
|
self.stop_mock.assert_called_once_with(self.io_loop)
|
||||||
|
"""
|
||||||
|
|
Loading…
Reference in a new issue