mirror of
https://github.com/sprockets/sprockets-postgres.git
synced 2024-12-28 19:29:21 +00:00
Missed some abstractions from envvars in 1.6
This commit is contained in:
parent
ca362e2bc5
commit
3d6345a882
3 changed files with 77 additions and 60 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
1.6.0
|
1.6.1
|
||||||
|
|
|
@ -32,6 +32,9 @@ DEFAULT_POSTGRES_MIN_POOL_SIZE = '1'
|
||||||
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
|
||||||
DEFAULT_POSTGRES_UUID = 'TRUE'
|
DEFAULT_POSTGRES_UUID = 'TRUE'
|
||||||
|
|
||||||
|
OptionalCallable = typing.Optional[typing.Callable]
|
||||||
|
"""Type annocation of optional callable"""
|
||||||
|
|
||||||
QueryParameters = typing.Union[dict, list, tuple, None]
|
QueryParameters = typing.Union[dict, list, tuple, None]
|
||||||
"""Type annotation for query parameters"""
|
"""Type annotation for query parameters"""
|
||||||
|
|
||||||
|
@ -106,9 +109,7 @@ class PostgresConnector:
|
||||||
:param on_duration: The callback to invoke when a query is complete and all
|
:param on_duration: The callback to invoke when a query is complete and all
|
||||||
of the data has been returned.
|
of the data has been returned.
|
||||||
:param timeout: A timeout value in seconds for executing queries. If
|
:param timeout: A timeout value in seconds for executing queries. If
|
||||||
unspecified, defaults to the ``POSTGRES_QUERY_TIMEOUT`` environment
|
unspecified, defaults to the configured query timeout of `120` seconds.
|
||||||
variable and if that is not specified, to the
|
|
||||||
:const:`DEFAULT_POSTGRES_QUERY_TIMEOUT` value of ``120``
|
|
||||||
:type timeout: :data:`~sprockets_postgres.Timeout`
|
:type timeout: :data:`~sprockets_postgres.Timeout`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -120,10 +121,7 @@ class PostgresConnector:
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
self._on_error = on_error
|
self._on_error = on_error
|
||||||
self._on_duration = on_duration
|
self._on_duration = on_duration
|
||||||
self._timeout = timeout or int(
|
self._timeout = timeout
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_QUERY_TIMEOUT',
|
|
||||||
DEFAULT_POSTGRES_QUERY_TIMEOUT))
|
|
||||||
|
|
||||||
async def callproc(self,
|
async def callproc(self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -312,16 +310,15 @@ class ApplicationMixin:
|
||||||
self._postgres_pool: typing.Optional[pool.Pool] = None
|
self._postgres_pool: typing.Optional[pool.Pool] = None
|
||||||
self._postgres_connected: typing.Optional[asyncio.Event] = None
|
self._postgres_connected: typing.Optional[asyncio.Event] = None
|
||||||
self._postgres_reconnect: typing.Optional[asyncio.Lock] = None
|
self._postgres_reconnect: typing.Optional[asyncio.Lock] = None
|
||||||
|
self._postgres_settings = self._create_postgres_settings()
|
||||||
self._postgres_srv: bool = False
|
self._postgres_srv: bool = False
|
||||||
self.runner_callbacks['on_start'].append(self._postgres_on_start)
|
self.runner_callbacks['on_start'].append(self._postgres_on_start)
|
||||||
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def postgres_connector(self,
|
async def postgres_connector(self,
|
||||||
on_error: typing.Optional[
|
on_error: OptionalCallable = None,
|
||||||
typing.Callable] = None,
|
on_duration: OptionalCallable = None,
|
||||||
on_duration: typing.Optional[
|
|
||||||
typing.Callable] = None,
|
|
||||||
timeout: Timeout = None,
|
timeout: Timeout = None,
|
||||||
_attempt: int = 1) \
|
_attempt: int = 1) \
|
||||||
-> typing.AsyncContextManager[PostgresConnector]:
|
-> typing.AsyncContextManager[PostgresConnector]:
|
||||||
|
@ -357,8 +354,11 @@ class ApplicationMixin:
|
||||||
try:
|
try:
|
||||||
async with self._postgres_pool.acquire() as conn:
|
async with self._postgres_pool.acquire() as conn:
|
||||||
async with conn.cursor(
|
async with conn.cursor(
|
||||||
cursor_factory=extras.RealDictCursor,
|
cursor_factory=extras.RealDictCursor,
|
||||||
timeout=timeout) as cursor:
|
timeout=(
|
||||||
|
timeout
|
||||||
|
or self._postgres_settings['query_timeout'])) \
|
||||||
|
as cursor:
|
||||||
yield PostgresConnector(
|
yield PostgresConnector(
|
||||||
cursor, on_error, on_duration, timeout)
|
cursor, on_error, on_duration, timeout)
|
||||||
except (asyncio.TimeoutError, psycopg2.OperationalError) as err:
|
except (asyncio.TimeoutError, psycopg2.OperationalError) as err:
|
||||||
|
@ -435,11 +435,57 @@ class ApplicationMixin:
|
||||||
'pool_free': self._postgres_pool.freesize
|
'pool_free': self._postgres_pool.freesize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _create_postgres_settings(self) -> dict:
|
||||||
|
return {
|
||||||
|
'url': self.settings.get(
|
||||||
|
'postgres_url', os.environ.get('POSTGRES_URL')),
|
||||||
|
'max_pool_size': int(self.settings.get(
|
||||||
|
'postgres_max_pool_size',
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_MAX_POOL_SIZE',
|
||||||
|
DEFAULT_POSTGRES_MAX_POOL_SIZE))),
|
||||||
|
'min_pool_size': int(self.settings.get(
|
||||||
|
'postgres_min_pool_size',
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_MIN_POOL_SIZE',
|
||||||
|
DEFAULT_POSTGRES_MIN_POOL_SIZE))),
|
||||||
|
'timeout': int(self.settings.get(
|
||||||
|
'postgres_connection_timeout',
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_CONNECT_TIMEOUT',
|
||||||
|
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
|
||||||
|
'connection_ttl': int(self.settings.get(
|
||||||
|
'postgres_connection_ttl',
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_CONNECTION_TTL',
|
||||||
|
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
|
||||||
|
'enable_hstore': self.settings.get(
|
||||||
|
'postgres_hstore',
|
||||||
|
util.strtobool(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))),
|
||||||
|
'enable_json': self.settings.get(
|
||||||
|
'postgres_json',
|
||||||
|
util.strtobool(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))),
|
||||||
|
'enable_uuid': self.settings.get(
|
||||||
|
'postgres_uuid',
|
||||||
|
util.strtobool(
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))),
|
||||||
|
'query_timeout': int(self.settings.get(
|
||||||
|
'postgres_query_timeout',
|
||||||
|
os.environ.get(
|
||||||
|
'POSTGRES_QUERY_TIMEOUT',
|
||||||
|
DEFAULT_POSTGRES_QUERY_TIMEOUT))),
|
||||||
|
}
|
||||||
|
|
||||||
async def _postgres_connect(self) -> bool:
|
async def _postgres_connect(self) -> bool:
|
||||||
"""Setup the Postgres pool of connections"""
|
"""Setup the Postgres pool of connections"""
|
||||||
self._postgres_connected.clear()
|
self._postgres_connected.clear()
|
||||||
|
|
||||||
parsed = parse.urlparse(os.environ['POSTGRES_URL'])
|
parsed = parse.urlparse(self._postgres_settings['url'])
|
||||||
if parsed.scheme.endswith('+srv'):
|
if parsed.scheme.endswith('+srv'):
|
||||||
self._postgres_srv = True
|
self._postgres_srv = True
|
||||||
try:
|
try:
|
||||||
|
@ -448,7 +494,7 @@ class ApplicationMixin:
|
||||||
LOGGER.critical(str(error))
|
LOGGER.critical(str(error))
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
url = os.environ['POSTGRES_URL']
|
url = self._postgres_settings['url']
|
||||||
|
|
||||||
if self._postgres_pool:
|
if self._postgres_pool:
|
||||||
self._postgres_pool.close()
|
self._postgres_pool.close()
|
||||||
|
@ -459,45 +505,16 @@ class ApplicationMixin:
|
||||||
try:
|
try:
|
||||||
self._postgres_pool = await pool.Pool.from_pool_fill(
|
self._postgres_pool = await pool.Pool.from_pool_fill(
|
||||||
url,
|
url,
|
||||||
maxsize=self.settings.get(
|
maxsize=self._postgres_settings['max_pool_size'],
|
||||||
'postgres_max_pool_size',
|
minsize=self._postgres_settings['min_pool_size'],
|
||||||
int(os.environ.get(
|
timeout=self._postgres_settings['timeout'],
|
||||||
'POSTGRES_MAX_POOL_SIZE',
|
enable_hstore=self._postgres_settings['enable_hstore'],
|
||||||
DEFAULT_POSTGRES_MAX_POOL_SIZE))),
|
enable_json=self._postgres_settings['enable_json'],
|
||||||
minsize=self.settings.get(
|
enable_uuid=self._postgres_settings['enable_uuid'],
|
||||||
'postgres_min_pool_size',
|
|
||||||
int(os.environ.get(
|
|
||||||
'POSTGRES_MIN_POOL_SIZE',
|
|
||||||
DEFAULT_POSTGRES_MIN_POOL_SIZE))),
|
|
||||||
timeout=self.settings.get(
|
|
||||||
'postgres_connect_timeout',
|
|
||||||
int(os.environ.get(
|
|
||||||
'POSTGRES_CONNECT_TIMEOUT',
|
|
||||||
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
|
|
||||||
enable_hstore=self.settings.get(
|
|
||||||
'postgres_hstore',
|
|
||||||
util.strtobool(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))),
|
|
||||||
enable_json=self.settings.get(
|
|
||||||
'enable_json',
|
|
||||||
util.strtobool(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))),
|
|
||||||
enable_uuid=self.settings.get(
|
|
||||||
'postgres_uuid',
|
|
||||||
util.strtobool(
|
|
||||||
os.environ.get(
|
|
||||||
'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))),
|
|
||||||
echo=False,
|
echo=False,
|
||||||
on_connect=None,
|
on_connect=None,
|
||||||
pool_recycle=self.settings.get(
|
pool_recycle=self._postgres_settings['connection_ttl'])
|
||||||
'postgres_connection_ttl',
|
except psycopg2.Error as error: # pragma: nocover
|
||||||
int(os.environ.get(
|
|
||||||
'POSTGRES_CONNECTION_TTL',
|
|
||||||
DEFAULT_POSTGRES_CONNECTION_TTL))))
|
|
||||||
except (psycopg2.OperationalError,
|
|
||||||
psycopg2.Error) as error: # pragma: nocover
|
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
'Error connecting to PostgreSQL on startup with %s: %s',
|
'Error connecting to PostgreSQL on startup with %s: %s',
|
||||||
safe_url, error)
|
safe_url, error)
|
||||||
|
@ -526,8 +543,8 @@ class ApplicationMixin:
|
||||||
callback mechanism.
|
callback mechanism.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'POSTGRES_URL' not in os.environ:
|
if not self._postgres_settings['url']:
|
||||||
LOGGER.critical('Missing POSTGRES_URL environment variable')
|
LOGGER.critical('Missing required `postgres_url` setting')
|
||||||
return self.stop(loop)
|
return self.stop(loop)
|
||||||
|
|
||||||
self._postgres_connected = asyncio.Event()
|
self._postgres_connected = asyncio.Event()
|
||||||
|
@ -557,8 +574,8 @@ class ApplicationMixin:
|
||||||
elif parsed.scheme.startswith('aws+'):
|
elif parsed.scheme.startswith('aws+'):
|
||||||
records = await self._resolve_srv(parsed.hostname)
|
records = await self._resolve_srv(parsed.hostname)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Unsupported URI Scheme: {}'.format(
|
raise RuntimeError(
|
||||||
parsed.scheme))
|
'Unsupported URI Scheme: {}'.format(parsed.scheme))
|
||||||
|
|
||||||
if not records:
|
if not records:
|
||||||
raise RuntimeError('No SRV records found')
|
raise RuntimeError('No SRV records found')
|
||||||
|
|
6
tests.py
6
tests.py
|
@ -715,7 +715,7 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
loop = ioloop.IOLoop.current()
|
loop = ioloop.IOLoop.current()
|
||||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
resolve_srv.return_value = []
|
resolve_srv.return_value = []
|
||||||
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz'
|
obj._postgres_settings['url'] = 'aws+srv://foo@bar/baz'
|
||||||
await obj._postgres_on_start(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_any_call('No SRV records found')
|
critical.assert_any_call('No SRV records found')
|
||||||
|
@ -727,7 +727,7 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
loop = ioloop.IOLoop.current()
|
loop = ioloop.IOLoop.current()
|
||||||
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
|
||||||
resolve_srv.return_value = []
|
resolve_srv.return_value = []
|
||||||
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz'
|
obj._postgres_settings['url'] = 'postgresql+srv://foo@bar/baz'
|
||||||
await obj._postgres_on_start(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_any_call('No SRV records found')
|
critical.assert_any_call('No SRV records found')
|
||||||
|
@ -737,7 +737,7 @@ class SRVTestCase(asynctest.TestCase):
|
||||||
async def test_unsupported_srv_uri(self, critical, stop):
|
async def test_unsupported_srv_uri(self, critical, stop):
|
||||||
obj = Application()
|
obj = Application()
|
||||||
loop = ioloop.IOLoop.current()
|
loop = ioloop.IOLoop.current()
|
||||||
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz'
|
obj._postgres_settings['url'] = 'postgres+srv://foo@bar/baz'
|
||||||
await obj._postgres_on_start(obj, loop)
|
await obj._postgres_on_start(obj, loop)
|
||||||
stop.assert_called_once_with(loop)
|
stop.assert_called_once_with(loop)
|
||||||
critical.assert_any_call(
|
critical.assert_any_call(
|
||||||
|
|
Loading…
Reference in a new issue