Testing update and fixes found in testing

This commit is contained in:
Gavin M. Roy 2020-04-07 16:59:06 -04:00
parent 611dfd1ec7
commit 603eb4d6dd
6 changed files with 290 additions and 50 deletions

View file

@ -48,7 +48,7 @@ jobs:
- name: Install UUID extension - name: Install UUID extension
run: | run: |
apk add postgresql-client \ apk add postgresql-client \
&& psql -q -h postgres -U postgres -d postgres -c 'CREATE EXTENSION "uuid-ossp";' && psql -q -h postgres -U postgres -d postgres -f fixtures/testing.sql
- name: Run flake8 tests - name: Run flake8 tests
run: flake8 run: flake8

View file

@ -65,7 +65,9 @@ docker-compose up -d --quiet-pull
wait_for_healthy_containers 1 wait_for_healthy_containers 1
docker-compose exec postgres psql -q -o /dev/null -U postgres -d postgres -c 'CREATE EXTENSION "uuid-ossp";' printf "Loading fixture data ... "
docker-compose exec postgres psql -q -o /dev/null -U postgres -d postgres -f /fixtures/testing.sql
report_done
cat > build/test-environment<<EOF cat > build/test-environment<<EOF
export ASYNC_TEST_TIMEOUT=5 export ASYNC_TEST_TIMEOUT=5

View file

@ -11,3 +11,7 @@ services:
retries: 3 retries: 3
ports: ports:
- 5432 - 5432
volumes:
- type: bind
source: ./fixtures
target: /fixtures

30
fixtures/testing.sql Normal file
View file

@ -0,0 +1,30 @@
CREATE EXTENSION "uuid-ossp";
CREATE TABLE public.test (
id UUID NOT NULL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_modified_at TIMESTAMP WITH TIME ZONE,
value TEXT NOT NULL
);
CREATE TABLE public.query_count(
key TEXT NOT NULL PRIMARY KEY,
last_updated_at TIMESTAMP WITH TIME ZONE,
count INTEGER
);
INSERT INTO public.query_count (key, last_updated_at, count)
VALUES ('test', CURRENT_TIMESTAMP, 0);
CREATE TABLE public.test_rows (
id INTEGER NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_modified_at TIMESTAMP WITH TIME ZONE,
toggle BOOLEAN NOT NULL
);
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
INSERT INTO public.test_rows (toggle) VALUES (FALSE);

View file

