Start tackling mypy errors

This commit is contained in:
Correl Roush 2023-02-13 23:30:42 -05:00
parent fcb12ec630
commit 3a7b37e4c7
7 changed files with 821 additions and 523 deletions

1141
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -16,6 +16,7 @@ tqdm = "^4.64.0"
psycopg = {extras = ["c", "pool"], version = "^3.0"}
tornado-openapi3 = "^1.1.0"
PyYAML = "^6.0"
returns = "^0.19.0"
[tool.poetry.dev-dependencies]
pytest = "*"
@ -25,6 +26,24 @@ mypy = "*"
[tool.poetry.scripts]
tutor = 'tutor.cli:main'
[tool.poetry.group.dev.dependencies]
types-pyyaml = "^6.0.12.6"
[tool.mypy]
plugins = "returns.contrib.mypy.returns_plugin"
[[tool.mypy.overrides]]
module = "humanize"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "parsy"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "tqdm"
ignore_missing_imports = true
[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"

View file

@ -16,14 +16,13 @@ class ImportMode(enum.Enum):
async def load(
settings: dict, stream: typing.IO, mode: ImportMode = ImportMode.add
) -> typing.Iterable[typing.Tuple[tutor.models.Card, int]]:
) -> typing.AsyncIterable[typing.Tuple[tutor.models.CardCopy, int]]:
"""Load cards from a CSV file.
Currently supports the following formats:
- Deckbox (set name match can fail)
- MTGStand (uses Scryfall ID)
"""
cards = []
async with await psycopg.AsyncConnection.connect(
settings["database"], autocommit=False
) as conn:

View file

