diff --git a/README.md b/README.md index 7ec3ccd..4f8a858 100755 --- a/README.md +++ b/README.md @@ -44,8 +44,8 @@ Start by creating a Dejavu object with your configurations settings (Dejavu take ... "database": { ... "host": "127.0.0.1", ... "user": "root", -... "passwd": , -... "db": , +... "password": , +... "database": , ... } ... } >>> djv = Dejavu(config) @@ -91,8 +91,8 @@ An example configuration is as follows: ... "database": { ... "host": "127.0.0.1", ... "user": "root", -... "passwd": "Password123", -... "db": "dejavu_db", +... "password": "Password123", +... "database": "dejavu_db", ... }, ... "database_type" : "mysql", ... "fingerprint_limit" : 10 diff --git a/dejavu/config/config.py b/dejavu/config/config.py index 4e28bcc..c2f0edd 100644 --- a/dejavu/config/config.py +++ b/dejavu/config/config.py @@ -8,7 +8,8 @@ OFFSET_SECS = 'offset_seconds' # DATABASE CLASS INSTANCES: DATABASES = { - 'mysql': ("dejavu.database_handler.mysql_database", "MySQLDatabase") + 'mysql': ("dejavu.database_handler.mysql_database", "MySQLDatabase"), + 'postgres': ("dejavu.database_handler.postgres_database", "PostgreSQLDatabase") } # TABLE SONGS diff --git a/dejavu/database.py b/dejavu/database.py index bc7154c..91c56ca 100755 --- a/dejavu/database.py +++ b/dejavu/database.py @@ -153,10 +153,10 @@ class Database(object, metaclass=abc.ABCMeta): def get_database(database_type="mysql"): - path, db_class_name = DATABASES[database_type] try: + path, db_class_name = DATABASES[database_type] db_module = importlib.import_module(path) db_class = getattr(db_module, db_class_name) return db_class - except ImportError: + except (ImportError, KeyError): raise TypeError("Unsupported database type supplied.") diff --git a/dejavu/database_handler/postgres_database.py b/dejavu/database_handler/postgres_database.py new file mode 100755 index 0000000..c73d01a --- /dev/null +++ b/dejavu/database_handler/postgres_database.py @@ -0,0 +1,238 @@ +import queue + +import psycopg2 +from psycopg2.extras import DictCursor + +import dejavu.database_handler.postgres_queries as queries +from dejavu.database import Database + + +class PostgreSQLDatabase(Database): + type = "postgres" + + def __init__(self, **options): + super().__init__() + self.cursor = cursor_factory(**options) + self._options = options + + def after_fork(self): + # Clear the cursor cache, we don't want any stale connections from + # the previous process. + Cursor.clear_cache() + + def setup(self): + """ + Creates any non-existing tables required for dejavu to function. + + This also removes all songs that have been added but have no + fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(queries.CREATE_SONGS_TABLE) + cur.execute(queries.CREATE_FINGERPRINTS_TABLE) + cur.execute(queries.DELETE_UNFINGERPRINTED) + + def empty(self): + """ + Drops tables created by dejavu and then creates them again + by calling `SQLDatabase.setup`. + + .. warning: + This will result in a loss of data + """ + with self.cursor() as cur: + cur.execute(queries.DROP_FINGERPRINTS) + cur.execute(queries.DROP_SONGS) + + self.setup() + + def delete_unfingerprinted_songs(self): + """ + Removes all songs that have no fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(queries.DELETE_UNFINGERPRINTED) + + def get_num_songs(self): + """ + Returns number of songs the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(queries.SELECT_UNIQUE_SONG_IDS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + + return count + + def get_num_fingerprints(self): + """ + Returns number of fingerprints the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(queries.SELECT_NUM_FINGERPRINTS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + cur.close() + + return count + + def set_song_fingerprinted(self, sid): + """ + Set the fingerprinted flag to TRUE (1) once a song has been completely + fingerprinted in the database. + """ + with self.cursor() as cur: + cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,)) + + def get_songs(self): + """ + Return songs that have the fingerprinted flag set TRUE (1). + """ + with self.cursor(dictionary=True) as cur: + cur.execute(queries.SELECT_SONGS) + for row in cur: + yield row + + def get_song_by_id(self, sid): + """ + Returns song by its ID. + """ + with self.cursor(dictionary=True) as cur: + cur.execute(queries.SELECT_SONG, (sid,)) + return cur.fetchone() + + def insert(self, hash, sid, offset): + """ + Insert a (sha1, song_id, offset) row into database. + """ + with self.cursor() as cur: + cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset)) + + def insert_song(self, song_name, file_hash): + """ + Inserts song in the database and returns the ID of the inserted record. + """ + with self.cursor() as cur: + cur.execute(queries.INSERT_SONG, (song_name, file_hash)) + return cur.fetchone()[0] + + def query(self, hash): + """ + Return all tuples associated with hash. + + If hash is None, returns all entries in the + database (be careful with that one!). + """ + if hash: + with self.cursor() as cur: + cur.execute(queries.SELECT, (hash,)) + for sid, offset in cur: + yield (sid, offset) + else: # select all if no key + with self.cursor() as cur: + cur.execute(queries.SELECT_ALL) + for sid, offset in cur: + yield (sid, offset) + + def get_iterable_kv_pairs(self): + """ + Returns all tuples in database. + """ + return self.query(None) + + def insert_hashes(self, sid, hashes, batch=1000): + """ + Insert series of hash => song_id, offset + values into the database. + """ + values = [(sid, hash, int(offset)) for hash, offset in hashes] + + with self.cursor() as cur: + for index in range(0, len(hashes), batch): + cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch]) + + def return_matches(self, hashes, batch=1000): + """ + Return the (song_id, offset_diff) tuples associated with + a list of (sha1, sample_offset) values. + """ + # Create a dictionary of hash => offset pairs for later lookups + mapper = {} + for hash, offset in hashes: + mapper[hash.upper()] = offset + + # Get an iterable of all the hashes we need + values = list(mapper.keys()) + + with self.cursor() as cur: + for index in range(0, len(values), batch): + # Create our IN part of the query + query = queries.SELECT_MULTIPLE + query = query % ', '.join(["decode(%s, 'hex')"] * len(values[index: index + batch])) + + cur.execute(query, values[index: index + batch]) + + for hash, sid, offset in cur: + # (sid, db_offset - song_sampled_offset) + yield (sid, offset - mapper[hash.upper()]) + + def __getstate__(self): + return self._options, + + def __setstate__(self, state): + self._options, = state + self.cursor = cursor_factory(**self._options) + + +def cursor_factory(**factory_options): + def cursor(**options): + options.update(factory_options) + return Cursor(**options) + return cursor + + +class Cursor(object): + """ + Establishes a connection to the database and returns an open cursor. + # Use as context manager + with Cursor() as cur: + cur.execute(query) + ... + """ + def __init__(self, dictionary=False, **options): + super().__init__() + + self._cache = queue.Queue(maxsize=5) + + try: + conn = self._cache.get_nowait() + # Ping the connection before using it from the cache. + conn.ping(True) + except queue.Empty: + conn = psycopg2.connect(**options) + + self.conn = conn + self.dictionary = dictionary + + @classmethod + def clear_cache(cls): + cls._cache = queue.Queue(maxsize=5) + + def __enter__(self): + if self.dictionary: + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + else: + self.cursor = self.conn.cursor() + return self.cursor + + def __exit__(self, extype, exvalue, traceback): + # if we had a PostgreSQL related error we try to rollback the cursor. + if extype is psycopg2.DatabaseError: + self.cursor.rollback() + + self.cursor.close() + self.conn.commit() + + # Put it back on the queue + try: + self._cache.put_nowait(self.conn) + except queue.Full: + self.conn.close() diff --git a/dejavu/database_handler/postgres_queries.py b/dejavu/database_handler/postgres_queries.py new file mode 100644 index 0000000..df64332 --- /dev/null +++ b/dejavu/database_handler/postgres_queries.py @@ -0,0 +1,104 @@ +from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FINGERPRINTS_TABLENAME, + SONGS_TABLENAME) + +# creates +CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" ( + "{FIELD_SONG_ID}" SERIAL + , "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL + , "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0 + , "{FIELD_FILE_SHA1}" BYTEA + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}") + , CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}") + ); +""" + +CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_HASH}" BYTEA NOT NULL + , "{FIELD_SONG_ID}" INT NOT NULL + , "{FIELD_OFFSET}" INT NOT NULL + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}") + , CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}") + REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}"); +""" + +CREATE_FINGERPRINTS_TABLE_INDEX = f""" + CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}"); +""" + +# inserts (ignores duplicates) +INSERT_FINGERPRINT = f""" + INSERT INTO "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_SONG_ID}" + , "{FIELD_HASH}" + , "{FIELD_OFFSET}") + VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING; +""" + +INSERT_SONG = f""" + INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}") + VALUES (%s, decode(%s, 'hex')) + RETURNING "{FIELD_SONG_ID}"; +""" + +# selects +SELECT = f""" + SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" = decode(%s, 'hex'); +""" + +SELECT_MULTIPLE = f""" + SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" IN (%s); +""" + +SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";' + +SELECT_SONG = f""" + SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_SONG_ID}" = %s; +""" + +SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";' + +SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT("{FIELD_SONG_ID}") AS n + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; +""" + +SELECT_SONGS = f""" + SELECT + "{FIELD_SONG_ID}" + , "{FIELD_SONGNAME}" + , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; +""" + +# drops +DROP_FINGERPRINTS = f'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";' +DROP_SONGS = f'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";' + +# update +UPDATE_SONG_FINGERPRINTED = f""" + UPDATE "{SONGS_TABLENAME}" SET "{FIELD_FINGERPRINTED}" = 1, "date_modified" = now() WHERE "{FIELD_SONG_ID}" = %s; +""" + +# delete +DELETE_UNFINGERPRINTED = f""" + DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0; +"""