@ -24,7 +24,7 @@ DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120 DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
DEFAULT_POSTGRES_UUID = 'TRUE' DEFAULT_POSTGRES_UUID = 'TRUE'
QueryParameters = typing.Union[list, tuple, None] QueryParameters = typing.Union[dict, list, tuple, None]
Timeout = typing.Union[int, float, None] Timeout = typing.Union[int, float, None]
@ -92,18 +92,17 @@ class PostgresConnector:
try: try:
await method(**kwargs) await method(**kwargs)
except (asyncio.TimeoutError, psycopg2.Error) as err: except (asyncio.TimeoutError, psycopg2.Error) as err:
LOGGER.error('Caught %r', err)
exc = self._on_error(metric_name, err) exc = self._on_error(metric_name, err)
if exc: if exc:
raise exc raise exc
finally: else:
if self._record_duration: if self._record_duration:
self._record_duration( self._record_duration(
metric_name, time.monotonic() - start_time) metric_name, time.monotonic() - start_time)
return await self._query_results() return await self._query_results()
async def _query_results(self) -> QueryResult: async def _query_results(self) -> QueryResult:
row, rows = None, None count, row, rows = self.cursor.rowcount, None, None
if self.cursor.rowcount == 1: if self.cursor.rowcount == 1:
try: try:
row = dict(await self.cursor.fetchone()) row = dict(await self.cursor.fetchone())
@ -114,7 +113,7 @@ class PostgresConnector:
rows = [dict(row) for row in await self.cursor.fetchall()] rows = [dict(row) for row in await self.cursor.fetchall()]
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
return QueryResult(self.cursor.rowcount, row, rows) return QueryResult(count, row, rows)
class ConnectionException(Exception): class ConnectionException(Exception):
@ -247,7 +246,10 @@ class RequestHandlerMixin:
metric_name: str = '', metric_name: str = '',
*, *,
timeout: Timeout = None) -> QueryResult: timeout: Timeout = None) -> QueryResult:
async with self._postgres_connector(timeout) as connector: async with self.application.postgres_connector(
self._on_postgres_error,
self._on_postgres_timing,
timeout) as connector:
return await connector.callproc( return await connector.callproc(
name, parameters, metric_name, timeout=timeout) name, parameters, metric_name, timeout=timeout)
@ -265,7 +267,10 @@ class RequestHandlerMixin:
either with positional ``%s`` or named ``%({name})s`` placeholders. either with positional ``%s`` or named ``%({name})s`` placeholders.
""" """
async with self._postgres_connector(timeout) as connector: async with self.application.postgres_connector(
self._on_postgres_error,
self._on_postgres_timing,
timeout) as connector:
return await connector.execute( return await connector.execute(
sql, parameters, metric_name, timeout=timeout) sql, parameters, metric_name, timeout=timeout)
@ -276,28 +281,20 @@ class RequestHandlerMixin:
Will automatically commit or rollback based upon exception. Will automatically commit or rollback based upon exception.
""" """
async with self._postgres_connector(timeout) as connector: async with self.application.postgres_connector(
self._on_postgres_error,
self._on_postgres_timing,
timeout) as connector:
async with connector.transaction(): async with connector.transaction():
yield connector yield connector
@contextlib.asynccontextmanager def _on_postgres_error(self,
async def _postgres_connector(self, timeout: Timeout = None) \ metric_name: str,
-> typing.AsyncContextManager[PostgresConnector]: exc: Exception) -> typing.Optional[Exception]:
async with self.application.postgres_connector(
self.__on_postgres_error,
self.__on_postgres_timing,
timeout) as connector:
yield connector
def __on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
"""Override for different error handling behaviors""" """Override for different error handling behaviors"""
LOGGER.error('%s in %s for %s (%s)', LOGGER.error('%s in %s for %s (%s)',
exc.__class__.__name__, exc.__class__.__name__, self.__class__.__name__,
self.__class__.__name__, metric_name, str(exc).split('\n')[0])
metric_name,
str(exc).split('\n')[0])
if isinstance(exc, ConnectionException): if isinstance(exc, ConnectionException):
raise web.HTTPError(503, reason='Database Connection Error') raise web.HTTPError(503, reason='Database Connection Error')
elif isinstance(exc, asyncio.TimeoutError): elif isinstance(exc, asyncio.TimeoutError):
@ -308,9 +305,9 @@ class RequestHandlerMixin:
raise web.HTTPError(500, reason='Database Error') raise web.HTTPError(500, reason='Database Error')
return exc return exc
def __on_postgres_timing(self, def _on_postgres_timing(self,
metric_name: str, metric_name: str,
duration: float) -> None: duration: float) -> None:
"""Override for custom metric recording""" """Override for custom metric recording"""
if hasattr(self, 'influxdb'): # sprockets-influxdb if hasattr(self, 'influxdb'): # sprockets-influxdb
self.influxdb.set_field(metric_name, duration) self.influxdb.set_field(metric_name, duration)

