diff --git a/sprockets_postgres.py b/sprockets_postgres.py index 02c895d..27a7cd3 100644 --- a/sprockets_postgres.py +++ b/sprockets_postgres.py @@ -110,7 +110,7 @@ class PostgresConnector: """ def __init__(self, cursor: aiopg.Cursor, - on_error: typing.Callable, + on_error: typing.Optional[typing.Callable] = None, on_duration: typing.Optional[typing.Callable] = None, timeout: Timeout = None): self.cursor = cursor @@ -246,9 +246,10 @@ class PostgresConnector: try: await method(**kwargs) except (asyncio.TimeoutError, psycopg2.Error) as err: - exc = self._on_error(metric_name, err) - if exc: - raise exc + if self._on_error: + err = self._on_error(metric_name, err) + if err: + raise err else: results = await self._query_results() if self._on_duration: @@ -313,7 +314,8 @@ class ApplicationMixin: @contextlib.asynccontextmanager async def postgres_connector(self, - on_error: typing.Callable, + on_error: typing.Optional[ + typing.Callable] = None, on_duration: typing.Optional[ typing.Callable] = None, timeout: Timeout = None,