mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-12-27 03:00:19 +00:00
Testing update and fixes found in testing
This commit is contained in:
parent
611dfd1ec7
commit
603eb4d6dd
6 changed files with 290 additions and 50 deletions
2
.github/workflows/testing.yaml
vendored
2
.github/workflows/testing.yaml
vendored
|
@ -48,7 +48,7 @@ jobs:
|
|||
- name: Install UUID extension
|
||||
run: |
|
||||
apk add postgresql-client \
|
||||
&& psql -q -h postgres -U postgres -d postgres -c 'CREATE EXTENSION "uuid-ossp";'
|
||||
&& psql -q -h postgres -U postgres -d postgres -f fixtures/testing.sql
|
||||
|
||||
- name: Run flake8 tests
|
||||
run: flake8
|
||||
|
|
|
@ -65,7 +65,9 @@ docker-compose up -d --quiet-pull
|
|||
|
||||
wait_for_healthy_containers 1
|
||||
|
||||
docker-compose exec postgres psql -q -o /dev/null -U postgres -d postgres -c 'CREATE EXTENSION "uuid-ossp";'
|
||||
printf "Loading fixture data ... "
|
||||
docker-compose exec postgres psql -q -o /dev/null -U postgres -d postgres -f /fixtures/testing.sql
|
||||
report_done
|
||||
|
||||
cat > build/test-environment<<EOF
|
||||
export ASYNC_TEST_TIMEOUT=5
|
||||
|
|
|
@ -11,3 +11,7 @@ services:
|
|||
retries: 3
|
||||
ports:
|
||||
- 5432
|
||||
volumes:
|
||||
- type: bind
|
||||
source: ./fixtures
|
||||
target: /fixtures
|
||||
|
|
30
fixtures/testing.sql
Normal file
30
fixtures/testing.sql
Normal file
|
@ -0,0 +1,30 @@
|
|||
CREATE EXTENSION "uuid-ossp";
|
||||
|
||||
CREATE TABLE public.test (
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_modified_at TIMESTAMP WITH TIME ZONE,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE public.query_count(
|
||||
key TEXT NOT NULL PRIMARY KEY,
|
||||
last_updated_at TIMESTAMP WITH TIME ZONE,
|
||||
count INTEGER
|
||||
);
|
||||
|
||||
INSERT INTO public.query_count (key, last_updated_at, count)
|
||||
VALUES ('test', CURRENT_TIMESTAMP, 0);
|
||||
|
||||
CREATE TABLE public.test_rows (
|
||||
id INTEGER NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_modified_at TIMESTAMP WITH TIME ZONE,
|
||||
toggle BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
|
||||
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
|
||||
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
|
||||
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
|
||||
INSERT INTO public.test_rows (toggle) VALUES (FALSE);
|
|
@ -24,7 +24,7 @@ DEFAULT_POSTGRES_MIN_POOL_SIZE = 1
|
|||
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
||||
DEFAULT_POSTGRES_UUID = 'TRUE'
|
||||
|
||||
QueryParameters = typing.Union[list, tuple, None]
|
||||
QueryParameters = typing.Union[dict, list, tuple, None]
|
||||
Timeout = typing.Union[int, float, None]
|
||||
|
||||
|
||||
|
@ -92,18 +92,17 @@ class PostgresConnector:
|
|||
try:
|
||||
await method(**kwargs)
|
||||
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||
LOGGER.error('Caught %r', err)
|
||||
exc = self._on_error(metric_name, err)
|
||||
if exc:
|
||||
raise exc
|
||||
finally:
|
||||
else:
|
||||
if self._record_duration:
|
||||
self._record_duration(
|
||||
metric_name, time.monotonic() - start_time)
|
||||
return await self._query_results()
|
||||
|
||||
async def _query_results(self) -> QueryResult:
|
||||
row, rows = None, None
|
||||
count, row, rows = self.cursor.rowcount, None, None
|
||||
if self.cursor.rowcount == 1:
|
||||
try:
|
||||
row = dict(await self.cursor.fetchone())
|
||||
|
@ -114,7 +113,7 @@ class PostgresConnector:
|
|||
rows = [dict(row) for row in await self.cursor.fetchall()]
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
return QueryResult(self.cursor.rowcount, row, rows)
|
||||
return QueryResult(count, row, rows)
|
||||
|
||||
|
||||
class ConnectionException(Exception):
|
||||
|
@ -247,7 +246,10 @@ class RequestHandlerMixin:
|
|||
metric_name: str = '',
|
||||
*,
|
||||
timeout: Timeout = None) -> QueryResult:
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
async with self.application.postgres_connector(
|
||||
self._on_postgres_error,
|
||||
self._on_postgres_timing,
|
||||
timeout) as connector:
|
||||
return await connector.callproc(
|
||||
name, parameters, metric_name, timeout=timeout)
|
||||
|
||||
|
@ -265,7 +267,10 @@ class RequestHandlerMixin:
|
|||
either with positional ``%s`` or named ``%({name})s`` placeholders.
|
||||
|
||||
"""
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
async with self.application.postgres_connector(
|
||||
self._on_postgres_error,
|
||||
self._on_postgres_timing,
|
||||
timeout) as connector:
|
||||
return await connector.execute(
|
||||
sql, parameters, metric_name, timeout=timeout)
|
||||
|
||||
|
@ -276,28 +281,20 @@ class RequestHandlerMixin:
|
|||
Will automatically commit or rollback based upon exception.
|
||||
|
||||
"""
|
||||
async with self._postgres_connector(timeout) as connector:
|
||||
async with self.application.postgres_connector(
|
||||
self._on_postgres_error,
|
||||
self._on_postgres_timing,
|
||||
timeout) as connector:
|
||||
async with connector.transaction():
|
||||
yield connector
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _postgres_connector(self, timeout: Timeout = None) \
|
||||
-> typing.AsyncContextManager[PostgresConnector]:
|
||||
async with self.application.postgres_connector(
|
||||
self.__on_postgres_error,
|
||||
self.__on_postgres_timing,
|
||||
timeout) as connector:
|
||||
yield connector
|
||||
|
||||
def __on_postgres_error(self,
|
||||
metric_name: str,
|
||||
exc: Exception) -> typing.Optional[Exception]:
|
||||
def _on_postgres_error(self,
|
||||
metric_name: str,
|
||||
exc: Exception) -> typing.Optional[Exception]:
|
||||
"""Override for different error handling behaviors"""
|
||||
LOGGER.error('%s in %s for %s (%s)',
|
||||
exc.__class__.__name__,
|
||||
self.__class__.__name__,
|
||||
metric_name,
|
||||
str(exc).split('\n')[0])
|
||||
exc.__class__.__name__, self.__class__.__name__,
|
||||
metric_name, str(exc).split('\n')[0])
|
||||
if isinstance(exc, ConnectionException):
|
||||
raise web.HTTPError(503, reason='Database Connection Error')
|
||||
elif isinstance(exc, asyncio.TimeoutError):
|
||||
|
@ -308,9 +305,9 @@ class RequestHandlerMixin:
|
|||
raise web.HTTPError(500, reason='Database Error')
|
||||
return exc
|
||||
|
||||
def __on_postgres_timing(self,
|
||||
metric_name: str,
|
||||
duration: float) -> None:
|
||||
def _on_postgres_timing(self,
|
||||
metric_name: str,
|
||||
duration: float) -> None:
|
||||
"""Override for custom metric recording"""
|
||||
if hasattr(self, 'influxdb'): # sprockets-influxdb
|
||||
self.influxdb.set_field(metric_name, duration)
|
||||
|
|
249
tests.py
249
tests.py
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
|
@ -12,8 +13,20 @@ from tornado import web
|
|||
import sprockets_postgres
|
||||
|
||||
|
||||
class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
class RequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
"""Base RequestHandler for test endpoints"""
|
||||
|
||||
def cast_data(self, data: typing.Union[dict, list, None]) \
|
||||
-> typing.Union[dict, list, None]:
|
||||
if data is None:
|
||||
return None
|
||||
elif isinstance(data, list):
|
||||
return [self.cast_data(row) for row in data]
|
||||
return {k: str(v) for k, v in data.items()}
|
||||
|
||||
|
||||
class CallprocRequestHandler(RequestHandler):
|
||||
|
||||
async def get(self):
|
||||
result = await self.postgres_callproc(
|
||||
|
@ -21,51 +34,164 @@ class CallprocRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
|||
await self.finish({'value': str(result.row['uuid_generate_v4'])})
|
||||
|
||||
|
||||
class ExecuteRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
class CountRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = """\
|
||||
SELECT last_updated_at, count
|
||||
FROM public.query_count
|
||||
WHERE key = 'test';"""
|
||||
|
||||
async def get(self):
|
||||
result = await self.postgres_execute(self.GET_SQL)
|
||||
await self.finish(self.cast_data(result.row))
|
||||
|
||||
|
||||
class ErrorRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = """\
|
||||
SELECT last_updated_at, count
|
||||
FROM public.query_count
|
||||
WHERE key = 'test';"""
|
||||
|
||||
async def get(self):
|
||||
await self.postgres_execute(self.GET_SQL)
|
||||
self.set_status(204)
|
||||
|
||||
def _on_postgres_error(self,
|
||||
metric_name: str,
|
||||
exc: Exception) -> typing.Optional[Exception]:
|
||||
return RuntimeError()
|
||||
|
||||
|
||||
class ExecuteRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = 'SELECT %s::TEXT AS value;'
|
||||
|
||||
async def get(self):
|
||||
timeout = self.get_argument('timeout', None)
|
||||
if timeout is not None:
|
||||
timeout = int(timeout)
|
||||
result = await self.postgres_execute(
|
||||
self.GET_SQL, [self.get_argument('value')], 'get')
|
||||
self.GET_SQL, [self.get_argument('value')], timeout=timeout)
|
||||
await self.finish({
|
||||
'value': result.row['value'] if result.row else None})
|
||||
|
||||
|
||||
class MultiRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
class InfluxDBRequestHandler(ExecuteRequestHandler):
|
||||
|
||||
GET_SQL = 'SELECT * FROM information_schema.enabled_roles;'
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.influxdb = self.application.influxdb
|
||||
self.influxdb.add_field = mock.Mock()
|
||||
|
||||
|
||||
class MetricsMixinRequestHandler(ExecuteRequestHandler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.record_timing = self.application.record_timing
|
||||
|
||||
|
||||
class MultiRowRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = 'SELECT * FROM public.test_rows;'
|
||||
|
||||
UPDATE_SQL = """\
|
||||
UPDATE public.test_rows
|
||||
SET toggle = %(to_value)s,
|
||||
last_modified_at = CURRENT_TIMESTAMP
|
||||
WHERE toggle IS %(from_value)s"""
|
||||
|
||||
async def get(self):
|
||||
result = await self.postgres_execute(self.GET_SQL)
|
||||
await self.finish({'rows': [row['role_name'] for row in result.rows]})
|
||||
await self.finish({
|
||||
'count': result.row_count,
|
||||
'rows': self.cast_data(result.rows)})
|
||||
|
||||
async def post(self):
|
||||
body = json.loads(self.request.body.decode('utf-8'))
|
||||
result = await self.postgres_execute(
|
||||
self.UPDATE_SQL, {
|
||||
'to_value': body['value'], 'from_value': not body['value']})
|
||||
await self.finish({
|
||||
'count': result.row_count,
|
||||
'rows': self.cast_data(result.rows)})
|
||||
|
||||
|
||||
class NoRowRequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||
web.RequestHandler):
|
||||
class NoRowRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = """\
|
||||
SELECT * FROM information_schema.tables WHERE table_schema = 'public';"""
|
||||
SELECT * FROM information_schema.tables WHERE table_schema = 'foo';"""
|
||||
|
||||
async def get(self):
|
||||
result = await self.postgres_execute(self.GET_SQL)
|
||||
await self.finish({'rows': result.rows})
|
||||
await self.finish({
|
||||
'count': result.row_count,
|
||||
'rows': self.cast_data(result.rows)})
|
||||
|
||||
|
||||
class StatusRequestHandler(web.RequestHandler):
|
||||
class StatusRequestHandler(RequestHandler):
|
||||
|
||||
async def get(self):
|
||||
result = await self.application.postgres_status()
|
||||
if not result['available']:
|
||||
status = await self.application.postgres_status()
|
||||
if not status['available']:
|
||||
self.set_status(503, 'Database Unavailable')
|
||||
await self.finish(dict(result))
|
||||
await self.finish(status)
|
||||
|
||||
|
||||
class TransactionRequestHandler(RequestHandler):
|
||||
|
||||
GET_SQL = """\
|
||||
SELECT id, created_at, last_modified_at, value
|
||||
FROM public.test
|
||||
WHERE id = %(id)s;"""
|
||||
|
||||
POST_SQL = """\
|
||||
INSERT INTO public.test (id, created_at, value)
|
||||
VALUES (%(id)s, CURRENT_TIMESTAMP, %(value)s)
|
||||
RETURNING id, created_at, value;"""
|
||||
|
||||
UPDATE_COUNT_SQL = """\
|
||||
UPDATE public.query_count
|
||||
SET count = count + 1,
|
||||
last_updated_at = CURRENT_TIMESTAMP
|
||||
WHERE key = 'test'
|
||||
RETURNING last_updated_at, count;"""
|
||||
|
||||
async def get(self, test_id):
|
||||
result = await self.postgres_execute(self.GET_SQL, {'id': test_id})
|
||||
if not result.row_count:
|
||||
raise web.HTTPError(404, 'Not Found')
|
||||
await self.finish(self.cast_data(result.row))
|
||||
|
||||
async def post(self):
|
||||
body = json.loads(self.request.body.decode('utf-8'))
|
||||
async with self.postgres_transaction() as postgres:
|
||||
|
||||
# This should roll back on the second call to this endopoint
|
||||
self.application.first_txn = await postgres.execute(
|
||||
self.POST_SQL, {'id': str(uuid.uuid4()),
|
||||
'value': str(uuid.uuid4())})
|
||||
|
||||
# This should roll back on the second call to this endopoint
|
||||
count = await postgres.execute(self.UPDATE_COUNT_SQL)
|
||||
|
||||
# This will trigger an error on the second call to this endpoint
|
||||
user = await postgres.execute(self.POST_SQL, body)
|
||||
|
||||
await self.finish({
|
||||
'count': self.cast_data(count.row),
|
||||
'user': self.cast_data(user.row)})
|
||||
|
||||
|
||||
class Application(sprockets_postgres.ApplicationMixin,
|
||||
app.Application):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.influxdb = mock.Mock()
|
||||
self.record_timing = mock.Mock()
|
||||
self.first_txn: typing.Optional[sprockets_postgres.QueryResult] = None
|
||||
|
||||
|
||||
class TestCase(testing.SprocketsHttpTestCase):
|
||||
|
@ -82,13 +208,22 @@ class TestCase(testing.SprocketsHttpTestCase):
|
|||
def get_app(self):
|
||||
self.app = Application(handlers=[
|
||||
web.url('/callproc', CallprocRequestHandler),
|
||||
web.url('/count', CountRequestHandler),
|
||||
web.url('/error', ErrorRequestHandler),
|
||||
web.url('/execute', ExecuteRequestHandler),
|
||||
web.url('/influxdb', InfluxDBRequestHandler),
|
||||
web.url('/metrics-mixin', MetricsMixinRequestHandler),
|
||||
web.url('/multi-row', MultiRowRequestHandler),
|
||||
web.url('/no-row', NoRowRequestHandler),
|
||||
web.url('/status', StatusRequestHandler)
|
||||
web.url('/status', StatusRequestHandler),
|
||||
web.url('/transaction', TransactionRequestHandler),
|
||||
web.url('/transaction/(?P<test_id>.*)', TransactionRequestHandler)
|
||||
])
|
||||
return self.app
|
||||
|
||||
|
||||
class RequestHandlerMixinTestCase(TestCase):
|
||||
|
||||
def test_postgres_status(self):
|
||||
response = self.fetch('/status')
|
||||
data = json.loads(response.body)
|
||||
|
@ -109,23 +244,63 @@ class TestCase(testing.SprocketsHttpTestCase):
|
|||
self.assertIsInstance(
|
||||
uuid.UUID(json.loads(response.body)['value']), uuid.UUID)
|
||||
|
||||
@mock.patch('aiopg.cursor.Cursor.execute')
|
||||
def test_postgres_error_passthrough(self, execute):
|
||||
execute.side_effect = asyncio.TimeoutError
|
||||
response = self.fetch('/error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertIn(b'Internal Server Error', response.body)
|
||||
|
||||
def test_postgres_execute(self):
|
||||
expectation = str(uuid.uuid4())
|
||||
response = self.fetch('/execute?value={}'.format(expectation))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(json.loads(response.body)['value'], expectation)
|
||||
|
||||
def test_postgres_multirow(self):
|
||||
def test_postgres_execute_with_timeout(self):
|
||||
expectation = str(uuid.uuid4())
|
||||
response = self.fetch(
|
||||
'/execute?value={}&timeout=5'.format(expectation))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(json.loads(response.body)['value'], expectation)
|
||||
|
||||
def test_postgres_influxdb(self):
|
||||
expectation = str(uuid.uuid4())
|
||||
response = self.fetch(
|
||||
'/influxdb?value={}'.format(expectation))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(json.loads(response.body)['value'], expectation)
|
||||
self.app.influxdb.set_field.assert_called_once()
|
||||
|
||||
def test_postgres_metrics_mixin(self):
|
||||
expectation = str(uuid.uuid4())
|
||||
response = self.fetch(
|
||||
'/metrics-mixin?value={}'.format(expectation))
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(json.loads(response.body)['value'], expectation)
|
||||
self.app.record_timing.assert_called_once()
|
||||
|
||||
def test_postgres_multirow_get(self):
|
||||
response = self.fetch('/multi-row')
|
||||
self.assertEqual(response.code, 200)
|
||||
body = json.loads(response.body)
|
||||
self.assertEqual(body['count'], 5)
|
||||
self.assertIsInstance(body['rows'], list)
|
||||
self.assertIn('postgres', body['rows'])
|
||||
|
||||
def test_postgres_multirow_no_data(self):
|
||||
for value in [True, False]:
|
||||
response = self.fetch(
|
||||
'/multi-row', method='POST', body=json.dumps({'value': value}))
|
||||
self.assertEqual(response.code, 200)
|
||||
body = json.loads(response.body)
|
||||
self.assertEqual(body['count'], 5)
|
||||
self.assertIsNone(body['rows'])
|
||||
|
||||
def test_postgres_norow(self):
|
||||
response = self.fetch('/no-row')
|
||||
self.assertEqual(response.code, 200)
|
||||
body = json.loads(response.body)
|
||||
self.assertEqual(body['count'], 0)
|
||||
self.assertIsNone(body['rows'])
|
||||
|
||||
@mock.patch('aiopg.cursor.Cursor.execute')
|
||||
|
@ -162,6 +337,38 @@ class TestCase(testing.SprocketsHttpTestCase):
|
|||
response = self.fetch('/execute?value=1')
|
||||
self.assertEqual(response.code, 503)
|
||||
|
||||
|
||||
class TransactionTestCase(TestCase):
|
||||
|
||||
def test_transactions(self):
|
||||
test_body = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'value': str(uuid.uuid4())
|
||||
}
|
||||
response = self.fetch(
|
||||
'/transaction', method='POST', body=json.dumps(test_body))
|
||||
self.assertEqual(response.code, 200)
|
||||
record = json.loads(response.body.decode('utf-8'))
|
||||
self.assertEqual(record['user']['id'], test_body['id'])
|
||||
self.assertEqual(record['user']['value'], test_body['value'])
|
||||
count = record['count']['count']
|
||||
last_updated = record['count']['last_updated_at']
|
||||
|
||||
response = self.fetch(
|
||||
'/transaction', method='POST', body=json.dumps(test_body))
|
||||
self.assertEqual(response.code, 409)
|
||||
|
||||
response = self.fetch(
|
||||
'/transaction/{}'.format(self.app.first_txn.row['id']))
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch('/count')
|
||||
self.assertEqual(response.code, 200)
|
||||
record = json.loads(response.body.decode('utf-8'))
|
||||
self.assertEqual(record['count'], count)
|
||||
self.assertEqual(record['last_updated_at'], last_updated)
|
||||
|
||||
|
||||
"""
|
||||
class MissingURLTestCase(testing.SprocketsHttpTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue