From 3d6345a882516432bf5ca7a35502ce79dca3370f Mon Sep 17 00:00:00 2001 From: "Gavin M. Roy" Date: Fri, 8 Jan 2021 18:14:39 -0500 Subject: [PATCH] Missed some abstractions from envvars in 1.6 --- VERSION | 2 +- sprockets_postgres.py | 129 ++++++++++++++++++++++++------------------ tests.py | 6 +- 3 files changed, 77 insertions(+), 60 deletions(-) diff --git a/VERSION b/VERSION index dc1e644..9c6d629 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.6.0 +1.6.1 diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 905a749..6a1c6f8 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -32,6 +32,9 @@ DEFAULT_POSTGRES_MIN_POOL_SIZE = '1' DEFAULT_POSTGRES_QUERY_TIMEOUT = 120 DEFAULT_POSTGRES_UUID = 'TRUE' +OptionalCallable = typing.Optional[typing.Callable] +"""Type annocation of optional callable""" + QueryParameters = typing.Union[dict, list, tuple, None] """Type annotation for query parameters""" @@ -106,9 +109,7 @@ class PostgresConnector: :param on_duration: The callback to invoke when a query is complete and all of the data has been returned. :param timeout: A timeout value in seconds for executing queries. If - unspecified, defaults to the ``POSTGRES_QUERY_TIMEOUT`` environment - variable and if that is not specified, to the - :const:`DEFAULT_POSTGRES_QUERY_TIMEOUT` value of ``120`` + unspecified, defaults to the configured query timeout of `120` seconds. :type timeout: :data:`~sprockets_postgres.Timeout` """ @@ -120,10 +121,7 @@ class PostgresConnector: self.cursor = cursor self._on_error = on_error self._on_duration = on_duration - self._timeout = timeout or int( - os.environ.get( - 'POSTGRES_QUERY_TIMEOUT', - DEFAULT_POSTGRES_QUERY_TIMEOUT)) + self._timeout = timeout async def callproc(self, name: str, @@ -312,16 +310,15 @@ class ApplicationMixin: self._postgres_pool: typing.Optional[pool.Pool] = None self._postgres_connected: typing.Optional[asyncio.Event] = None self._postgres_reconnect: typing.Optional[asyncio.Lock] = None + self._postgres_settings = self._create_postgres_settings() self._postgres_srv: bool = False self.runner_callbacks['on_start'].append(self._postgres_on_start) self.runner_callbacks['shutdown'].append(self._postgres_shutdown) @contextlib.asynccontextmanager async def postgres_connector(self, - on_error: typing.Optional[ - typing.Callable] = None, - on_duration: typing.Optional[ - typing.Callable] = None, + on_error: OptionalCallable = None, + on_duration: OptionalCallable = None, timeout: Timeout = None, _attempt: int = 1) \ -> typing.AsyncContextManager[PostgresConnector]: @@ -357,8 +354,11 @@ class ApplicationMixin: try: async with self._postgres_pool.acquire() as conn: async with conn.cursor( - cursor_factory=extras.RealDictCursor, - timeout=timeout) as cursor: + cursor_factory=extras.RealDictCursor, + timeout=( + timeout + or self._postgres_settings['query_timeout'])) \ + as cursor: yield PostgresConnector( cursor, on_error, on_duration, timeout) except (asyncio.TimeoutError, psycopg2.OperationalError) as err: @@ -435,11 +435,57 @@ class ApplicationMixin: 'pool_free': self._postgres_pool.freesize } + def _create_postgres_settings(self) -> dict: + return { + 'url': self.settings.get( + 'postgres_url', os.environ.get('POSTGRES_URL')), + 'max_pool_size': int(self.settings.get( + 'postgres_max_pool_size', + os.environ.get( + 'POSTGRES_MAX_POOL_SIZE', + DEFAULT_POSTGRES_MAX_POOL_SIZE))), + 'min_pool_size': int(self.settings.get( + 'postgres_min_pool_size', + os.environ.get( + 'POSTGRES_MIN_POOL_SIZE', + DEFAULT_POSTGRES_MIN_POOL_SIZE))), + 'timeout': int(self.settings.get( + 'postgres_connection_timeout', + os.environ.get( + 'POSTGRES_CONNECT_TIMEOUT', + DEFAULT_POSTGRES_CONNECTION_TIMEOUT))), + 'connection_ttl': int(self.settings.get( + 'postgres_connection_ttl', + os.environ.get( + 'POSTGRES_CONNECTION_TTL', + DEFAULT_POSTGRES_CONNECTION_TIMEOUT))), + 'enable_hstore': self.settings.get( + 'postgres_hstore', + util.strtobool( + os.environ.get( + 'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))), + 'enable_json': self.settings.get( + 'postgres_json', + util.strtobool( + os.environ.get( + 'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))), + 'enable_uuid': self.settings.get( + 'postgres_uuid', + util.strtobool( + os.environ.get( + 'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))), + 'query_timeout': int(self.settings.get( + 'postgres_query_timeout', + os.environ.get( + 'POSTGRES_QUERY_TIMEOUT', + DEFAULT_POSTGRES_QUERY_TIMEOUT))), + } + async def _postgres_connect(self) -> bool: """Setup the Postgres pool of connections""" self._postgres_connected.clear() - parsed = parse.urlparse(os.environ['POSTGRES_URL']) + parsed = parse.urlparse(self._postgres_settings['url']) if parsed.scheme.endswith('+srv'): self._postgres_srv = True try: @@ -448,7 +494,7 @@ class ApplicationMixin: LOGGER.critical(str(error)) return False else: - url = os.environ['POSTGRES_URL'] + url = self._postgres_settings['url'] if self._postgres_pool: self._postgres_pool.close() @@ -459,45 +505,16 @@ class ApplicationMixin: try: self._postgres_pool = await pool.Pool.from_pool_fill( url, - maxsize=self.settings.get( - 'postgres_max_pool_size', - int(os.environ.get( - 'POSTGRES_MAX_POOL_SIZE', - DEFAULT_POSTGRES_MAX_POOL_SIZE))), - minsize=self.settings.get( - 'postgres_min_pool_size', - int(os.environ.get( - 'POSTGRES_MIN_POOL_SIZE', - DEFAULT_POSTGRES_MIN_POOL_SIZE))), - timeout=self.settings.get( - 'postgres_connect_timeout', - int(os.environ.get( - 'POSTGRES_CONNECT_TIMEOUT', - DEFAULT_POSTGRES_CONNECTION_TIMEOUT))), - enable_hstore=self.settings.get( - 'postgres_hstore', - util.strtobool( - os.environ.get( - 'POSTGRES_HSTORE', DEFAULT_POSTGRES_HSTORE))), - enable_json=self.settings.get( - 'enable_json', - util.strtobool( - os.environ.get( - 'POSTGRES_JSON', DEFAULT_POSTGRES_JSON))), - enable_uuid=self.settings.get( - 'postgres_uuid', - util.strtobool( - os.environ.get( - 'POSTGRES_UUID', DEFAULT_POSTGRES_UUID))), + maxsize=self._postgres_settings['max_pool_size'], + minsize=self._postgres_settings['min_pool_size'], + timeout=self._postgres_settings['timeout'], + enable_hstore=self._postgres_settings['enable_hstore'], + enable_json=self._postgres_settings['enable_json'], + enable_uuid=self._postgres_settings['enable_uuid'], echo=False, on_connect=None, - pool_recycle=self.settings.get( - 'postgres_connection_ttl', - int(os.environ.get( - 'POSTGRES_CONNECTION_TTL', - DEFAULT_POSTGRES_CONNECTION_TTL)))) - except (psycopg2.OperationalError, - psycopg2.Error) as error: # pragma: nocover + pool_recycle=self._postgres_settings['connection_ttl']) + except psycopg2.Error as error: # pragma: nocover LOGGER.warning( 'Error connecting to PostgreSQL on startup with %s: %s', safe_url, error) @@ -526,8 +543,8 @@ class ApplicationMixin: callback mechanism. """ - if 'POSTGRES_URL' not in os.environ: - LOGGER.critical('Missing POSTGRES_URL environment variable') + if not self._postgres_settings['url']: + LOGGER.critical('Missing required `postgres_url` setting') return self.stop(loop) self._postgres_connected = asyncio.Event() @@ -557,8 +574,8 @@ class ApplicationMixin: elif parsed.scheme.startswith('aws+'): records = await self._resolve_srv(parsed.hostname) else: - raise RuntimeError('Unsupported URI Scheme: {}'.format( - parsed.scheme)) + raise RuntimeError( + 'Unsupported URI Scheme: {}'.format(parsed.scheme)) if not records: raise RuntimeError('No SRV records found') diff --git a/tests.py b/tests.py index b3e5041..ea1a2c9 100644 --- a/tests.py +++ b/tests.py @@ -715,7 +715,7 @@ class SRVTestCase(asynctest.TestCase): loop = ioloop.IOLoop.current() with mock.patch.object(obj, '_resolve_srv') as resolve_srv: resolve_srv.return_value = [] - os.environ['POSTGRES_URL'] = 'aws+srv://foo@bar/baz' + obj._postgres_settings['url'] = 'aws+srv://foo@bar/baz' await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) critical.assert_any_call('No SRV records found') @@ -727,7 +727,7 @@ class SRVTestCase(asynctest.TestCase): loop = ioloop.IOLoop.current() with mock.patch.object(obj, '_resolve_srv') as resolve_srv: resolve_srv.return_value = [] - os.environ['POSTGRES_URL'] = 'postgresql+srv://foo@bar/baz' + obj._postgres_settings['url'] = 'postgresql+srv://foo@bar/baz' await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) critical.assert_any_call('No SRV records found') @@ -737,7 +737,7 @@ class SRVTestCase(asynctest.TestCase): async def test_unsupported_srv_uri(self, critical, stop): obj = Application() loop = ioloop.IOLoop.current() - os.environ['POSTGRES_URL'] = 'postgres+srv://foo@bar/baz' + obj._postgres_settings['url'] = 'postgres+srv://foo@bar/baz' await obj._postgres_on_start(obj, loop) stop.assert_called_once_with(loop) critical.assert_any_call(