Missed some abstractions from envvars in 1.6

This commit is contained in:
Gavin M. Roy 2021-01-08 18:14:39 -05:00
parent ca362e2bc5
commit 3d6345a882
3 changed files with 77 additions and 60 deletions

View file

@ -1 +1 @@
1.6.0 1.6.1

View file

@ -32,6 +32,9 @@ DEFAULT_POSTGRES_MIN_POOL_SIZE = '1'
DEFAULT_POSTGRES_QUERY_TIMEOUT = 120 DEFAULT_POSTGRES_QUERY_TIMEOUT = 120
DEFAULT_POSTGRES_UUID = 'TRUE' DEFAULT_POSTGRES_UUID = 'TRUE'
OptionalCallable = typing.Optional[typing.Callable]
"""Type annocation of optional callable"""
QueryParameters = typing.Union[dict, list, tuple, None] QueryParameters = typing.Union[dict, list, tuple, None]
"""Type annotation for query parameters""" """Type annotation for query parameters"""
@ -106,9 +109,7 @@ class PostgresConnector:
:param on_duration: The callback to invoke when a query is complete and all :param on_duration: The callback to invoke when a query is complete and all
of the data has been returned. of the data has been returned.
:param timeout: A timeout value in seconds for executing queries. If :param timeout: A timeout value in seconds for executing queries. If
unspecified, defaults to the ``POSTGRES_QUERY_TIMEOUT`` environment unspecified, defaults to the configured query timeout of `120` seconds.
variable and if that is not specified, to the
:const:`DEFAULT_POSTGRES_QUERY_TIMEOUT` value of ``120``
:type timeout: :data:`~sprockets_postgres.Timeout` :type timeout: :data:`~sprockets_postgres.Timeout`
""" """
@ -120,10 +121,7 @@ class PostgresConnector:
self.cursor = cursor self.cursor = cursor
self._on_error = on_error self._on_error = on_error
self._on_duration = on_duration self._on_duration = on_duration
self._timeout = timeout or int( self._timeout = timeout
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))
async def callproc(self, async def callproc(self,
name: str, name: str,
@ -312,16 +310,15 @@ class ApplicationMixin:
self._postgres_pool: typing.Optional[pool.Pool] = None self._postgres_pool: typing.Optional[pool.Pool] = None
self._postgres_connected: typing.Optional[asyncio.Event] = None self._postgres_connected: typing.Optional[asyncio.Event] = None
self._postgres_reconnect: typing.Optional[asyncio.Lock] = None self._postgres_reconnect: typing.Optional[asyncio.Lock] = None
self._postgres_settings = self._create_postgres_settings()
self._postgres_srv: bool = False self._postgres_srv: bool = False
self.runner_callbacks['on_start'].append(self._postgres_on_start) self.runner_callbacks['on_start'].append(self._postgres_on_start)
self.runner_callbacks['shutdown'].append(self._postgres_shutdown) self.runner_callbacks['shutdown'].append(self._postgres_shutdown)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def postgres_connector(self, async def postgres_connector(self,
on_error: typing.Optional[ on_error: OptionalCallable = None,
typing.Callable] = None, on_duration: OptionalCallable = None,
on_duration: typing.Optional[
typing.Callable] = None,
timeout: Timeout = None, timeout: Timeout = None,
_attempt: int = 1) \ _attempt: int = 1) \
-> typing.AsyncContextManager[PostgresConnector]: -> typing.AsyncContextManager[PostgresConnector]:
@ -357,8 +354,11 @@ class ApplicationMixin:
try: try:
async with self._postgres_pool.acquire() as conn: async with self._postgres_pool.acquire() as conn:
async with conn.cursor( async with conn.cursor(
cursor_factory=extras.RealDictCursor, cursor_factory=extras.RealDictCursor,
timeout=timeout) as cursor: timeout=(
timeout
or self._postgres_settings['query_timeout'])) \
as cursor:
yield PostgresConnector( yield PostgresConnector(
cursor, on_error, on_duration, timeout) cursor, on_error, on_duration, timeout)
except (asyncio.TimeoutError, psycopg2.OperationalError) as err: except (asyncio.TimeoutError, psycopg2.OperationalError) as err:
@ -435,11 +435,57 @@ class ApplicationMixin:
'pool_free': self._postgres_pool.freesize 'pool_free': self._postgres_pool.freesize
} }
def _create_postgres_settings(self) -> dict:
return {
'url': self.settings.get(
'postgres_url', os.environ.get('POSTGRES_URL')),
'max_pool_size': int(self.settings.get(
'postgres_max_pool_size',
os.environ.get(
'POSTGRES_MAX_POOL_SIZE',
DEFAULT_POSTGRES_MAX_POOL_SIZE))),
'min_pool_size': int(self.settings.get(
'postgres_min_pool_size',
os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
DEFAULT_POSTGRES_MIN_POOL_SIZE))),
'timeout': int(self.settings.get(
'postgres_connection_timeout',
os.environ.get(
'POSTGRES_CONNECT_TIMEOUT',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
'connection_ttl': int(self.settings.get(
'postgres_connection_ttl',
os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
'enable_hstore': self.settings.get(
'postgres_hstore',
util.strtobool(
os.environ.get(
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))),
'enable_json': self.settings.get(
'postgres_json',
util.strtobool(
os.environ.get(
'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))),
'enable_uuid': self.settings.get(
'postgres_uuid',
util.strtobool(
os.environ.get(
'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))),
'query_timeout': int(self.settings.get(
'postgres_query_timeout',
os.environ.get(
'POSTGRES_QUERY_TIMEOUT',
DEFAULT_POSTGRES_QUERY_TIMEOUT))),
}
async def _postgres_connect(self) -> bool: async def _postgres_connect(self) -> bool:
"""Setup the Postgres pool of connections""" """Setup the Postgres pool of connections"""
self._postgres_connected.clear() self._postgres_connected.clear()
parsed = parse.urlparse(os.environ['POSTGRES_URL']) parsed = parse.urlparse(self._postgres_settings['url'])
if parsed.scheme.endswith('+srv'): if parsed.scheme.endswith('+srv'):
self._postgres_srv = True self._postgres_srv = True
try: try:
@ -448,7 +494,7 @@ class ApplicationMixin:
LOGGER.critical(str(error)) LOGGER.critical(str(error))
return False return False
else: else:
url = os.environ['POSTGRES_URL'] url = self._postgres_settings['url']
if self._postgres_pool: if self._postgres_pool:
self._postgres_pool.close() self._postgres_pool.close()
@ -459,45 +505,16 @@ class ApplicationMixin:
try: try:
self._postgres_pool = await pool.Pool.from_pool_fill( self._postgres_pool = await pool.Pool.from_pool_fill(
url, url,
maxsize=self.settings.get( maxsize=self._postgres_settings['max_pool_size'],
'postgres_max_pool_size', minsize=self._postgres_settings['min_pool_size'],
int(os.environ.get( timeout=self._postgres_settings['timeout'],
'POSTGRES_MAX_POOL_SIZE', enable_hstore=self._postgres_settings['enable_hstore'],
DEFAULT_POSTGRES_MAX_POOL_SIZE))), enable_json=self._postgres_settings['enable_json'],
minsize=self.settings.get( enable_uuid=self._postgres_settings['enable_uuid'],
'postgres_min_pool_size',
int(os.environ.get(
'POSTGRES_MIN_POOL_SIZE',
DEFAULT_POSTGRES_MIN_POOL_SIZE))),
timeout=self.settings.get(
'postgres_connect_timeout',
int(os.environ.get(
'POSTGRES_CONNECT_TIMEOUT',
DEFAULT_POSTGRES_CONNECTION_TIMEOUT))),
enable_hstore=self.settings.get(
'postgres_hstore',
util.strtobool(
os.environ.get(
'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))),
enable_json=self.settings.get(
'enable_json',
util.strtobool(
os.environ.get(
'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))),
enable_uuid=self.settings.get(
'postgres_uuid',
util.strtobool(
os.environ.get(
'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))),
echo=False, echo=False,
on_connect=None, on_connect=None,
pool_recycle=self.settings.get( pool_recycle=self._postgres_settings['connection_ttl'])
'postgres_connection_ttl', except psycopg2.Error as error: # pragma: nocover
int(os.environ.get(
'POSTGRES_CONNECTION_TTL',
DEFAULT_POSTGRES_CONNECTION_TTL))))
except (psycopg2.OperationalError,
psycopg2.Error) as error: # pragma: nocover
LOGGER.warning( LOGGER.warning(
'Error connecting to PostgreSQL on startup with %s: %s', 'Error connecting to PostgreSQL on startup with %s: %s',
safe_url, error) safe_url, error)
@ -526,8 +543,8 @@ class ApplicationMixin:
callback mechanism. callback mechanism.
""" """
if 'POSTGRES_URL' not in os.environ: if not self._postgres_settings['url']:
LOGGER.critical('Missing POSTGRES_URL environment variable') LOGGER.critical('Missing required `postgres_url` setting')
return self.stop(loop) return self.stop(loop)
self._postgres_connected = asyncio.Event() self._postgres_connected = asyncio.Event()
@ -557,8 +574,8 @@ class ApplicationMixin:
elif parsed.scheme.startswith('aws+'): elif parsed.scheme.startswith('aws+'):
records = await self._resolve_srv(parsed.hostname) records = await self._resolve_srv(parsed.hostname)
else: else:
raise RuntimeError('Unsupported URI Scheme: {}'.format( raise RuntimeError(
parsed.scheme)) 'Unsupported URI Scheme: {}'.format(parsed.scheme))
if not records: if not records:
raise RuntimeError('No SRV records found') raise RuntimeError('No SRV records found')

View file

@ -715,7 +715,7 @@ class SRVTestCase(asynctest.TestCase):
loop = ioloop.IOLoop.current() loop = ioloop.IOLoop.current()
with mock.patch.object(obj, '_resolve_srv') as resolve_srv: with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = [] resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz' obj._postgres_settings['url'] = 'aws+srv://foo@bar/baz'
await obj._postgres_on_start(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_any_call('No SRV records found') critical.assert_any_call('No SRV records found')
@ -727,7 +727,7 @@ class SRVTestCase(asynctest.TestCase):
loop = ioloop.IOLoop.current() loop = ioloop.IOLoop.current()
with mock.patch.object(obj, '_resolve_srv') as resolve_srv: with mock.patch.object(obj, '_resolve_srv') as resolve_srv:
resolve_srv.return_value = [] resolve_srv.return_value = []
os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz' obj._postgres_settings['url'] = 'postgresql+srv://foo@bar/baz'
await obj._postgres_on_start(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_any_call('No SRV records found') critical.assert_any_call('No SRV records found')
@ -737,7 +737,7 @@ class SRVTestCase(asynctest.TestCase):
async def test_unsupported_srv_uri(self, critical, stop): async def test_unsupported_srv_uri(self, critical, stop):
obj = Application() obj = Application()
loop = ioloop.IOLoop.current() loop = ioloop.IOLoop.current()
os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz' obj._postgres_settings['url'] = 'postgres+srv://foo@bar/baz'
await obj._postgres_on_start(obj, loop) await obj._postgres_on_start(obj, loop)
stop.assert_called_once_with(loop) stop.assert_called_once_with(loop)
critical.assert_any_call( critical.assert_any_call(