@ -24,7 +24,7 @@ def convert_price(price: typing.Optional[str]) -> typing.Optional[decimal.Decima
async def search(
db: psycopg.Cursor,
db: psycopg.AsyncCursor,
name: typing.Optional[str] = None,
collector_number: typing.Optional[str] = None,
set_code: typing.Optional[str] = None,
@ -39,7 +39,7 @@ async def search(
db.row_factory = psycopg.rows.dict_row
joins = []
constraints = []
params = {}
params: typing.Dict[str, typing.Any] = {}
if name is not None:
constraints.append("cards.name LIKE %(name)s")
params["name"] = name
@ -102,7 +102,7 @@ async def search(
async def advanced_search(
db: psycopg.Cursor,
db: psycopg.AsyncCursor,
search: tutor.search.Search,
limit: int = 10,
offset: int = 0,
@ -112,8 +112,7 @@ async def advanced_search(
db.row_factory = psycopg.rows.dict_row
joins = []
constraints = []
params = {}
sets = []
params: typing.Dict[str, typing.Any] = {}
logger.debug("Performing search for: %s", search)
for i, criterion in enumerate(search.criteria):
@ -244,7 +243,7 @@ async def advanced_search(
async def oracle_id_by_name(
db: psycopg.Cursor, name: str
db: psycopg.AsyncCursor, name: str
) -> typing.Optional[uuid.UUID]:
db.row_factory = psycopg.rows.dict_row
await db.execute(
@ -253,9 +252,11 @@ async def oracle_id_by_name(
row = await db.fetchone()
if row:
return row["oracle_id"]
else:
return None
async def store_card(db: psycopg.Cursor, card: tutor.models.Card) -> None:
async def store_card(db: psycopg.AsyncCursor, card: tutor.models.Card) -> None:
await db.execute(
"""
INSERT INTO tmp_cards
@ -286,7 +287,7 @@ async def store_card(db: psycopg.Cursor, card: tutor.models.Card) -> None:
async def store_price(
db: psycopg.Cursor, date: datetime.date, card: tutor.models.Card
db: psycopg.AsyncCursor, date: datetime.date, card: tutor.models.Card
) -> None:
await db.execute(
"""
@ -306,7 +307,7 @@ async def store_price(
)
async def store_set(db: psycopg.Cursor, set_code: str, name: str) -> None:
async def store_set(db: psycopg.AsyncCursor, set_code: str, name: str) -> None:
await db.execute(
"""
INSERT INTO "sets" ("set_code", "name")
@ -317,7 +318,7 @@ async def store_set(db: psycopg.Cursor, set_code: str, name: str) -> None:
)
async def store_copy(db: psycopg.Cursor, copy: tutor.models.CardCopy) -> None:
async def store_copy(db: psycopg.AsyncCursor, copy: tutor.models.CardCopy) -> None:
await db.execute(
"""
INSERT INTO copies ("scryfall_id", "isFoil", "collection", "condition")
@ -332,24 +333,29 @@ async def store_copy(db: psycopg.Cursor, copy: tutor.models.CardCopy) -> None:
)
async def clear_copies(db: psycopg.Cursor, collection: typing.Optional[str] = None):
async def clear_copies(
db: psycopg.AsyncCursor, collection: typing.Optional[str] = None
):
if collection:
await db.execute("DELETE FROM copies WHERE collection = %s", collection)
else:
await db.execute("DELETE FROM copies")
async def store_deck(db: psycopg.Cursor, name: str) -> None:
async def store_deck(db: psycopg.AsyncCursor, name: str) -> typing.Optional[int]:
await db.execute(
'INSERT INTO "decks" ("name") VALUES (%(name)s) RETURNING "deck_id"',
{"name": name},
)
result = await db.fetchone()
if result:
return result[0]
else:
return None
async def store_deck_card(
db: psycopg.Cursor, deck_id: int, oracle_id: uuid.UUID, quantity: int = 1
db: psycopg.AsyncCursor, deck_id: int, oracle_id: uuid.UUID, quantity: int = 1
) -> None:
await db.execute(
"""
@ -361,7 +367,7 @@ async def store_deck_card(
async def get_decks(
db: psycopg.Cursor, limit: int = 10, offset: int = 0
db: psycopg.AsyncCursor, limit: int = 10, offset: int = 0
) -> typing.List[tutor.models.Deck]:
db.row_factory = psycopg.rows.dict_row
await db.execute(
@ -443,7 +449,7 @@ async def get_decks(
async def get_deck(
db: psycopg.Cursor, deck_id: int
db: psycopg.AsyncCursor, deck_id: int
) -> typing.Optional[tutor.models.Deck]:
db.row_factory = psycopg.rows.dict_row
await db.execute(
@ -519,9 +525,11 @@ async def get_deck(
if card and card.get("oracle_id")
],
)
else:
return None
async def store_var(db: psycopg.Cursor, key: str, value: str) -> None:
async def store_var(db: psycopg.AsyncCursor, key: str, value: str) -> None:
await db.execute(
"""
INSERT INTO "vars" ("key", "value")
@ -533,7 +541,7 @@ async def store_var(db: psycopg.Cursor, key: str, value: str) -> None:
)
async def collection_stats(db: psycopg.Cursor) -> dict:
async def collection_stats(db: psycopg.AsyncCursor) -> typing.Optional[dict]:
db.row_factory = psycopg.rows.dict_row
await db.execute(
"""

View file

@ -15,10 +15,11 @@ class Color(enum.IntEnum):
Red = 5
def __str__(self) -> str:
return dict(zip(Color, "CWUBGR")).get(self.value).replace("C", "")
colors: typing.Dict[int, str] = dict(zip(Color, "CWUBGR"))
return colors.get(self.value, "").replace("C", "")
@staticmethod
def to_string(colors: typing.List["Color"]) -> str:
def to_string(colors: typing.Iterable["Color"]) -> str:
return "".join(map(str, sorted(colors)))
@staticmethod
@ -84,7 +85,7 @@ class Card:
set_code: str
set_name: str
collector_number: str
rarity: str
rarity: Rarity
color_identity: typing.List[Color]
cmc: decimal.Decimal
type_line: str

View file

@ -1,20 +1,23 @@
import datetime
import decimal
import uuid
from returns.maybe import Maybe
import tutor.models
def to_card(data: dict) -> tutor.models.Card:
def to_card(data: dict) -> Maybe[tutor.models.Card]:
prices = {
k: decimal.Decimal(v) if v else None for k, v in data.get("prices", {}).items()
}
return tutor.models.Card(
return Maybe.do(
tutor.models.Card(
scryfall_id=data["id"],
oracle_id=data.get("oracle_id"),
oracle_id=oracle_id,
name=data["name"],
set_code=data["set"].upper() if "set" in data else None,
set_name=data.get("set_name"),
collector_number=data.get("collector_number"),
set_code=set_code,
set_name=set_name,
collector_number=collector_number,
rarity=tutor.models.Rarity.from_string(data.get("rarity", "n/a")),
color_identity=tutor.models.Color.from_string(
"".join(data.get("color_identity", []))
@ -22,21 +25,19 @@ def to_card(data: dict) -> tutor.models.Card:
cmc=decimal.Decimal(data.get("cmc", "0")),
mana_cost=data.get("mana_cost"),
type_line=data.get("type_line", ""),
release_date=datetime.date.fromisoformat(data["released_at"])
if "released_at" in data
else None,
release_date=release_date,
games={
game for game in tutor.models.Game if game.value in data.get("games", [])
game
for game in tutor.models.Game
if game.value in data.get("games", [])
},
legalities={
game_format: {l.value: l for l in tutor.models.Legality}[legality]
for game_format, legality in data.get("legalities", {}).items()
},
edhrec_rank=(
int(data.get("edhrec_rank"))
if data.get("edhrec_rank") is not None
else None
),
edhrec_rank=Maybe.from_optional(data.get("edhrec_rank"))
.map(int)
.value_or(None),
oracle_text=data.get("oracle_text"),
price_usd=prices.get("usd"),
price_usd_foil=prices.get("usd_foil"),
@ -44,3 +45,13 @@ def to_card(data: dict) -> tutor.models.Card:
price_eur_foil=prices.get("eur_foil"),
price_tix=prices.get("tix"),
)
for oracle_id in Maybe.from_optional(data.get("oracle_id")).map(uuid.UUID)
for set_code in Maybe.from_optional(data.get("set")).map(str.upper)
for set_name in Maybe.from_optional(data.get("set_name")).map(str)
for collector_number in Maybe.from_optional(data.get("collector_number")).map(
str
)
for release_date in Maybe.from_optional(data.get("released_at")).map(
datetime.date.fromisoformat
)
)

View file

@ -92,7 +92,7 @@ class JSONEncoder(json.JSONEncoder):
},
}
def _card_copy(self, copy: tutor.models.Card) -> dict:
def _card_copy(self, copy: tutor.models.CardCopy) -> dict:
return {
"card": self._card(copy.card),
"foil": copy.foil,
@ -133,16 +133,6 @@ class OpenAPIRequestHandler(tornado_openapi3.handler.OpenAPIRequestHandler):
class RequestHandler(tornado.web.RequestHandler):
def set_links(self, **links) -> None:
self.set_header(
"Link",
", ".join(
[f'<{self.url(url)}>; rel="{rel}"' for rel, url in links.items()]
),
)
class SearchHandler(RequestHandler):
def url(self, url: str) -> str:
scheme_override = self.application.settings["scheme"]
if not scheme_override:
@ -158,13 +148,30 @@ class SearchHandler(RequestHandler):
)
)
def set_links(self, **links) -> None:
self.set_header(
"Link",
", ".join(
[f'<{self.url(url)}>; rel="{rel}"' for rel, url in links.items()]
),
)
@property
def pool(self) -> psycopg_pool.AsyncConnectionPool:
pool = getattr(self.application, "pool", None)
if not pool:
raise RuntimeError("Database pool not initialized")
return pool
class SearchHandler(RequestHandler):
async def get(self) -> None:
async with self.application.pool.connection() as conn:
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
query = self.get_argument("q", "")
in_collection = self.get_argument("in_collection", None)
page = max(1, int(self.get_argument("page", 1)))
limit = int(self.get_argument("limit", 10))
page = max(1, int(self.get_argument("page", "1")))
limit = int(self.get_argument("limit", "10"))
sort_by = self.get_argument("sort_by", "rarity")
search = tutor.search.search.parse(query)
copies = await tutor.database.advanced_search(
@ -192,7 +199,7 @@ class SearchHandler(RequestHandler):
class CollectionHandler(RequestHandler):
async def get(self) -> None:
async with self.application.pool.connection() as conn:
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
self.write(
json.dumps(
@ -203,9 +210,9 @@ class CollectionHandler(RequestHandler):
class DecksHandler(RequestHandler):
async def get(self) -> None:
page = max(1, int(self.get_argument("page", 1)))
limit = int(self.get_argument("limit", 10))
async with self.application.pool.connection() as conn:
page = max(1, int(self.get_argument("page", "1")))
limit = int(self.get_argument("limit", "10"))
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
decks = await tutor.database.get_decks(
cursor, limit=limit + 1, offset=limit * (page - 1)
@ -229,7 +236,7 @@ class DeckHandler(RequestHandler):
async def get(self, deck_id) -> None:
self.set_header("Content-Type", "application/json")
self.set_header("Access-Control-Allow-Origin", "*")
async with self.application.pool.connection() as conn:
async with self.pool.connection() as conn:
async with conn.cursor() as cursor:
deck = await tutor.database.get_deck(cursor, deck_id)
if not deck:
@ -253,7 +260,7 @@ class TemplateHandler(RequestHandler):
async def get(self) -> None:
self.set_header("Content-Type", self.content_type)
return self.render(self.path, **self.vars)
self.render(self.path, **self.vars)
class StaticFileHandler(tornado.web.StaticFileHandler):