diff --git a/tutor/cli.py b/tutor/cli.py index aa1de7c..e3c557f 100644 --- a/tutor/cli.py +++ b/tutor/cli.py @@ -73,11 +73,30 @@ def import_group(): @import_group.command("cards") @click.argument("filename", type=click.Path(dir_okay=False)) +@click.option( + "--mode", + type=click.Choice([mode.value for mode in tutor.csvimport.ImportMode]), + default=tutor.csvimport.ImportMode.add.value, +) @click.pass_context -def import_cards(ctx, filename): - tornado.ioloop.IOLoop.current().run_sync( - lambda: tutor.csvimport.load(ctx.obj, filename) - ) +def import_cards(ctx, filename, mode): + async def _import_cards(): + # count total lines + with open(filename, "r") as csvfile: + linecount = sum(1 for _ in csvfile) + last_line = 0 + with click.progressbar( + length=linecount, + label=f"Importing cards ({mode})", + ) as bar: + async for card, line in tutor.csvimport.load( + ctx.obj, filename, mode=tutor.csvimport.ImportMode[mode] + ): + bar.update(line - last_line) + last_line = line + bar.update(linecount - last_line) + + tornado.ioloop.IOLoop.current().run_sync(_import_cards) @import_group.command("deck") diff --git a/tutor/csvimport.py b/tutor/csvimport.py index 85180a9..f364805 100644 --- a/tutor/csvimport.py +++ b/tutor/csvimport.py @@ -1,4 +1,5 @@ import csv +import enum import logging import typing @@ -8,7 +9,14 @@ import tutor.database import tutor.models -async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]: +class ImportMode(enum.Enum): + add = "add" + replace = "replace" + + +async def load( + settings: dict, filename: str, mode: ImportMode = ImportMode.add +) -> typing.Iterable[typing.Tuple[tutor.models.Card, int]]: """Load cards from a CSV file. Currently supports the following formats: @@ -17,7 +25,9 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]: """ cards = [] async with aiosqlite.connect(settings["database"]) as db: - with open(filename) as csvfile: + if mode == ImportMode.replace: + await tutor.database.clear_copies(db) + with open(filename, "r") as csvfile: reader = csv.DictReader(csvfile) for row in reader: is_foil = "Foil" in row and row["Foil"].lower() == "foil" @@ -35,13 +45,15 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]: foil=is_foil or None, ) if not found: - logging.warning("Could not find card for row %s", row) + # logging.warning("Could not find card for row %s", row) + continue elif len(found) > 1: - logging.warning( - "Found %s possibilities for row %s", len(found), row - ) - for card in found: - logging.warning(card) + # logging.warning( + # "Found %s possibilities for row %s", len(found), row + # ) + # for card in found: + # logging.warning(card) + continue else: card = tutor.models.CardCopy( card=found[0], @@ -50,7 +62,7 @@ async def load(settings: dict, filename: str) -> typing.List[tutor.models.Card]: ) logging.info((quantity, card)) for i in range(quantity): - cards.append(card) await tutor.database.store_copy(db, card) + yield card, reader.line_num + await db.commit() - return cards diff --git a/tutor/database.py b/tutor/database.py index 7abf895..f62aedd 100644 --- a/tutor/database.py +++ b/tutor/database.py @@ -327,6 +327,15 @@ async def store_copy(db: aiosqlite.Connection, copy: tutor.models.CardCopy) -> N ) +async def clear_copies( + db: aiosqlite.Connection, collection: typing.Optional[str] = None +): + if collection: + await db.execute("DELETE FROM copies WHERE collection = ?", collection) + else: + await db.execute("DELETE FROM copies") + + async def store_deck(db: aiosqlite.Connection, name: str) -> None: cursor = await db.execute( "INSERT INTO `decks` (`name`) VALUES (:name)",