249
tests.py
View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import os import os
import typing
import uuid import uuid
from unittest import mock from unittest import mock
@ -12,8 +13,20 @@ from tornado import web
import sprockets_postgres import sprockets_postgres
class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin, class RequestHandler(sprockets_postgres.RequestHandlerMixin,
web.RequestHandler): web.RequestHandler):
"""Base RequestHandler for test endpoints"""
def cast_data(self, data: typing.Union[dict, list, None]) \
-> typing.Union[dict, list, None]:
if data is None:
return None
elif isinstance(data, list):
return [self.cast_data(row) for row in data]
return {k: str(v) for k, v in data.items()}
class CallprocRequestHandler(RequestHandler):
async def get(self): async def get(self):
result = await self.postgres_callproc( result = await self.postgres_callproc(
@ -21,51 +34,164 @@ class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
await self.finish({'value': str(result.row['uuid_generate_v4'])}) await self.finish({'value': str(result.row['uuid_generate_v4'])})
class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin, class CountRequestHandler(RequestHandler):
web.RequestHandler):
GET_SQL = """\
SELECT last_updated_at, count
FROM public.query_count
WHERE key = 'test';"""
async def get(self):
result = await self.postgres_execute(self.GET_SQL)
await self.finish(self.cast_data(result.row))
class ErrorRequestHandler(RequestHandler):
GET_SQL = """\
SELECT last_updated_at, count
FROM public.query_count
WHERE key = 'test';"""
async def get(self):
await self.postgres_execute(self.GET_SQL)
self.set_status(204)
def _on_postgres_error(self,
metric_name: str,
exc: Exception) -> typing.Optional[Exception]:
return RuntimeError()
class ExecuteRequestHandler(RequestHandler):
GET_SQL = 'SELECT %s::TEXT AS value;' GET_SQL = 'SELECT %s::TEXT AS value;'
async def get(self): async def get(self):
timeout = self.get_argument('timeout', None)
if timeout is not None:
timeout = int(timeout)
result = await self.postgres_execute( result = await self.postgres_execute(
self.GET_SQL, [self.get_argument('value')], 'get') self.GET_SQL, [self.get_argument('value')], timeout=timeout)
await self.finish({ await self.finish({
'value': result.row['value'] if result.row else None}) 'value': result.row['value'] if result.row else None})
class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin, class InfluxDBRequestHandler(ExecuteRequestHandler):
web.RequestHandler):
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;' def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.influxdb = self.application.influxdb
self.influxdb.add_field = mock.Mock()
class MetricsMixinRequestHandler(ExecuteRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.record_timing = self.application.record_timing
class MultiRowRequestHandler(RequestHandler):
GET_SQL = 'SELECT * FROM public.test_rows;'
UPDATE_SQL = """\
UPDATE public.test_rows
SET toggle = %(to_value)s,
last_modified_at = CURRENT_TIMESTAMP
WHERE toggle IS %(from_value)s"""
async def get(self): async def get(self):
result = await self.postgres_execute(self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': [row['role_name'] for row in result.rows]}) await self.finish({
'count': result.row_count,
'rows': self.cast_data(result.rows)})
async def post(self):
body = json.loads(self.request.body.decode('utf-8'))
result = await self.postgres_execute(
self.UPDATE_SQL, {
'to_value': body['value'], 'from_value': not body['value']})
await self.finish({
'count': result.row_count,
'rows': self.cast_data(result.rows)})
class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin, class NoRowRequestHandler(RequestHandler):
web.RequestHandler):
GET_SQL = """\ GET_SQL = """\
SELECT * FROM information_schema.tables WHERE table_schema = 'public';""" SELECT * FROM information_schema.tables WHERE table_schema = 'foo';"""
async def get(self): async def get(self):
result = await self.postgres_execute(self.GET_SQL) result = await self.postgres_execute(self.GET_SQL)
await self.finish({'rows': result.rows}) await self.finish({
'count': result.row_count,
'rows': self.cast_data(result.rows)})
class StatusRequestHandler(web.RequestHandler): class StatusRequestHandler(RequestHandler):
async def get(self): async def get(self):
result = await self.application.postgres_status() status = await self.application.postgres_status()
if not result['available']: if not status['available']:
self.set_status(503, 'Database Unavailable') self.set_status(503, 'Database Unavailable')
await self.finish(dict(result)) await self.finish(status)
class TransactionRequestHandler(RequestHandler):
GET_SQL = """\
SELECT id, created_at, last_modified_at, value
FROM public.test
WHERE id = %(id)s;"""
POST_SQL = """\
INSERT INTO public.test (id, created_at, value)
VALUES (%(id)s, CURRENT_TIMESTAMP, %(value)s)
RETURNING id, created_at, value;"""
UPDATE_COUNT_SQL = """\
UPDATE public.query_count
SET count = count + 1,
last_updated_at = CURRENT_TIMESTAMP
WHERE key = 'test'
RETURNING last_updated_at, count;"""
async def get(self, test_id):
result = await self.postgres_execute(self.GET_SQL, {'id': test_id})
if not result.row_count:
raise web.HTTPError(404, 'Not Found')
await self.finish(self.cast_data(result.row))
async def post(self):
body = json.loads(self.request.body.decode('utf-8'))
async with self.postgres_transaction() as postgres:
# This should roll back on the second call to this endopoint
self.application.first_txn = await postgres.execute(
self.POST_SQL, {'id': str(uuid.uuid4()),
'value': str(uuid.uuid4())})
# This should roll back on the second call to this endopoint
count = await postgres.execute(self.UPDATE_COUNT_SQL)
# This will trigger an error on the second call to this endpoint
user = await postgres.execute(self.POST_SQL, body)
await self.finish({
'count': self.cast_data(count.row),
'user': self.cast_data(user.row)})
class Application(sprockets_postgres.ApplicationMixin, class Application(sprockets_postgres.ApplicationMixin,
app.Application): app.Application):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.influxdb = mock.Mock()
self.record_timing = mock.Mock()
self.first_txn: typing.Optional[sprockets_postgres.QueryResult] = None
class TestCase(testing.SprocketsHttpTestCase): class TestCase(testing.SprocketsHttpTestCase):
@ -82,13 +208,22 @@ class TestCase(testing.SprocketsHttpTestCase):
def get_app(self): def get_app(self):
self.app = Application(handlers=[ self.app = Application(handlers=[
web.url('/callproc', CallprocRequestHandler), web.url('/callproc', CallprocRequestHandler),
web.url('/count', CountRequestHandler),
web.url('/error', ErrorRequestHandler),
web.url('/execute', ExecuteRequestHandler), web.url('/execute', ExecuteRequestHandler),
web.url('/influxdb', InfluxDBRequestHandler),
web.url('/metrics-mixin', MetricsMixinRequestHandler),
web.url('/multi-row', MultiRowRequestHandler), web.url('/multi-row', MultiRowRequestHandler),
web.url('/no-row', NoRowRequestHandler), web.url('/no-row', NoRowRequestHandler),
web.url('/status', StatusRequestHandler) web.url('/status', StatusRequestHandler),
web.url('/transaction', TransactionRequestHandler),
web.url('/transaction/(?P<test_id>.*)', TransactionRequestHandler)
]) ])
return self.app return self.app
class RequestHandlerMixinTestCase(TestCase):
def test_postgres_status(self): def test_postgres_status(self):
response = self.fetch('/status') response = self.fetch('/status')
data = json.loads(response.body) data = json.loads(response.body)
@ -109,23 +244,63 @@ class TestCase(testing.SprocketsHttpTestCase):
self.assertIsInstance( self.assertIsInstance(
uuid.UUID(json.loads(response.body)['value']), uuid.UUID) uuid.UUID(json.loads(response.body)['value']), uuid.UUID)
@mock.patch('aiopg.cursor.Cursor.execute')
def test_postgres_error_passthrough(self, execute):
execute.side_effect = asyncio.TimeoutError
response = self.fetch('/error')
self.assertEqual(response.code, 500)
self.assertIn(b'Internal Server Error', response.body)
def test_postgres_execute(self): def test_postgres_execute(self):
expectation = str(uuid.uuid4()) expectation = str(uuid.uuid4())
response = self.fetch('/execute?value={}'.format(expectation)) response = self.fetch('/execute?value={}'.format(expectation))
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], expectation) self.assertEqual(json.loads(response.body)['value'], expectation)
def test_postgres_multirow(self): def test_postgres_execute_with_timeout(self):
expectation = str(uuid.uuid4())
response = self.fetch(
'/execute?value={}&timeout=5'.format(expectation))
self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], expectation)
def test_postgres_influxdb(self):
expectation = str(uuid.uuid4())
response = self.fetch(
'/influxdb?value={}'.format(expectation))
self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], expectation)
self.app.influxdb.set_field.assert_called_once()
def test_postgres_metrics_mixin(self):
expectation = str(uuid.uuid4())
response = self.fetch(
'/metrics-mixin?value={}'.format(expectation))
self.assertEqual(response.code, 200)
self.assertEqual(json.loads(response.body)['value'], expectation)
self.app.record_timing.assert_called_once()
def test_postgres_multirow_get(self):
response = self.fetch('/multi-row') response = self.fetch('/multi-row')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = json.loads(response.body) body = json.loads(response.body)
self.assertEqual(body['count'], 5)
self.assertIsInstance(body['rows'], list) self.assertIsInstance(body['rows'], list)
self.assertIn('postgres', body['rows'])
def test_postgres_multirow_no_data(self):
for value in [True, False]:
response = self.fetch(
'/multi-row', method='POST', body=json.dumps({'value': value}))
self.assertEqual(response.code, 200)
body = json.loads(response.body)
self.assertEqual(body['count'], 5)
self.assertIsNone(body['rows'])
def test_postgres_norow(self): def test_postgres_norow(self):
response = self.fetch('/no-row') response = self.fetch('/no-row')
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
body = json.loads(response.body) body = json.loads(response.body)
self.assertEqual(body['count'], 0)
self.assertIsNone(body['rows']) self.assertIsNone(body['rows'])
@mock.patch('aiopg.cursor.Cursor.execute') @mock.patch('aiopg.cursor.Cursor.execute')
@ -162,6 +337,38 @@ class TestCase(testing.SprocketsHttpTestCase):
response = self.fetch('/execute?value=1') response = self.fetch('/execute?value=1')
self.assertEqual(response.code, 503) self.assertEqual(response.code, 503)
class TransactionTestCase(TestCase):
def test_transactions(self):
test_body = {
'id': str(uuid.uuid4()),
'value': str(uuid.uuid4())
}
response = self.fetch(
'/transaction', method='POST', body=json.dumps(test_body))
self.assertEqual(response.code, 200)
record = json.loads(response.body.decode('utf-8'))
self.assertEqual(record['user']['id'], test_body['id'])
self.assertEqual(record['user']['value'], test_body['value'])
count = record['count']['count']
last_updated = record['count']['last_updated_at']
response = self.fetch(
'/transaction', method='POST', body=json.dumps(test_body))
self.assertEqual(response.code, 409)
response = self.fetch(
'/transaction/{}'.format(self.app.first_txn.row['id']))
self.assertEqual(response.code, 404)
response = self.fetch('/count')
self.assertEqual(response.code, 200)
record = json.loads(response.body.decode('utf-8'))
self.assertEqual(record['count'], count)
self.assertEqual(record['last_updated_at'], last_updated)
""" """
class MissingURLTestCase(testing.SprocketsHttpTestCase): class MissingURLTestCase(testing.SprocketsHttpTestCase):