testcontainers/dbtests/app.py
2024-10-09 23:08:40 -04:00

82 lines
2.4 KiB
Python

import asyncio
import json
import os
import psycopg
import psycopg.abc
import psycopg_pool
import tornado
class DatabaseHandler(tornado.web.RequestHandler):
async def query(self, query: psycopg.abc.Query, *args, **kwargs) -> list:
connection_string = os.environ.get("DBTESTS_DATABASE")
if not connection_string:
raise ValueError("Connection string is not set")
async with await psycopg.AsyncConnection.connect(
connection_string
) as connection:
async with connection.cursor() as cursor:
await cursor.execute(query, *args, **kwargs)
return await cursor.fetchall()
class CounterHandler(DatabaseHandler):
async def get(self, name: str):
result = await self.query(
"""
SELECT value
FROM counters
WHERE name = %(name)s
""",
{"name": name},
)
self.set_header("Content-Type", "application/json")
self.write(json.dumps(result[0][0] if result else 0))
async def post(self, name: str):
result = await self.query(
"""
INSERT INTO counters (name, value)
VALUES (%(name)s, 1)
ON CONFLICT (name) DO UPDATE
SET value = counters.value + 1
RETURNING value
""",
{"name": name},
)
self.set_header("Content-Type", "application/json")
self.write(json.dumps(result[0][0]))
class MyApplication(tornado.web.Application):
def __init__(self) -> None:
self._db_pool: psycopg_pool.AsyncConnectionPool | None = None
handlers = [(r"/counters/(?P<name>.*)", CounterHandler)]
super().__init__(handlers)
async def initialize_db_pool(self) -> psycopg_pool.AsyncConnectionPool:
if not self._db_pool:
connection_string = os.environ.get("DBTESTS_DATABASE")
if not connection_string:
raise ValueError("Connection string is not set")
self._db_pool = psycopg_pool.AsyncConnectionPool(connection_string)
return self._db_pool
async def cleanup(self) -> None:
if self._db_pool:
await self._db_pool.close()
async def main() -> None:
app = MyApplication()
app.listen(8000)
await asyncio.Event().wait()
await app.cleanup()
if __name__ == "__main__":
asyncio.run(main())