Add import modes with progress reporting

This commit is contained in:
Correl Roush 2022-02-10 15:55:40 -05:00
parent f326dd3ecd
commit eee0edf2da
3 changed files with 54 additions and 14 deletions

View file

@ -73,11 +73,30 @@ def import_group():
@import_group.command("cards")
@click.argument("filename", type=click.Path(dir_okay=False))
@click.option(
"--mode",
type=click.Choice([mode.value for mode in tutor.csvimport.ImportMode]),
default=tutor.csvimport.ImportMode.add.value,
)
@click.pass_context
def import_cards(ctx, filename):
tornado.ioloop.IOLoop.current().run_sync(
lambda: tutor.csvimport.load(ctx.obj, filename)
)
def import_cards(ctx, filename, mode):
async def _import_cards():
# count total lines
with open(filename, "r") as csvfile:
linecount = sum(1 for _ in csvfile)
last_line = 0
with click.progressbar(
length=linecount,
label=f"Importing cards ({mode})",
) as bar:
async for card, line in tutor.csvimport.load(
ctx.obj, filename, mode=tutor.csvimport.ImportMode[mode]
):
bar.update(line - last_line)
last_line = line
bar.update(linecount - last_line)
tornado.ioloop.IOLoop.current().run_sync(_import_cards)
@import_group.command("deck")

View file

@ -1,4 +1,5 @@
import csv
import enum
import logging
import typing
@ -8,7 +9,14 @@ import tutor.database
import tutor.models
async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]:
class ImportMode(enum.Enum):
add = "add"
replace = "replace"
async def load(
settings: dict, filename: str, mode: ImportMode = ImportMode.add
) -> typing.Iterable[typing.Tuple[tutor.models.Card, int]]:
"""Load cards from a CSV file.
Currently supports the following formats:
@ -17,7 +25,9 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]:
"""
cards = []
async with aiosqlite.connect(settings["database"]) as db:
with open(filename) as csvfile:
if mode == ImportMode.replace:
await tutor.database.clear_copies(db)
with open(filename, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
is_foil = "Foil" in row and row["Foil"].lower() == "foil"
@ -35,13 +45,15 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]:
foil=is_foil or None,
)
if not found:
logging.warning("Could not find card for row %s", row)
# logging.warning("Could not find card for row %s", row)
continue
elif len(found) > 1:
logging.warning(
"Found %s possibilities for row %s", len(found), row
)
for card in found:
logging.warning(card)
# logging.warning(
# "Found %s possibilities for row %s", len(found), row
# )
# for card in found:
# logging.warning(card)
continue
else:
card = tutor.models.CardCopy(
card=found[0],
@ -50,7 +62,7 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]:
)
logging.info((quantity, card))
for i in range(quantity):
cards.append(card)
await tutor.database.store_copy(db, card)
yield card, reader.line_num
await db.commit()
return cards

View file

@ -327,6 +327,15 @@ async def store_copy(db: aiosqlite.Connection, copy: tutor.models.CardCopy) -> N
)
async def clear_copies(
db: aiosqlite.Connection, collection: typing.Optional[str] = None
):
if collection:
await db.execute("DELETE FROM copies WHERE collection = ?", collection)
else:
await db.execute("DELETE FROM copies")
async def store_deck(db: aiosqlite.Connection, name: str) -> None:
cursor = await db.execute(
"INSERT INTO `decks` (`name`) VALUES (:name)",