mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-11-25 11:19:51 +00:00
Add reconfiguration for SRV based connections
- Attempt to add graceful reconfiguration for SRV based connections. On connection failure, the pool will be closed and reopened after fetching to get new SRV records. - When using SRV, use all return hosts in the PostgreSQL URL - If multiple requests hit a disconnect error, the reconnect logic will allow the first one in to reconnect and the others to wait
This commit is contained in:
parent
9c448c11e3
commit
ecc070e974
4 changed files with 205 additions and 103 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
1.2.0
|
1.3.0
|
||||||
|
|
|
@ -27,10 +27,10 @@ details the configuration options and their defaults.
|
||||||
+---------------------------------+--------------------------------------------------+-----------+
|
+---------------------------------+--------------------------------------------------+-----------+
|
||||||
|
|
||||||
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
If ``POSTGRES_URL`` uses a scheme of ``postgresql+srv``, a SRV DNS lookup will be
|
||||||
performed and the lowest priority record with the highest weight will be selected
|
performed and the URL will be constructed containing all host and port combinations
|
||||||
for connecting to Postgres.
|
in priority and weighted order, utilizing `libpq's supoprt <https://www.postgresql.org/docs/12/libpq-connect.html>`_
|
||||||
|
for multiple hosts in a URL.
|
||||||
|
|
||||||
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
AWS's ECS service discovery does not follow the SRV standard, but creates SRV
|
||||||
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
records. If ``POSTGRES_URL`` uses a scheme of ``aws+srv``, a SRV DNS lookup will be
|
||||||
performed using the correct format for ECS service discovery. The lowest priority
|
performed using the correct format for ECS service discovery.
|
||||||
record with the highest weight will be selected for connecting to Postgres.
|
|
||||||
|
|
|
@ -276,10 +276,9 @@ class ConnectionException(Exception):
|
||||||
|
|
||||||
|
|
||||||
class ApplicationMixin:
|
class ApplicationMixin:
|
||||||
"""
|
""":class:`sprockets.http.app.Application` mixin for handling the
|
||||||
:class:`sprockets.http.app.Application` / :class:`tornado.web.Application`
|
connection to Postgres and exporting functions for querying the database,
|
||||||
mixin for handling the connection to Postgres and exporting functions for
|
getting the status, and proving a cursor.
|
||||||
querying the database, getting the status, and proving a cursor.
|
|
||||||
|
|
||||||
Automatically creates and shuts down :class:`aiopg.Pool` on startup
|
Automatically creates and shuts down :class:`aiopg.Pool` on startup
|
||||||
and shutdown by installing `on_start` and `shutdown` callbacks into the
|
and shutdown by installing `on_start` and `shutdown` callbacks into the
|
||||||
|
@ -291,7 +290,10 @@ class ApplicationMixin:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._postgres_pool: typing.Optional[pool.Pool] = None
|
self._postgres_pool: typing.Optional[pool.Pool] = None
|
||||||
self.runner_callbacks['on_start'].append(self._postgres_setup)
|
self._postgres_connected = asyncio.Event()
|
||||||
|
self._postgres_reconnect = asyncio.Lock()
|
||||||
|
self._postgres_srv: bool = False
|
||||||
|
self.runner_callbacks['on_start'].append(self._postgres_on_start)
|
||||||
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
|
@ -299,7 +301,8 @@ class ApplicationMixin:
|
||||||
on_error: typing.Callable,
|
on_error: typing.Callable,
|
||||||
on_duration: typing.Optional[
|
on_duration: typing.Optional[
|
||||||
typing.Callable] = None,
|
typing.Callable] = None,
|
||||||
timeout: Timeout = None) \
|
timeout: Timeout = None,
|
||||||
|
_attempt: int = 1) \
|
||||||
-> typing.AsyncContextManager[PostgresConnector]:
|
-> typing.AsyncContextManager[PostgresConnector]:
|
||||||
"""Asynchronous :ref:`context-manager <python:typecontextmanager>`
|
"""Asynchronous :ref:`context-manager <python:typecontextmanager>`
|
||||||
that returns a :class:`~sprockets_postgres.PostgresConnector` instance
|
that returns a :class:`~sprockets_postgres.PostgresConnector` instance
|
||||||
|
@ -338,7 +341,19 @@ class ApplicationMixin:
|
||||||
yield PostgresConnector(
|
yield PostgresConnector(
|
||||||
cursor, on_error, on_duration, timeout)
|
cursor, on_error, on_duration, timeout)
|
||||||
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
except (asyncio.TimeoutError, psycopg2.Error) as err:
|
||||||
exc = on_error('postgres_connector', ConnectionException(str(err)))
|
message = str(err)
|
||||||
|
if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
|
||||||
|
LOGGER.critical('Disconnected from Postgres: %s', err)
|
||||||
|
if not self._postgres_reconnect.locked():
|
||||||
|
async with self._postgres_reconnect:
|
||||||
|
if await self._postgres_connect():
|
||||||
|
async with self.postgres_connector(
|
||||||
|
on_error, on_duration, timeout,
|
||||||
|
_attempt + 1) as connector:
|
||||||
|
yield connector
|
||||||
|
return
|
||||||
|
message = 'disconnected'
|
||||||
|
exc = on_error('postgres_connector', ConnectionException(message))
|
||||||
if exc:
|
if exc:
|
||||||
raise exc
|
raise exc
|
||||||
else: # postgres_status.on_error does not return an exception
|
else: # postgres_status.on_error does not return an exception
|
||||||
|
@ -370,6 +385,13 @@ class ApplicationMixin:
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not self._postgres_connected.is_set():
|
||||||
|
return {
|
||||||
|
'available': False,
|
||||||
|
'pool_size': 0,
|
||||||
|
'pool_free': 0
|
||||||
|
}
|
||||||
|
|
||||||
LOGGER.debug('Querying postgres status')
|
LOGGER.debug('Querying postgres status')
|
||||||
query_error = asyncio.Event()
|
query_error = asyncio.Event()
|
||||||
|
|
||||||
|
@ -390,6 +412,89 @@ class ApplicationMixin:
|
||||||
'pool_free': self._postgres_pool.freesize
|
'pool_free': self._postgres_pool.freesize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _postgres_connect(self) -> bool:
|
||||||
|
"""Setup the Postgres pool of connections"""
|
||||||
|
self._postgres_connected.clear()
|
||||||
|
|
||||||
|
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
||||||
|
if parsed.scheme.endswith('+srv'):
|
||||||
|
self._postgres_srv = True
|
||||||
|
try:
|
||||||
|
url = await self._postgres_url_from_srv(parsed)
|
||||||
|
except RuntimeError as error:
|
||||||
|
LOGGER.critical(str(error))
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
url = os.environ['POSTGRES_URL']
|
||||||
|
|
||||||
|
if self._postgres_pool:
|
||||||
|
self._postgres_pool.close()
|
||||||
|
|
||||||
|
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
||||||
|
try:
|
||||||
|
self._postgres_pool = await pool.Pool.from_pool_fill(
|
||||||
|
url,
|
||||||
|
maxsize=int(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_MAX_POOL_SIZE',
|
||||||
|
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
|
||||||
|
minsize=int(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_MIN_POOL_SIZE',
|
||||||
|
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
|
||||||
|
timeout=int(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_CONNECT_TIMEOUT',
|
||||||
|
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
|
||||||
|
enable_hstore=util.strtobool(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
|
||||||
|
enable_json=util.strtobool(
|
||||||
|
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
|
||||||
|
enable_uuid=util.strtobool(
|
||||||
|
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
|
||||||
|
echo=False,
|
||||||
|
on_connect=None,
|
||||||
|
pool_recycle=int(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_CONNECTION_TTL',
|
||||||
|
DEFAULT_POSTGRES_CONNECTION_TTL)))
|
||||||
|
except (psycopg2.OperationalError,
|
||||||
|
psycopg2.Error) as error: # pragma: nocover
|
||||||
|
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
||||||
|
error)
|
||||||
|
return False
|
||||||
|
self._postgres_connected.set()
|
||||||
|
LOGGER.debug('Connected to Postgres')
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _postgres_on_start(self,
|
||||||
|
_app: web.Application,
|
||||||
|
loop: ioloop.IOLoop):
|
||||||
|
"""Invoked as a startup step for the application
|
||||||
|
|
||||||
|
This is invoked by the :class:`sprockets.http.app.Application` on start
|
||||||
|
callback mechanism.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if 'POSTGRES_URL' not in os.environ:
|
||||||
|
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
||||||
|
return self.stop(loop)
|
||||||
|
if not await self._postgres_connect():
|
||||||
|
LOGGER.critical('PostgreSQL failed to connect, shutting down')
|
||||||
|
return self.stop(loop)
|
||||||
|
|
||||||
|
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
||||||
|
"""Shutdown the Postgres connections and wait for them to close.
|
||||||
|
|
||||||
|
This is invoked by the :class:`sprockets.http.app.Application` shutdown
|
||||||
|
callback mechanism.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._postgres_pool is not None:
|
||||||
|
self._postgres_pool.close()
|
||||||
|
await self._postgres_pool.wait_closed()
|
||||||
|
|
||||||
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
|
async def _postgres_url_from_srv(self, parsed: parse.ParseResult) -> str:
|
||||||
if parsed.scheme.startswith('postgresql+'):
|
if parsed.scheme.startswith('postgresql+'):
|
||||||
host_parts = parsed.hostname.split('.')
|
host_parts = parsed.hostname.split('.')
|
||||||
|
@ -405,86 +510,16 @@ class ApplicationMixin:
|
||||||
if not records:
|
if not records:
|
||||||
raise RuntimeError('No SRV records found')
|
raise RuntimeError('No SRV records found')
|
||||||
|
|
||||||
|
netloc = []
|
||||||
if parsed.username and not parsed.password:
|
if parsed.username and not parsed.password:
|
||||||
return 'postgresql://{}@{}:{}{}'.format(
|
netloc.append('{}@'.format(parsed.username))
|
||||||
parsed.username, records[0].host, records[0].port, parsed.path)
|
|
||||||
elif parsed.username and parsed.password:
|
elif parsed.username and parsed.password:
|
||||||
return 'postgresql://{}:{}@{}:{}{}'.format(
|
netloc.append('{}:{}@'.format(parsed.username, parsed.password))
|
||||||
parsed.username, parsed.password,
|
netloc.append(','.join([
|
||||||
records[0].host, records[0].port, parsed.path)
|
'{}:{}'.format(r.host, r.port) for r in records]))
|
||||||
return 'postgresql://{}:{}{}'.format(
|
return parse.urlunparse(
|
||||||
records[0].host, records[0].port, parsed.path)
|
('postgresql', ''.join(netloc), parsed.path,
|
||||||
|
parsed.params, parsed.query, ''))
|
||||||
async def _postgres_setup(self,
|
|
||||||
_app: web.Application,
|
|
||||||
loop: ioloop.IOLoop) -> None:
|
|
||||||
"""Setup the Postgres pool of connections and log if there is an error.
|
|
||||||
|
|
||||||
This is invoked by the :class:`sprockets.http.app.Application` on start
|
|
||||||
callback mechanism.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if 'POSTGRES_URL' not in os.environ:
|
|
||||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
|
||||||
return self.stop(loop)
|
|
||||||
|
|
||||||
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
|
||||||
if parsed.scheme.endswith('+srv'):
|
|
||||||
try:
|
|
||||||
url = await self._postgres_url_from_srv(parsed)
|
|
||||||
except RuntimeError as error:
|
|
||||||
LOGGER.critical(str(error))
|
|
||||||
return self.stop(loop)
|
|
||||||
else:
|
|
||||||
url = os.environ['POSTGRES_URL']
|
|
||||||
|
|
||||||
LOGGER.debug('Connecting to %s', os.environ['POSTGRES_URL'])
|
|
||||||
self._postgres_pool = pool.Pool(
|
|
||||||
url,
|
|
||||||
maxsize=int(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_MAX_POOL_SIZE',
|
|
||||||
DEFAULT_POSTGRES_MAX_POOL_SIZE)),
|
|
||||||
minsize=int(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_MIN_POOL_SIZE',
|
|
||||||
DEFAULT_POSTGRES_MIN_POOL_SIZE)),
|
|
||||||
timeout=int(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_CONNECT_TIMEOUT',
|
|
||||||
DEFAULT_POSTGRES_CONNECTION_TIMEOUT)),
|
|
||||||
enable_hstore=util.strtobool(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE)),
|
|
||||||
enable_json=util.strtobool(
|
|
||||||
os.environ.get('POSTGRES_JSON', DEFAULT_POSTGRES_JSON)),
|
|
||||||
enable_uuid=util.strtobool(
|
|
||||||
os.environ.get('POSTGRES_UUID', DEFAULT_POSTGRES_UUID)),
|
|
||||||
echo=False,
|
|
||||||
on_connect=None,
|
|
||||||
pool_recycle=int(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_CONNECTION_TTL',
|
|
||||||
DEFAULT_POSTGRES_CONNECTION_TTL)))
|
|
||||||
try:
|
|
||||||
async with self._postgres_pool._cond:
|
|
||||||
await self._postgres_pool._fill_free_pool(False)
|
|
||||||
except (psycopg2.OperationalError,
|
|
||||||
psycopg2.Error) as error: # pragma: nocover
|
|
||||||
LOGGER.warning('Error connecting to PostgreSQL on startup: %s',
|
|
||||||
error)
|
|
||||||
LOGGER.debug('Connected to Postgres')
|
|
||||||
|
|
||||||
async def _postgres_shutdown(self, _ioloop: ioloop.IOLoop) -> None:
|
|
||||||
"""Shutdown the Postgres connections and wait for them to close.
|
|
||||||
|
|
||||||
This is invoked by the :class:`sprockets.http.app.Application` shutdown
|
|
||||||
callback mechanism.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self._postgres_pool is not None:
|
|
||||||
self._postgres_pool.close()
|
|
||||||
await self._postgres_pool.wait_closed()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _resolve_srv(hostname: str) \
|
async def _resolve_srv(hostname: str) \
|
||||||
|
|
93
tests.py
93
tests.py
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing
|
import typing
|
||||||
|
@ -7,17 +8,21 @@ import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
|
import aiopg
|
||||||
import asynctest
|
import asynctest
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import pycares
|
import pycares
|
||||||
from asynctest import mock
|
from asynctest import mock
|
||||||
from psycopg2 import errors
|
from psycopg2 import errors
|
||||||
from sprockets.http import app, testing
|
from sprockets.http import app, testing
|
||||||
from tornado import ioloop, web
|
from tornado import ioloop, testing as ttesting, web
|
||||||
|
|
||||||
import sprockets_postgres
|
import sprockets_postgres
|
||||||
|
|
||||||
|
|
||||||
|
test_postgres_cursor_oer_invocation = 0
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(sprockets_postgres.RequestHandlerMixin,
|
class RequestHandler(sprockets_postgres.RequestHandlerMixin,
|
||||||
web.RequestHandler):
|
web.RequestHandler):
|
||||||
"""Base RequestHandler for test endpoints"""
|
"""Base RequestHandler for test endpoints"""
|
||||||
|
@ -223,6 +228,11 @@ class Application(sprockets_postgres.ApplicationMixin,
|
||||||
|
|
||||||
class TestCase(testing.SprocketsHttpTestCase):
|
class TestCase(testing.SprocketsHttpTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
asyncio.get_event_loop().run_until_complete(
|
||||||
|
self.app._postgres_connected.wait())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
with open('build/test-environment') as f:
|
with open('build/test-environment') as f:
|
||||||
|
@ -267,6 +277,12 @@ class RequestHandlerMixinTestCase(TestCase):
|
||||||
self.assertEqual(response.code, 503)
|
self.assertEqual(response.code, 503)
|
||||||
self.assertFalse(json.loads(response.body)['available'])
|
self.assertFalse(json.loads(response.body)['available'])
|
||||||
|
|
||||||
|
def test_postgres_status_not_connected(self):
|
||||||
|
self.app._postgres_connected.clear()
|
||||||
|
response = self.fetch('/status')
|
||||||
|
self.assertEqual(response.code, 503)
|
||||||
|
self.assertFalse(json.loads(response.body)['available'])
|
||||||
|
|
||||||
@mock.patch('aiopg.cursor.Cursor.execute')
|
@mock.patch('aiopg.cursor.Cursor.execute')
|
||||||
def test_postgres_status_error(self, execute):
|
def test_postgres_status_error(self, execute):
|
||||||
execute.side_effect = asyncio.TimeoutError()
|
execute.side_effect = asyncio.TimeoutError()
|
||||||
|
@ -384,6 +400,55 @@ class RequestHandlerMixinTestCase(TestCase):
|
||||||
response = self.fetch('/execute?value=1')
|
response = self.fetch('/execute?value=1')
|
||||||
self.assertEqual(response.code, 503)
|
self.assertEqual(response.code, 503)
|
||||||
|
|
||||||
|
def test_postgres_cursor_operational_error_reconnects(self):
|
||||||
|
original = aiopg.connection.Connection.cursor
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def mock_cursor(self, name=None, cursor_factory=None,
|
||||||
|
scrollable=None, withhold=False, timeout=None):
|
||||||
|
global test_postgres_cursor_oer_invocation
|
||||||
|
|
||||||
|
test_postgres_cursor_oer_invocation += 1
|
||||||
|
if test_postgres_cursor_oer_invocation == 1:
|
||||||
|
raise psycopg2.OperationalError()
|
||||||
|
async with original(self, name, cursor_factory, scrollable,
|
||||||
|
withhold, timeout) as value:
|
||||||
|
yield value
|
||||||
|
|
||||||
|
aiopg.connection.Connection.cursor = mock_cursor
|
||||||
|
|
||||||
|
with mock.patch.object(self.app, '_postgres_connect') as connect:
|
||||||
|
response = self.fetch('/execute?value=1')
|
||||||
|
self.assertEqual(response.code, 200)
|
||||||
|
self.assertEqual(json.loads(response.body)['value'], '1')
|
||||||
|
connect.assert_called_once()
|
||||||
|
|
||||||
|
aiopg.connection.Connection.cursor = original
|
||||||
|
|
||||||
|
@mock.patch('aiopg.connection.Connection.cursor')
|
||||||
|
def test_postgres_cursor_raises(self, cursor):
|
||||||
|
cursor.side_effect = psycopg2.OperationalError()
|
||||||
|
with mock.patch.object(self.app, '_postgres_connect') as connect:
|
||||||
|
connect.return_value = False
|
||||||
|
response = self.fetch('/execute?value=1')
|
||||||
|
self.assertEqual(response.code, 503)
|
||||||
|
connect.assert_called_once()
|
||||||
|
|
||||||
|
@ttesting.gen_test()
|
||||||
|
@mock.patch('aiopg.connection.Connection.cursor')
|
||||||
|
async def test_postgres_cursor_failure_concurrency(self, cursor):
|
||||||
|
cursor.side_effect = psycopg2.OperationalError()
|
||||||
|
|
||||||
|
def on_error(*args):
|
||||||
|
return RuntimeError
|
||||||
|
|
||||||
|
async def invoke_cursor():
|
||||||
|
async with self.app.postgres_connector(on_error) as connector:
|
||||||
|
await connector.execute('SELECT 1')
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
await asyncio.gather(invoke_cursor(), invoke_cursor())
|
||||||
|
|
||||||
|
|
||||||
class TransactionTestCase(TestCase):
|
class TransactionTestCase(TestCase):
|
||||||
|
|
||||||
|
@ -464,9 +529,9 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
resolve_srv.return_value = []
|
resolve_srv.return_value = []
|
||||||
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
||||||
await obj._postgres_setup(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_called_once_with('No SRV records found')
|
critical.assert_any_call('No SRV records found')
|
||||||
|
|
||||||
@mock.patch('sprockets.http.app.Application.stop')
|
@mock.patch('sprockets.http.app.Application.stop')
|
||||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||||
|
@ -476,9 +541,9 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
resolve_srv.return_value = []
|
resolve_srv.return_value = []
|
||||||
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
||||||
await obj._postgres_setup(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_called_once_with('No SRV records found')
|
critical.assert_any_call('No SRV records found')
|
||||||
|
|
||||||
@mock.patch('sprockets.http.app.Application.stop')
|
@mock.patch('sprockets.http.app.Application.stop')
|
||||||
@mock.patch('sprockets_postgres.LOGGER.critical')
|
@mock.patch('sprockets_postgres.LOGGER.critical')
|
||||||
|
@ -486,9 +551,9 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
obj = Application()
|
obj = Application()
|
||||||
loop = ioloop.IOLoop.current()
|
loop = ioloop.IOLoop.current()
|
||||||
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
||||||
await obj._postgres_setup(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_called_once_with(
|
critical.assert_any_call(
|
||||||
'Unsupported URI Scheme: postgres+srv')
|
'Unsupported URI Scheme: postgres+srv')
|
||||||
|
|
||||||
@mock.patch('aiodns.DNSResolver.query')
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
|
@ -504,7 +569,8 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
query.return_value = future
|
query.return_value = future
|
||||||
url = await obj._postgres_url_from_srv(parsed)
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
query.assert_called_once_with('bar.baz', 'SRV')
|
query.assert_called_once_with('bar.baz', 'SRV')
|
||||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
self.assertEqual(
|
||||||
|
url, 'postgresql://foo@foo1:5432,foo3:6432,foo2:5432/qux')
|
||||||
|
|
||||||
@mock.patch('aiodns.DNSResolver.query')
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
async def test_postgresql_url_from_srv_variation_1(self, query):
|
async def test_postgresql_url_from_srv_variation_1(self, query):
|
||||||
|
@ -518,7 +584,7 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
query.return_value = future
|
query.return_value = future
|
||||||
url = await obj._postgres_url_from_srv(parsed)
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
|
query.assert_called_once_with('_bar._postgresql.baz', 'SRV')
|
||||||
self.assertEqual(url, 'postgresql://foo@foo1:5432/qux')
|
self.assertEqual(url, 'postgresql://foo@foo1:5432,foo2:5432/qux')
|
||||||
|
|
||||||
@mock.patch('aiodns.DNSResolver.query')
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
async def test_postgresql_url_from_srv_variation_2(self, query):
|
async def test_postgresql_url_from_srv_variation_2(self, query):
|
||||||
|
@ -526,13 +592,14 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
|
parsed = parse.urlparse('postgresql+srv://foo:bar@baz.qux/corgie')
|
||||||
future = asyncio.Future()
|
future = asyncio.Future()
|
||||||
future.set_result([
|
future.set_result([
|
||||||
SRV('foo2', 5432, 2, 0, 32),
|
SRV('foo2', 5432, 1, 0, 32),
|
||||||
SRV('foo1', 5432, 1, 0, 32)
|
SRV('foo1', 5432, 2, 0, 32)
|
||||||
])
|
])
|
||||||
query.return_value = future
|
query.return_value = future
|
||||||
url = await obj._postgres_url_from_srv(parsed)
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
|
query.assert_called_once_with('_baz._postgresql.qux', 'SRV')
|
||||||
self.assertEqual(url, 'postgresql://foo:bar@foo1:5432/corgie')
|
self.assertEqual(
|
||||||
|
url, 'postgresql://foo:bar@foo2:5432,foo1:5432/corgie')
|
||||||
|
|
||||||
@mock.patch('aiodns.DNSResolver.query')
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
async def test_postgresql_url_from_srv_variation_3(self, query):
|
async def test_postgresql_url_from_srv_variation_3(self, query):
|
||||||
|
@ -547,7 +614,7 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
query.return_value = future
|
query.return_value = future
|
||||||
url = await obj._postgres_url_from_srv(parsed)
|
url = await obj._postgres_url_from_srv(parsed)
|
||||||
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
|
query.assert_called_once_with('_foo._postgresql.bar', 'SRV')
|
||||||
self.assertEqual(url, 'postgresql://foo3:5432/baz')
|
self.assertEqual(url, 'postgresql://foo3:5432,foo1:5432,foo2:5432/baz')
|
||||||
|
|
||||||
@mock.patch('aiodns.DNSResolver.query')
|
@mock.patch('aiodns.DNSResolver.query')
|
||||||
async def test_resolve_srv_sorted(self, query):
|
async def test_resolve_srv_sorted(self, query):
|
||||||
|
|
Loading…
Reference in a new issue