diff --git a/.gitignore b/.gitignore index 8f8bc7e..01bd861 100755 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,3 @@ *.pyc -wav -mp3 -*.wav -*.mp3 .DS_Store *.cnf diff --git a/dejavu.cnf.SAMPLE b/dejavu.cnf.SAMPLE index 9a89e25..e161192 100755 --- a/dejavu.cnf.SAMPLE +++ b/dejavu.cnf.SAMPLE @@ -4,5 +4,6 @@ "user": "root", "password": "rootpass", "database": "dejavu" - } + }, + "database_type": "mysql" } diff --git a/dejavu.py b/dejavu.py index 387021e..84ea42a 100755 --- a/dejavu.py +++ b/dejavu.py @@ -1,16 +1,12 @@ -#!/usr/bin/python - import argparse import json -import os +from os.path import isdir import sys -import warnings from argparse import RawTextHelpFormatter from dejavu import Dejavu -from dejavu.recognize import FileRecognizer, MicrophoneRecognizer - -warnings.filterwarnings("ignore") +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer +from dejavu.logic.recognizer.file_recognizer import FileRecognizer DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE" @@ -58,7 +54,6 @@ if __name__ == '__main__': config_file = args.config if config_file is None: config_file = DEFAULT_CONFIG_FILE - # print ("Using default config file: {config_file}") djv = init(config_file) if args.fingerprint: @@ -71,22 +66,19 @@ if __name__ == '__main__': elif len(args.fingerprint) == 1: filepath = args.fingerprint[0] - if os.path.isdir(filepath): + if isdir(filepath): print("Please specify an extension if you'd like to fingerprint a directory!") sys.exit(1) djv.fingerprint_file(filepath) elif args.recognize: # Recognize audio source - song = None + songs = None source = args.recognize[0] opt_arg = args.recognize[1] if source in ('mic', 'microphone'): - song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg) + songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg) elif source == 'file': - song = djv.recognize(FileRecognizer, opt_arg) - decoded_song = repr(song).decode('string_escape') - print(decoded_song) - - sys.exit(0) + songs = djv.recognize(FileRecognizer, opt_arg) + print(songs) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index bf89d3a..ade6232 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -3,13 +3,13 @@ import os import sys import traceback -import dejavu.decoder as decoder -from dejavu.config.config import (CONFIDENCE, DEFAULT_FS, - DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, - FIELD_FILE_SHA1, OFFSET, OFFSET_SECS, - SONG_ID, SONG_NAME, TOPN) -from dejavu.database import get_database -from dejavu.fingerprint import fingerprint +import dejavu.logic.decoder as decoder +from dejavu.base_classes.base_database import get_database +from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS, + DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, + FIELD_FILE_SHA1, OFFSET, OFFSET_SECS, + SONG_ID, SONG_NAME, TOPN) +from dejavu.logic.fingerprint import fingerprint class Dejavu: @@ -71,7 +71,7 @@ class Dejavu: continue except StopIteration: break - except: + except Exception: print("Failed fingerprinting") # Print traceback because we can't reraise it here traceback.print_exc(file=sys.stdout) @@ -119,6 +119,7 @@ class Dejavu: diff_counter = {} largest_count = 0 + # TODO: review logic to get topn results. for tup in matches: sid, diff = tup if diff not in diff_counter: diff --git a/dejavu/base_classes/__init__.py b/dejavu/base_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/database.py b/dejavu/base_classes/base_database.py similarity index 89% rename from dejavu/database.py rename to dejavu/base_classes/base_database.py index 91c56ca..77b118b 100755 --- a/dejavu/database.py +++ b/dejavu/base_classes/base_database.py @@ -1,9 +1,11 @@ import abc import importlib -from dejavu.config.config import DATABASES +from typing import Dict + +from dejavu.config.settings import DATABASES -class Database(object, metaclass=abc.ABCMeta): +class BaseDatabase(object, metaclass=abc.ABCMeta): # Name of your Database subclass, this is used in configuration # to refer to your class type = None @@ -70,17 +72,16 @@ class Database(object, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def get_songs(self): + def get_songs(self) -> Dict[str, str]: """ - Returns all fully fingerprinted songs in the database. + Returns all fully fingerprinted songs in the database. Result must be a Dictionary. """ pass @abc.abstractmethod - def get_song_by_id(self, sid): + def get_song_by_id(self, sid) -> Dict[str, str]: """ - Return a song by its identifier - + Return a song by its identifier. Result must be a Dictionary. sid: Song identifier """ pass @@ -124,7 +125,7 @@ class Database(object, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def insert_hashes(self, sid, hashes): + def insert_hashes(self, sid, hashes, batch=1000): """ Insert a multitude of fingerprints. @@ -133,7 +134,6 @@ class Database(object, metaclass=abc.ABCMeta): - hash: Part of a sha1 hash, in hexadecimal format - offset: Offset this hash was created from/at. """ - pass @abc.abstractmethod def return_matches(self, hashes): diff --git a/dejavu/base_classes/base_recognizer.py b/dejavu/base_classes/base_recognizer.py new file mode 100644 index 0000000..cd07c01 --- /dev/null +++ b/dejavu/base_classes/base_recognizer.py @@ -0,0 +1,19 @@ +import abc + +from dejavu.config.settings import DEFAULT_FS + + +class BaseRecognizer(object, metaclass=abc.ABCMeta): + def __init__(self, dejavu): + self.dejavu = dejavu + self.Fs = DEFAULT_FS + + def _recognize(self, *data): + matches = [] + for d in data: + matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) + return self.dejavu.align_matches(matches) + + @abc.abstractmethod + def recognize(self): + pass # base class does nothing diff --git a/dejavu/base_classes/common_database.py b/dejavu/base_classes/common_database.py new file mode 100644 index 0000000..699385e --- /dev/null +++ b/dejavu/base_classes/common_database.py @@ -0,0 +1,194 @@ +import abc + +from dejavu.base_classes.base_database import BaseDatabase + + +class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): + # Since several methods across different databases are actually just the same + # I've built this class with the idea to reuse that logic instead of copy pasting + # over and over the same code. + + def __init__(self): + super().__init__() + + def before_fork(self): + """ + Called before the database instance is given to the new process + """ + pass + + def after_fork(self): + """ + Called after the database instance has been given to the new process + + This will be called in the new process. + """ + pass + + def setup(self): + """ + Called on creation or shortly afterwards. + """ + with self.cursor() as cur: + cur.execute(self.CREATE_SONGS_TABLE) + cur.execute(self.CREATE_FINGERPRINTS_TABLE) + cur.execute(self.DELETE_UNFINGERPRINTED) + + def empty(self): + """ + Called when the database should be cleared of all data. + """ + with self.cursor() as cur: + cur.execute(self.DROP_FINGERPRINTS) + cur.execute(self.DROP_SONGS) + + self.setup() + + def delete_unfingerprinted_songs(self): + """ + Called to remove any song entries that do not have any fingerprints + associated with them. + """ + with self.cursor() as cur: + cur.execute(self.DELETE_UNFINGERPRINTED) + + def get_num_songs(self): + """ + Returns the amount of songs in the database. + """ + with self.cursor() as cur: + cur.execute(self.SELECT_UNIQUE_SONG_IDS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + + return count + + def get_num_fingerprints(self): + """ + Returns the number of fingerprints in the database. + """ + with self.cursor() as cur: + cur.execute(self.SELECT_NUM_FINGERPRINTS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + cur.close() + + return count + + def set_song_fingerprinted(self, sid): + """ + Sets a specific song as having all fingerprints in the database. + + sid: Song identifier + """ + with self.cursor() as cur: + cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,)) + + def get_songs(self): + """ + Returns all fully fingerprinted songs in the database. Result must be a Dictionary. + """ + with self.cursor(dictionary=True) as cur: + cur.execute(self.SELECT_SONGS) + for row in cur: + yield row + + def get_song_by_id(self, sid): + """ + Return a song by its identifier. Result must be a Dictionary. + sid: Song identifier + """ + with self.cursor(dictionary=True) as cur: + cur.execute(self.SELECT_SONG, (sid,)) + return cur.fetchone() + + def insert(self, hash, sid, offset): + """ + Inserts a single fingerprint into the database. + + hash: Part of a sha1 hash, in hexadecimal format + sid: Song identifier this fingerprint is off + offset: The offset this hash is from + """ + with self.cursor() as cur: + cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset)) + + @abc.abstractmethod + def insert_song(self, song_name): + """ + Inserts a song name into the database, returns the new + identifier of the song. + + song_name: The name of the song. + """ + pass + + def query(self, fingerprint): + """ + Returns all matching fingerprint entries associated with + the given hash as parameter. + + hash: Part of a sha1 hash, in hexadecimal format + """ + if fingerprint: + with self.cursor() as cur: + cur.execute(self.SELECT, (fingerprint,)) + for sid, offset in cur: + yield (sid, offset) + else: # select all if no key + with self.cursor() as cur: + cur.execute(self.SELECT_ALL) + for sid, offset in cur: + yield (sid, offset) + + def get_iterable_kv_pairs(self): + """ + Returns all fingerprints in the database. + """ + return self.query(None) + + def insert_hashes(self, sid, hashes, batch=1000): + """ + Insert a multitude of fingerprints. + + sid: Song identifier the fingerprints belong to + hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + """ + values = [(sid, hsh, int(offset)) for hsh, offset in hashes] + + with self.cursor() as cur: + for index in range(0, len(hashes), batch): + cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch]) + + def return_matches(self, hashes, batch=1000): + """ + Searches the database for pairs of (hash, offset) values. + + hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + + Returns a sequence of (sid, offset_difference) tuples. + + sid: Song identifier + offset_difference: (offset - database_offset) + """ + # Create a dictionary of hash => offset pairs for later lookups + mapper = {} + for hsh, offset in hashes: + mapper[hsh.upper()] = offset + + # Get an iterable of all the hashes we need + values = list(mapper.keys()) + + with self.cursor() as cur: + for index in range(0, len(values), batch): + # Create our IN part of the query + query = self.SELECT_MULTIPLE + query = query % ', '.join([self.IN_MATCH] * len(values[index: index + batch])) + + cur.execute(query, values[index: index + batch]) + + for hsh, sid, offset in cur: + # (sid, db_offset - song_sampled_offset) + yield (sid, offset - mapper[hsh]) diff --git a/dejavu/config/config.py b/dejavu/config/settings.py similarity index 99% rename from dejavu/config/config.py rename to dejavu/config/settings.py index c2f0edd..bc711b7 100644 --- a/dejavu/config/config.py +++ b/dejavu/config/settings.py @@ -72,4 +72,4 @@ PEAK_SORT = True FINGERPRINT_REDUCTION = 20 # Number of results being returned for file recognition -TOPN = 2 \ No newline at end of file +TOPN = 2 diff --git a/dejavu/database_handler/mysql_database.py b/dejavu/database_handler/mysql_database.py index 4fb9bea..e1e4257 100755 --- a/dejavu/database_handler/mysql_database.py +++ b/dejavu/database_handler/mysql_database.py @@ -3,13 +3,114 @@ import queue import mysql.connector from mysql.connector.errors import DatabaseError -import dejavu.database_handler.mysql_queries as queries -from dejavu.database import Database +from dejavu.base_classes.common_database import CommonDatabase +from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FINGERPRINTS_TABLENAME, + SONGS_TABLENAME) -class MySQLDatabase(Database): +class MySQLDatabase(CommonDatabase): type = "mysql" + # CREATES + CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` ( + `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT + , `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL + , `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0 + , `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL + , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + , CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`) + , CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`) + ) ENGINE=INNODB; + """ + + CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_HASH}` BINARY(10) NOT NULL + , `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL + , `{FIELD_OFFSET}` INT UNSIGNED NOT NULL + , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + , INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`) + , CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}` + UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`) + , CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`) + REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE + ) ENGINE=INNODB; + """ + + # INSERTS (IGNORES DUPLICATES) + INSERT_FINGERPRINT = f""" + INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_SONG_ID}` + , `{FIELD_HASH}` + , `{FIELD_OFFSET}`) + VALUES (%s, UNHEX(%s), %s); + """ + + INSERT_SONG = f""" + INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`) + VALUES (%s, UNHEX(%s)); + """ + + # SELECTS + SELECT = f""" + SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` = UNHEX(%s); + """ + + SELECT_MULTIPLE = f""" + SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` IN (%s); + """ + + SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;" + + SELECT_SONG = f""" + SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_SONG_ID}` = %s; + """ + + SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;" + + SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT(`{FIELD_SONG_ID}`) AS n + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; + """ + + SELECT_SONGS = f""" + SELECT + `{FIELD_SONG_ID}` + , `{FIELD_SONGNAME}` + , HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; + """ + + # DROPS + DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;" + DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;" + + # update + UPDATE_SONG_FINGERPRINTED = f""" + UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s; + """ + + # DELETE + DELETE_UNFINGERPRINTED = f""" + DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0; + """ + + # IN + IN_MATCH = f"UNHEX(%s)" + def __init__(self, **options): super().__init__() self.cursor = cursor_factory(**options) @@ -20,160 +121,14 @@ class MySQLDatabase(Database): # the previous process. Cursor.clear_cache() - def setup(self): - """ - Creates any non-existing tables required for dejavu to function. - - This also removes all songs that have been added but have no - fingerprints associated with them. - """ - with self.cursor() as cur: - cur.execute(queries.CREATE_SONGS_TABLE) - cur.execute(queries.CREATE_FINGERPRINTS_TABLE) - cur.execute(queries.DELETE_UNFINGERPRINTED) - - def empty(self): - """ - Drops tables created by dejavu and then creates them again - by calling `SQLDatabase.setup`. - - .. warning: - This will result in a loss of data - """ - with self.cursor() as cur: - cur.execute(queries.DROP_FINGERPRINTS) - cur.execute(queries.DROP_SONGS) - - self.setup() - - def delete_unfingerprinted_songs(self): - """ - Removes all songs that have no fingerprints associated with them. - """ - with self.cursor() as cur: - cur.execute(queries.DELETE_UNFINGERPRINTED) - - def get_num_songs(self): - """ - Returns number of songs the database has fingerprinted. - """ - with self.cursor() as cur: - cur.execute(queries.SELECT_UNIQUE_SONG_IDS) - count = cur.fetchone()[0] if cur.rowcount != 0 else 0 - - return count - - def get_num_fingerprints(self): - """ - Returns number of fingerprints the database has fingerprinted. - """ - with self.cursor() as cur: - cur.execute(queries.SELECT_NUM_FINGERPRINTS) - count = cur.fetchone()[0] if cur.rowcount != 0 else 0 - cur.close() - - return count - - def set_song_fingerprinted(self, sid): - """ - Set the fingerprinted flag to TRUE (1) once a song has been completely - fingerprinted in the database. - """ - with self.cursor() as cur: - cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,)) - - def get_songs(self): - """ - Return songs that have the fingerprinted flag set TRUE (1). - """ - with self.cursor(dictionary=True) as cur: - cur.execute(queries.SELECT_SONGS) - for row in cur: - yield row - - def get_song_by_id(self, sid): - """ - Returns song by its ID. - """ - with self.cursor(dictionary=True) as cur: - cur.execute(queries.SELECT_SONG, (sid,)) - return cur.fetchone() - - def insert(self, hash, sid, offset): - """ - Insert a (sha1, song_id, offset) row into database. - """ - with self.cursor() as cur: - cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset)) - def insert_song(self, song_name, file_hash): """ Inserts song in the database and returns the ID of the inserted record. """ with self.cursor() as cur: - cur.execute(queries.INSERT_SONG, (song_name, file_hash)) + cur.execute(self.INSERT_SONG, (song_name, file_hash)) return cur.lastrowid - def query(self, hash): - """ - Return all tuples associated with hash. - - If hash is None, returns all entries in the - database (be careful with that one!). - """ - if hash: - with self.cursor() as cur: - cur.execute(queries.SELECT, (hash,)) - for sid, offset in cur: - yield (sid, offset) - else: # select all if no key - with self.cursor() as cur: - cur.execute(queries.SELECT_ALL) - for sid, offset in cur: - yield (sid, offset) - - def get_iterable_kv_pairs(self): - """ - Returns all tuples in database. - """ - return self.query(None) - - def insert_hashes(self, sid, hashes, batch=1000): - """ - Insert series of hash => song_id, offset - values into the database. - """ - values = [(sid, hash, int(offset)) for hash, offset in hashes] - - with self.cursor() as cur: - for index in range(0, len(hashes), batch): - cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch]) - - def return_matches(self, hashes, batch=1000): - """ - Return the (song_id, offset_diff) tuples associated with - a list of (sha1, sample_offset) values. - """ - # Create a dictionary of hash => offset pairs for later lookups - mapper = {} - for hash, offset in hashes: - mapper[hash.upper()] = offset - - # Get an iterable of all the hashes we need - values = list(mapper.keys()) - - with self.cursor() as cur: - for index in range(0, len(values), batch): - # Create our IN part of the query - query = queries.SELECT_MULTIPLE - query = query % ', '.join(['UNHEX(%s)'] * len(values[index: index + batch])) - - cur.execute(query, values[index: index + batch]) - - for hash, sid, offset in cur: - # (sid, db_offset - song_sampled_offset) - yield (sid, offset - mapper[hash]) - def __getstate__(self): return self._options, diff --git a/dejavu/database_handler/mysql_queries.py b/dejavu/database_handler/mysql_queries.py deleted file mode 100644 index ee27f3e..0000000 --- a/dejavu/database_handler/mysql_queries.py +++ /dev/null @@ -1,134 +0,0 @@ -from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, - FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, - FIELD_SONGNAME, FINGERPRINTS_TABLENAME, - SONGS_TABLENAME) - -""" -Queries: - -1) Find duplicates (shouldn't be any, though): - - select `hash`, `song_id`, `offset`, count(*) cnt - from fingerprints - group by `hash`, `song_id`, `offset` - having cnt > 1 - order by cnt asc; - -2) Get number of hashes by song: - - select song_id, song_name, count(song_id) as num - from fingerprints - natural join songs - group by song_id - order by count(song_id) desc; - -3) get hashes with highest number of collisions - - select - hash, - count(distinct song_id) as n - from fingerprints - group by `hash` - order by n DESC; - -=> 26 different songs with same fingerprint (392 times): - - select songs.song_name, fingerprints.offset - from fingerprints natural join songs - where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; -""" - -# creates -CREATE_SONGS_TABLE = f""" - CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` ( - `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT - , `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL - , `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0 - , `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL - , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - , CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`) - , CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`) - ) ENGINE=INNODB; -""" - -CREATE_FINGERPRINTS_TABLE = f""" - CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` ( - `{FIELD_HASH}` BINARY(10) NOT NULL - , `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL - , `{FIELD_OFFSET}` INT UNSIGNED NOT NULL - , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP - , INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`) - , CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}` - UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`) - , CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`) - REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE -) ENGINE=INNODB; -""" - -# inserts (ignores duplicates) -INSERT_FINGERPRINT = f""" - INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` ( - `{FIELD_SONG_ID}` - , `{FIELD_HASH}` - , `{FIELD_OFFSET}`) - VALUES (%s, UNHEX(%s), %s); -""" - -INSERT_SONG = f""" - INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`) - VALUES (%s, UNHEX(%s)); -""" - -# selects -SELECT = f""" - SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` - FROM `{FINGERPRINTS_TABLENAME}` - WHERE `{FIELD_HASH}` = UNHEX(%s); -""" - -SELECT_MULTIPLE = f""" - SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` - FROM `{FINGERPRINTS_TABLENAME}` - WHERE `{FIELD_HASH}` IN (%s); -""" - -SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;" - -SELECT_SONG = f""" - SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` - FROM `{SONGS_TABLENAME}` - WHERE `{FIELD_SONG_ID}` = %s; -""" - -SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;" - -SELECT_UNIQUE_SONG_IDS = f""" - SELECT COUNT(`{FIELD_SONG_ID}`) AS n - FROM `{SONGS_TABLENAME}` - WHERE `{FIELD_FINGERPRINTED}` = 1; -""" - -SELECT_SONGS = f""" - SELECT - `{FIELD_SONG_ID}` - , `{FIELD_SONGNAME}` - , HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` - FROM `{SONGS_TABLENAME}` - WHERE `{FIELD_FINGERPRINTED}` = 1; -""" - -# drops -DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;" -DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;" - -# update -UPDATE_SONG_FINGERPRINTED = f""" - UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s; -""" - -# delete -DELETE_UNFINGERPRINTED = f""" - DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0; -""" diff --git a/dejavu/database_handler/postgres_database.py b/dejavu/database_handler/postgres_database.py index c73d01a..341641b 100755 --- a/dejavu/database_handler/postgres_database.py +++ b/dejavu/database_handler/postgres_database.py @@ -3,13 +3,124 @@ import queue import psycopg2 from psycopg2.extras import DictCursor -import dejavu.database_handler.postgres_queries as queries -from dejavu.database import Database +from dejavu.base_classes.common_database import CommonDatabase +from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FINGERPRINTS_TABLENAME, + SONGS_TABLENAME) -class PostgreSQLDatabase(Database): +class PostgreSQLDatabase(CommonDatabase): type = "postgres" + # CREATES + CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" ( + "{FIELD_SONG_ID}" SERIAL + , "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL + , "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0 + , "{FIELD_FILE_SHA1}" BYTEA + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}") + , CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}") + ); + """ + + CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_HASH}" BYTEA NOT NULL + , "{FIELD_SONG_ID}" INT NOT NULL + , "{FIELD_OFFSET}" INT NOT NULL + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}") + , CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}") + REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" + USING hash ("{FIELD_HASH}"); + """ + + CREATE_FINGERPRINTS_TABLE_INDEX = f""" + CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" + USING hash ("{FIELD_HASH}"); + """ + + # INSERTS (IGNORES DUPLICATES) + INSERT_FINGERPRINT = f""" + INSERT INTO "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_SONG_ID}" + , "{FIELD_HASH}" + , "{FIELD_OFFSET}") + VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING; + """ + + INSERT_SONG = f""" + INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}") + VALUES (%s, decode(%s, 'hex')) + RETURNING "{FIELD_SONG_ID}"; + """ + + # SELECTS + SELECT = f""" + SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" = decode(%s, 'hex'); + """ + + SELECT_MULTIPLE = f""" + SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" IN (%s); + """ + + SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";' + + SELECT_SONG = f""" + SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_SONG_ID}" = %s; + """ + + SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";' + + SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT("{FIELD_SONG_ID}") AS n + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; + """ + + SELECT_SONGS = f""" + SELECT + "{FIELD_SONG_ID}" + , "{FIELD_SONGNAME}" + , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; + """ + + # DROPS + DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";' + DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";' + + # UPDATE + UPDATE_SONG_FINGERPRINTED = f""" + UPDATE "{SONGS_TABLENAME}" SET + "{FIELD_FINGERPRINTED}" = 1 + , "date_modified" = now() + WHERE "{FIELD_SONG_ID}" = %s; + """ + + # DELETE + DELETE_UNFINGERPRINTED = f""" + DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0; + """ + + # IN + IN_MATCH = f"decode(%s, 'hex')" + def __init__(self, **options): super().__init__() self.cursor = cursor_factory(**options) @@ -20,160 +131,14 @@ class PostgreSQLDatabase(Database): # the previous process. Cursor.clear_cache() - def setup(self): - """ - Creates any non-existing tables required for dejavu to function. - - This also removes all songs that have been added but have no - fingerprints associated with them. - """ - with self.cursor() as cur: - cur.execute(queries.CREATE_SONGS_TABLE) - cur.execute(queries.CREATE_FINGERPRINTS_TABLE) - cur.execute(queries.DELETE_UNFINGERPRINTED) - - def empty(self): - """ - Drops tables created by dejavu and then creates them again - by calling `SQLDatabase.setup`. - - .. warning: - This will result in a loss of data - """ - with self.cursor() as cur: - cur.execute(queries.DROP_FINGERPRINTS) - cur.execute(queries.DROP_SONGS) - - self.setup() - - def delete_unfingerprinted_songs(self): - """ - Removes all songs that have no fingerprints associated with them. - """ - with self.cursor() as cur: - cur.execute(queries.DELETE_UNFINGERPRINTED) - - def get_num_songs(self): - """ - Returns number of songs the database has fingerprinted. - """ - with self.cursor() as cur: - cur.execute(queries.SELECT_UNIQUE_SONG_IDS) - count = cur.fetchone()[0] if cur.rowcount != 0 else 0 - - return count - - def get_num_fingerprints(self): - """ - Returns number of fingerprints the database has fingerprinted. - """ - with self.cursor() as cur: - cur.execute(queries.SELECT_NUM_FINGERPRINTS) - count = cur.fetchone()[0] if cur.rowcount != 0 else 0 - cur.close() - - return count - - def set_song_fingerprinted(self, sid): - """ - Set the fingerprinted flag to TRUE (1) once a song has been completely - fingerprinted in the database. - """ - with self.cursor() as cur: - cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,)) - - def get_songs(self): - """ - Return songs that have the fingerprinted flag set TRUE (1). - """ - with self.cursor(dictionary=True) as cur: - cur.execute(queries.SELECT_SONGS) - for row in cur: - yield row - - def get_song_by_id(self, sid): - """ - Returns song by its ID. - """ - with self.cursor(dictionary=True) as cur: - cur.execute(queries.SELECT_SONG, (sid,)) - return cur.fetchone() - - def insert(self, hash, sid, offset): - """ - Insert a (sha1, song_id, offset) row into database. - """ - with self.cursor() as cur: - cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset)) - def insert_song(self, song_name, file_hash): """ Inserts song in the database and returns the ID of the inserted record. """ with self.cursor() as cur: - cur.execute(queries.INSERT_SONG, (song_name, file_hash)) + cur.execute(self.INSERT_SONG, (song_name, file_hash)) return cur.fetchone()[0] - def query(self, hash): - """ - Return all tuples associated with hash. - - If hash is None, returns all entries in the - database (be careful with that one!). - """ - if hash: - with self.cursor() as cur: - cur.execute(queries.SELECT, (hash,)) - for sid, offset in cur: - yield (sid, offset) - else: # select all if no key - with self.cursor() as cur: - cur.execute(queries.SELECT_ALL) - for sid, offset in cur: - yield (sid, offset) - - def get_iterable_kv_pairs(self): - """ - Returns all tuples in database. - """ - return self.query(None) - - def insert_hashes(self, sid, hashes, batch=1000): - """ - Insert series of hash => song_id, offset - values into the database. - """ - values = [(sid, hash, int(offset)) for hash, offset in hashes] - - with self.cursor() as cur: - for index in range(0, len(hashes), batch): - cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch]) - - def return_matches(self, hashes, batch=1000): - """ - Return the (song_id, offset_diff) tuples associated with - a list of (sha1, sample_offset) values. - """ - # Create a dictionary of hash => offset pairs for later lookups - mapper = {} - for hash, offset in hashes: - mapper[hash.upper()] = offset - - # Get an iterable of all the hashes we need - values = list(mapper.keys()) - - with self.cursor() as cur: - for index in range(0, len(values), batch): - # Create our IN part of the query - query = queries.SELECT_MULTIPLE - query = query % ', '.join(["decode(%s, 'hex')"] * len(values[index: index + batch])) - - cur.execute(query, values[index: index + batch]) - - for hash, sid, offset in cur: - # (sid, db_offset - song_sampled_offset) - yield (sid, offset - mapper[hash.upper()]) - def __getstate__(self): return self._options, diff --git a/dejavu/database_handler/postgres_queries.py b/dejavu/database_handler/postgres_queries.py deleted file mode 100644 index df64332..0000000 --- a/dejavu/database_handler/postgres_queries.py +++ /dev/null @@ -1,104 +0,0 @@ -from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, - FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, - FIELD_SONGNAME, FINGERPRINTS_TABLENAME, - SONGS_TABLENAME) - -# creates -CREATE_SONGS_TABLE = f""" - CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" ( - "{FIELD_SONG_ID}" SERIAL - , "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL - , "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0 - , "{FIELD_FILE_SHA1}" BYTEA - , "date_created" TIMESTAMP NOT NULL DEFAULT now() - , "date_modified" TIMESTAMP NOT NULL DEFAULT now() - , CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}") - , CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}") - ); -""" - -CREATE_FINGERPRINTS_TABLE = f""" - CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" ( - "{FIELD_HASH}" BYTEA NOT NULL - , "{FIELD_SONG_ID}" INT NOT NULL - , "{FIELD_OFFSET}" INT NOT NULL - , "date_created" TIMESTAMP NOT NULL DEFAULT now() - , "date_modified" TIMESTAMP NOT NULL DEFAULT now() - , CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}") - , CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}") - REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}"); -""" - -CREATE_FINGERPRINTS_TABLE_INDEX = f""" - CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}"); -""" - -# inserts (ignores duplicates) -INSERT_FINGERPRINT = f""" - INSERT INTO "{FINGERPRINTS_TABLENAME}" ( - "{FIELD_SONG_ID}" - , "{FIELD_HASH}" - , "{FIELD_OFFSET}") - VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING; -""" - -INSERT_SONG = f""" - INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}") - VALUES (%s, decode(%s, 'hex')) - RETURNING "{FIELD_SONG_ID}"; -""" - -# selects -SELECT = f""" - SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" - FROM "{FINGERPRINTS_TABLENAME}" - WHERE "{FIELD_HASH}" = decode(%s, 'hex'); -""" - -SELECT_MULTIPLE = f""" - SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}" - FROM "{FINGERPRINTS_TABLENAME}" - WHERE "{FIELD_HASH}" IN (%s); -""" - -SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";' - -SELECT_SONG = f""" - SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" - FROM "{SONGS_TABLENAME}" - WHERE "{FIELD_SONG_ID}" = %s; -""" - -SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";' - -SELECT_UNIQUE_SONG_IDS = f""" - SELECT COUNT("{FIELD_SONG_ID}") AS n - FROM "{SONGS_TABLENAME}" - WHERE "{FIELD_FINGERPRINTED}" = 1; -""" - -SELECT_SONGS = f""" - SELECT - "{FIELD_SONG_ID}" - , "{FIELD_SONGNAME}" - , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" - FROM "{SONGS_TABLENAME}" - WHERE "{FIELD_FINGERPRINTED}" = 1; -""" - -# drops -DROP_FINGERPRINTS = f'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";' -DROP_SONGS = f'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";' - -# update -UPDATE_SONG_FINGERPRINTED = f""" - UPDATE "{SONGS_TABLENAME}" SET "{FIELD_FINGERPRINTED}" = 1, "date_modified" = now() WHERE "{FIELD_SONG_ID}" = %s; -""" - -# delete -DELETE_UNFINGERPRINTED = f""" - DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0; -""" diff --git a/dejavu/logic/__init__.py b/dejavu/logic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/decoder.py b/dejavu/logic/decoder.py similarity index 97% rename from dejavu/decoder.py rename to dejavu/logic/decoder.py index 9299068..245d8b1 100755 --- a/dejavu/decoder.py +++ b/dejavu/logic/decoder.py @@ -1,10 +1,12 @@ -import os import fnmatch +import os +from hashlib import sha1 + import numpy as np from pydub import AudioSegment from pydub.utils import audioop -from . import wavio -from hashlib import sha1 + +from dejavu.third_party import wavio def unique_hash(filepath, blocksize=2**20): @@ -12,7 +14,7 @@ def unique_hash(filepath, blocksize=2**20): a file. Inspired by MD5 version here: http://stackoverflow.com/a/1131255/712997 - Works with large files. + Works with large files. """ s = sha1() with open(filepath, "rb") as f: diff --git a/dejavu/fingerprint.py b/dejavu/logic/fingerprint.py similarity index 90% rename from dejavu/fingerprint.py rename to dejavu/logic/fingerprint.py index ce8d8db..0653965 100755 --- a/dejavu/fingerprint.py +++ b/dejavu/logic/fingerprint.py @@ -9,11 +9,11 @@ from scipy.ndimage.morphology import (binary_erosion, generate_binary_structure, iterate_structure) -from dejavu.config.config import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, - DEFAULT_FS, DEFAULT_OVERLAP_RATIO, - DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, - MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, - PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) +from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, + DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, + MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, + PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) IDX_FREQ_I = 0 IDX_TIME_J = 1 diff --git a/dejavu/logic/recognizer/__init__.py b/dejavu/logic/recognizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/logic/recognizer/file_recognizer.py b/dejavu/logic/recognizer/file_recognizer.py new file mode 100644 index 0000000..8c019be --- /dev/null +++ b/dejavu/logic/recognizer/file_recognizer.py @@ -0,0 +1,24 @@ +import time + +import dejavu.logic.decoder as decoder +from dejavu.base_classes.base_recognizer import BaseRecognizer + + +class FileRecognizer(BaseRecognizer): + def __init__(self, dejavu): + super().__init__(dejavu) + + def recognize_file(self, filename): + frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) + + t = time.time() + matches = self._recognize(*frames) + t = time.time() - t + + for match in matches: + match['match_time'] = t + + return matches + + def recognize(self, filename): + return self.recognize_file(filename) diff --git a/dejavu/recognize.py b/dejavu/logic/recognizer/microphone_recognizer.py similarity index 70% rename from dejavu/recognize.py rename to dejavu/logic/recognizer/microphone_recognizer.py index 3d6c622..3bfef0d 100755 --- a/dejavu/recognize.py +++ b/dejavu/logic/recognizer/microphone_recognizer.py @@ -1,45 +1,7 @@ -import time - import numpy as np import pyaudio -import dejavu.decoder as decoder -from dejavu.config.config import DEFAULT_FS - - -class BaseRecognizer(object): - def __init__(self, dejavu): - self.dejavu = dejavu - self.Fs = DEFAULT_FS - - def _recognize(self, *data): - matches = [] - for d in data: - matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) - return self.dejavu.align_matches(matches) - - def recognize(self): - pass # base class does nothing - - -class FileRecognizer(BaseRecognizer): - def __init__(self, dejavu): - super().__init__(dejavu) - - def recognize_file(self, filename): - frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) - - t = time.time() - matches = self._recognize(*frames) - t = time.time() - t - - for match in matches: - match['match_time'] = t - - return matches - - def recognize(self, filename): - return self.recognize_file(filename) +from dejavu.base_classes.base_recognizer import BaseRecognizer class MicrophoneRecognizer(BaseRecognizer): diff --git a/dejavu/tests/__init__.py b/dejavu/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/testing.py b/dejavu/tests/dejavu_test.py similarity index 69% rename from dejavu/testing.py rename to dejavu/tests/dejavu_test.py index eb78578..3a301aa 100644 --- a/dejavu/testing.py +++ b/dejavu/tests/dejavu_test.py @@ -1,130 +1,26 @@ - -import ast import fnmatch +import json import logging -import os import random import re import subprocess import traceback +from os import listdir, makedirs, walk +from os.path import basename, exists, isfile, join, splitext +import matplotlib.pyplot as plt +import numpy as np from pydub import AudioSegment -from dejavu import Dejavu -from dejavu.decoder import path_to_songname -from dejavu.fingerprint import * +from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS, + DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, + MATCH_TIME, OFFSET, SONG_NAME) +from dejavu.logic.decoder import path_to_songname -def set_seed(seed=None): - """ - `seed` as None means that the sampling will be random. - - Setting your own seed means that you can produce the - same experiment over and over. - """ - if seed != None: - random.seed(seed) - - -def get_files_recursive(src, fmt): - """ - `src` is the source directory. - `fmt` is the extension, ie ".mp3" or "mp3", etc. - """ - for root, dirnames, filenames in os.walk(src): - for filename in fnmatch.filter(filenames, '*' + fmt): - yield os.path.join(root, filename) - - -def get_length_audio(audiopath, extension): - """ - Returns length of audio in seconds. - Returns None if format isn't supported or in case of error. - """ - try: - audio = AudioSegment.from_file(audiopath, extension.replace(".", "")) - except: - print(f"Error in get_length_audio(): {traceback.format_exc()}") - return None - return int(len(audio) / 1000.0) - - -def get_starttime(length, nseconds, padding): - """ - `length` is total audio length in seconds - `nseconds` is amount of time to sample in seconds - `padding` is off-limits seconds at beginning and ending - """ - maximum = length - padding - nseconds - if padding > maximum: - return 0 - return random.randint(padding, maximum) - - -def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10): - """ - Generates a test file for each file recursively in `src` directory - of given format using `nseconds` sampled from the audio file. - - Results are written to `dest` directory. - - `padding` is the number of off-limit seconds and the beginning and - end of a track that won't be sampled in testing. Often you want to - avoid silence, etc. - """ - # create directories if necessary - for directory in [src, dest]: - try: - os.stat(directory) - except: - os.mkdir(directory) - - # find files recursively of a given file format - for fmt in fmts: - testsources = get_files_recursive(src, fmt) - for audiosource in testsources: - - print("audiosource:", audiosource) - - filename, extension = os.path.splitext(os.path.basename(audiosource)) - length = get_length_audio(audiosource, extension) - starttime = get_starttime(length, nseconds, padding) - - test_file_name = f"{os.path.join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}" - - subprocess.check_output([ - "ffmpeg", "-y", - "-ss", f"{starttime}", - '-t', f"{nseconds}", - "-i", audiosource, - test_file_name]) - - -def log_msg(msg, log=True, silent=False): - if log: - logging.debug(msg) - if not silent: - print(msg) - - -def autolabel(rects, ax): - # attach some text labels - for rect in rects: - height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom') - - -def autolabeldoubles(rects, ax): - # attach some text labels - for rect in rects: - height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}', - ha='center', va='bottom') - - -class DejavuTest(object): +class DejavuTest: def __init__(self, folder, seconds): - super(DejavuTest, self).__init__() + super().__init__() self.test_folder = folder self.test_seconds = seconds @@ -133,9 +29,10 @@ class DejavuTest(object): print("test_seconds", self.test_seconds) self.test_files = [ - f for f in os.listdir(self.test_folder) - if os.path.isfile(os.path.join(self.test_folder, f)) - and re.findall("[0-9]*sec", f)[0] in self.test_seconds] + f for f in listdir(self.test_folder) + if isfile(join(self.test_folder, f)) + and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds]) + ] print("test_files", self.test_files) @@ -147,27 +44,27 @@ class DejavuTest(object): print("lines:", self.n_lines) # variable match results (yes, no, invalid) - self.result_match = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + self.result_match = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] - print("result_match matrix:", self.result_match) + print("result_match matrix:", self.result_match) # variable match precision (if matched in the corrected time) - self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] - # variable mahing time (query time) - self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + # variable matching time (query time) + self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] # variable confidence - self.result_match_confidence = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + self.result_match_confidence = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] self.begin() - def get_column_id (self, secs): + def get_column_id(self, secs): for i, sec in enumerate(self.test_seconds): if secs == sec: return i - def get_line_id (self, song): + def get_line_id(self, song): for i, s in enumerate(self.test_songs): if song == s: return i @@ -176,7 +73,7 @@ class DejavuTest(object): def create_plots(self, name, results, results_folder): for sec in range(0, len(self.test_seconds)): - ind = np.arange(self.n_lines) # + ind = np.arange(self.n_lines) width = 0.25 # the width of the bars fig = plt.figure() @@ -206,7 +103,7 @@ class DejavuTest(object): plt.grid() - fig_name = os.path.join(results_folder, f"{name}_{self.test_seconds[sec]}.png") + fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png") fig.savefig(fig_name) def begin(self): @@ -214,17 +111,19 @@ class DejavuTest(object): log_msg('--------------------------------------------------') log_msg(f'file: {f}') - # get column - col = self.get_column_id(re.findall("[0-9]*sec", f)[0]) - # format: XXXX_offset_length.mp3 - song = path_to_songname(f).split("_")[0] + # get column + col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0]) + + # format: XXXX_offset_length.mp3, we also take into account underscores within XXXX + splits = path_to_songname(f).split("_") + song = "_".join(splits[0:len(path_to_songname(f).split("_")) - 2]) line = self.get_line_id(song) result = subprocess.check_output([ - "python", + "python", "dejavu.py", '-r', - 'file', - self.test_folder + "/" + f]) + 'file', + join(self.test_folder, f)]) if result.strip() == "None": log_msg('No match') @@ -232,17 +131,15 @@ class DejavuTest(object): self.result_matching_times[line][col] = 0 self.result_query_duration[line][col] = 0 self.result_match_confidence[line][col] = 0 - + else: result = result.strip() - result = result.replace(" \'", ' "') - result = result.replace("{\'", '{"') - result = result.replace("\':", '":') - result = result.replace("\',", '",') + # we parse the output song back to a json + result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"')) - # which song did we predict? - result = ast.literal_eval(result) - song_result = result["song_name"] + # which song did we predict? We consider only the first match. + result = result[0] + song_result = result[SONG_NAME] log_msg(f'song: {song}') log_msg(f'song_result: {song_result}') @@ -256,21 +153,22 @@ class DejavuTest(object): log_msg('correct match') print(self.result_match) self.result_match[line][col] = 'yes' - self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3) - self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE] + self.result_query_duration[line][col] = round(result[MATCH_TIME], 3) + self.result_match_confidence[line][col] = result[CONFIDENCE] - song_start_time = re.findall("_[^_]+", f) + # using replace in f for getting rid of underscores in name + song_start_time = re.findall("_[^_]+", f.replace(song, "")) song_start_time = song_start_time[0].lstrip("_ ") - result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE * + result_start_time = round((result[OFFSET] * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0) self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) if abs(self.result_matching_times[line][col]) == 1: self.result_matching_times[line][col] = 0 - log_msg(f'query duration: {round(result[Dejavu.MATCH_TIME], 3)}') - log_msg(f'confidence: {result[Dejavu.CONFIDENCE]}') + log_msg(f'query duration: {round(result[MATCH_TIME], 3)}') + log_msg(f'confidence: {result[CONFIDENCE]}') log_msg(f'song start_time: {song_start_time}') log_msg(f'result start time: {result_start_time}') @@ -279,3 +177,110 @@ class DejavuTest(object): else: log_msg('inaccurate match') log_msg('--------------------------------------------------\n') + + +def set_seed(seed=None): + """ + `seed` as None means that the sampling will be random. + + Setting your own seed means that you can produce the + same experiment over and over. + """ + if seed is not None: + random.seed(seed) + + +def get_files_recursive(src, fmt): + """ + `src` is the source directory. + `fmt` is the extension, ie ".mp3" or "mp3", etc. + """ + files = [] + for root, dirnames, filenames in walk(src): + for filename in fnmatch.filter(filenames, '*' + fmt): + files.append(join(root, filename)) + + return files + + +def get_length_audio(audiopath, extension): + """ + Returns length of audio in seconds. + Returns None if format isn't supported or in case of error. + """ + try: + audio = AudioSegment.from_file(audiopath, extension.replace(".", "")) + except Exception: + print(f"Error in get_length_audio(): {traceback.format_exc()}") + return None + return int(len(audio) / 1000.0) + + +def get_starttime(length, nseconds, padding): + """ + `length` is total audio length in seconds + `nseconds` is amount of time to sample in seconds + `padding` is off-limits seconds at beginning and ending + """ + maximum = length - padding - nseconds + if padding > maximum: + return 0 + return random.randint(padding, maximum) + + +def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10): + """ + Generates a test file for each file recursively in `src` directory + of given format using `nseconds` sampled from the audio file. + + Results are written to `dest` directory. + + `padding` is the number of off-limit seconds and the beginning and + end of a track that won't be sampled in testing. Often you want to + avoid silence, etc. + """ + # create directories if necessary + if not exists(dest): + makedirs(dest) + + # find files recursively of a given file format + for fmt in fmts: + testsources = get_files_recursive(src, fmt) + for audiosource in testsources: + + print("audiosource:", audiosource) + + filename, extension = splitext(basename(audiosource)) + length = get_length_audio(audiosource, extension) + starttime = get_starttime(length, nseconds, padding) + + test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}" + + subprocess.check_output([ + "ffmpeg", "-y", + "-ss", f"{starttime}", + '-t', f"{nseconds}", + "-i", audiosource, + test_file_name]) + + +def log_msg(msg, log=True, silent=False): + if log: + logging.debug(msg) + if not silent: + print(msg) + + +def autolabel(rects, ax): + # attach some text labels + for rect in rects: + height = rect.get_height() + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom') + + +def autolabeldoubles(rects, ax): + # attach some text labels + for rect in rects: + height = rect.get_height() + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}', + ha='center', va='bottom') diff --git a/dejavu/third_party/__init__.py b/dejavu/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/third_party/wavio.py b/dejavu/third_party/wavio.py new file mode 100644 index 0000000..8c706b8 --- /dev/null +++ b/dejavu/third_party/wavio.py @@ -0,0 +1,357 @@ +# wavio.py +# Author: Warren Weckesser +# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause) +# Synopsis: A Python module for reading and writing 24 bit WAV files. +# Github: github.com/WarrenWeckesser/wavio + +""" +The wavio module defines the functions: +read(file) + Read a WAV file and return a `wavio.Wav` object, with attributes + `data`, `rate` and `sampwidth`. +write(filename, data, rate, scale=None, sampwidth=None) + Write a numpy array to a WAV file. +----- +Author: Warren Weckesser +License: BSD 2-Clause: +Copyright (c) 2015, Warren Weckesser +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + + +import wave as _wave + +import numpy as _np + +__version__ = "0.0.5.dev1" + + +def _wav2array(nchannels, sampwidth, data): + """data must be the string containing the bytes from the wav file.""" + num_samples, remainder = divmod(len(data), sampwidth * nchannels) + if remainder > 0: + raise ValueError('The length of data is not a multiple of ' + 'sampwidth * num_channels.') + if sampwidth > 4: + raise ValueError("sampwidth must not be greater than 4.") + + if sampwidth == 3: + a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8) + raw_bytes = _np.frombuffer(data, dtype=_np.uint8) + a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth) + a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255 + result = a.view('> _np.array([0, 8, 16])) & 255 + wavdata = a8.astype(_np.uint8).tostring() + else: + # Make sure the array is little-endian, and then convert using + # tostring() + a = a.astype('<' + a.dtype.str[1:], copy=False) + wavdata = a.tostring() + return wavdata + + +class Wav(object): + """ + Object returned by `wavio.read`. Attributes are: + data : numpy array + The array of data read from the WAV file. + rate : float + The sample rate of the WAV file. + sampwidth : int + The sample width (i.e. number of bytes per sample) of the WAV file. + For example, `sampwidth == 3` is a 24 bit WAV file. + """ + + def __init__(self, data, rate, sampwidth): + self.data = data + self.rate = rate + self.sampwidth = sampwidth + + def __repr__(self): + s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, " + f"rate={self.rate}, sampwidth={self.sampwidth})") + return s + + +def read(file): + """ + Read a WAV file. + Parameters + ---------- + file : string or file object + Either the name of a file or an open file pointer. + Returns + ------- + wav : wavio.Wav() instance + The return value is an instance of the class `wavio.Wav`, + with the following attributes: + data : numpy array + The array containing the data. The shape of the array + is (num_samples, num_channels). num_channels is the + number of audio channels (1 for mono, 2 for stereo). + rate : float + The sampling frequency (i.e. frame rate) + sampwidth : float + The sample width, in bytes. E.g. for a 24 bit WAV file, + sampwidth is 3. + Notes + ----- + This function uses the `wave` module of the Python standard libary + to read the WAV file, so it has the same limitations as that library. + In particular, the function does not read compressed WAV files, and + it does not read files with floating point data. + The array returned by `wavio.read` is always two-dimensional. If the + WAV data is mono, the array will have shape (num_samples, 1). + `wavio.read()` does not scale or normalize the data. The data in the + array `wav.data` is the data that was in the file. When the file + contains 24 bit samples, the resulting numpy array is 32 bit integers, + with values that have been sign-extended. + """ + wav = _wave.open(file) + rate = wav.getframerate() + nchannels = wav.getnchannels() + sampwidth = wav.getsampwidth() + nframes = wav.getnframes() + data = wav.readframes(nframes) + wav.close() + array = _wav2array(nchannels, sampwidth, data) + w = Wav(data=array, rate=rate, sampwidth=sampwidth) + return w + + +_sampwidth_dtypes = {1: _np.uint8, + 2: _np.int16, + 3: _np.int32, + 4: _np.int32} +_sampwidth_ranges = {1: (0, 256), + 2: (-2**15, 2**15), + 3: (-2**23, 2**23), + 4: (-2**31, 2**31)} + + +def _scale_to_sampwidth(data, sampwidth, vmin, vmax): + # Scale and translate the values to fit the range of the data type + # associated with the given sampwidth. + + data = data.clip(vmin, vmax) + + dt = _sampwidth_dtypes[sampwidth] + if vmax == vmin: + data = _np.zeros(data.shape, dtype=dt) + else: + outmin, outmax = _sampwidth_ranges[sampwidth] + if outmin != vmin or outmax != vmax: + vmin = float(vmin) + vmax = float(vmax) + data = (float(outmax - outmin) * (data - vmin) / + (vmax - vmin)).astype(_np.int64) + outmin + data[data == outmax] = outmax - 1 + data = data.astype(dt) + + return data + + +def write(file, data, rate, scale=None, sampwidth=None): + """ + Write the numpy array `data` to a WAV file. + The Python standard library "wave" is used to write the data + to the file, so this function has the same limitations as that + module. In particular, the Python library does not support + floating point data. When given a floating point array, this + function converts the values to integers. See below for the + conversion rules. + Parameters + ---------- + file : string, or file object open for writing in binary mode + Either the name of a file or an open file pointer. + data : numpy array, 1- or 2-dimensional, integer or floating point + If it is 2-d, the rows are the frames (i.e. samples) and the + columns are the channels. + rate : float + The sampling frequency (i.e. frame rate) of the data. + sampwidth : int, optional + The sample width, in bytes, of the output file. + If `sampwidth` is not given, it is inferred (if possible) from + the data type of `data`, as follows:: + data.dtype sampwidth + ---------- --------- + uint8, int8 1 + uint16, int16 2 + uint32, int32 4 + For any other data types, or to write a 24 bit file, `sampwidth` + must be given. + scale : tuple or str, optional + By default, the data written to the file is scaled up or down to + occupy the full range of the output data type. So, for example, + the unsigned 8 bit data [0, 1, 2, 15] would be written to the file + as [0, 17, 30, 255]. More generally, the default behavior is + (roughly):: + vmin = data.min() + vmax = data.max() + outmin = + outmax = + outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin + The `scale` argument allows the scaling of the output data to be + changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which + case the given values override the use of `data.min()` and + `data.max()` for `vmin` and `vmax` shown above. (If either value + is `None`, the value shown above is used.) Data outside the + range (vmin, vmax) is clipped. If `vmin == vmax`, the output is + all zeros. + If `scale` is the string "none", then `vmin` and `vmax` are set to + `outmin` and `outmax`, respectively. This means the data is written + to the file with no scaling. (Note: `scale="none" is not the same + as `scale=None`. The latter means "use the default behavior", + which is to scale by the data minimum and maximum.) + If `scale` is the string "dtype-limits", then `vmin` and `vmax` + are set to the minimum and maximum integers of `data.dtype`. + The string "dtype-limits" is not allowed when the `data` is a + floating point array. + If using `scale` results in values that exceed the limits of the + output sample width, the data is clipped. For example, the + following code:: + >> x = np.array([-100, 0, 100, 200, 300, 325]) + >> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1) + will write the values [0, 0, 100, 200, 255, 255] to the file. + Example + ------- + Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file. + >> import numpy as np + >> import wavio + >> rate = 22050 # samples per second + >> T = 3 # sample duration (seconds) + >> f = 440.0 # sound frequency (Hz) + >> t = np.linspace(0, T, T*rate, endpoint=False) + >> x = np.sin(2*np.pi * f * t) + >> wavio.write("sine24.wav", x, rate, sampwidth=3) + Create a file that contains the 16 bit integer values -10000 and 10000 + repeated 100 times. Don't automatically scale the values. Use a sample + rate 8000. + >> x = np.empty(200, dtype=np.int16) + >> x[::2] = -10000 + >> x[1::2] = 10000 + >> wavio.write("foo.wav", x, 8000, scale='none') + Check that the file contains what we expect. + >> w = wavio.read("foo.wav") + >> np.all(w.data[:, 0] == x) + True + In the following, the values -10000 and 10000 (from within the 16 bit + range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and + 168 (in the range [0, 2**8-1]). + >> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits') + >> w = wavio.read("foo.wav") + >> w.data[:4, 0] + array([ 88, 168, 88, 168], dtype=uint8) + """ + + if sampwidth is None: + if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4: + raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.') + sampwidth = data.itemsize + else: + if sampwidth not in [1, 2, 3, 4]: + raise ValueError('sampwidth must be 1, 2, 3 or 4.') + + outdtype = _sampwidth_dtypes[sampwidth] + outmin, outmax = _sampwidth_ranges[sampwidth] + + if scale == "none": + data = data.clip(outmin, outmax-1).astype(outdtype) + elif scale == "dtype-limits": + if not _np.issubdtype(data.dtype, _np.integer): + raise ValueError("scale cannot be 'dtype-limits' with non-integer data.") + # Easy transforms that just changed the signedness of the data. + if sampwidth == 1 and data.dtype == _np.int8: + data = (data.astype(_np.int16) + 128).astype(_np.uint8) + elif sampwidth == 2 and data.dtype == _np.uint16: + data = (data.astype(_np.int32) - 32768).astype(_np.int16) + elif sampwidth == 4 and data.dtype == _np.uint32: + data = (data.astype(_np.int64) - 2**31).astype(_np.int32) + elif data.itemsize != sampwidth: + # Integer input, but rescaling is needed to adjust the + # input range to the output sample width. + ii = _np.iinfo(data.dtype) + vmin = ii.min + vmax = ii.max + data = _scale_to_sampwidth(data, sampwidth, vmin, vmax) + else: + if scale is None: + vmin = data.min() + vmax = data.max() + else: + # scale must be a tuple of the form (vmin, vmax) + vmin, vmax = scale + if vmin is None: + vmin = data.min() + if vmax is None: + vmax = data.max() + + data = _scale_to_sampwidth(data, sampwidth, vmin, vmax) + + # At this point, `data` has been converted to have one of the following: + # sampwidth dtype + # --------- ----- + # 1 uint8 + # 2 int16 + # 3 int32 + # 4 int32 + # The values in `data` are in the form in which they will be saved; + # no more scaling will take place. + + if data.ndim == 1: + data = data.reshape(-1, 1) + + wavdata = _array2wav(data, sampwidth) + + w = _wave.open(file, 'wb') + w.setnchannels(data.shape[1]) + w.setsampwidth(sampwidth) + w.setframerate(rate) + w.writeframes(wavdata) + w.close() diff --git a/dejavu/wavio.py b/dejavu/wavio.py deleted file mode 100644 index 7045057..0000000 --- a/dejavu/wavio.py +++ /dev/null @@ -1,122 +0,0 @@ -# wavio.py -# Author: Warren Weckesser -# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause) -# Synopsis: A Python module for reading and writing 24 bit WAV files. -# Github: github.com/WarrenWeckesser/wavio - -import wave as _wave - -import numpy as _np - - -def _wav2array(nchannels, sampwidth, data): - """data must be the string containing the bytes from the wav file.""" - num_samples, remainder = divmod(len(data), sampwidth * nchannels) - if remainder > 0: - raise ValueError('The length of data is not a multiple of ' - 'sampwidth * num_channels.') - if sampwidth > 4: - raise ValueError("sampwidth must not be greater than 4.") - - if sampwidth == 3: - a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8) - raw_bytes = _np.fromstring(data, dtype=_np.uint8) - a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth) - a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255 - result = a.view('>> rate = 22050 # samples per second - >>> T = 3 # sample duration (seconds) - >>> f = 440.0 # sound frequency (Hz) - >>> t = np.linspace(0, T, T*rate, endpoint=False) - >>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t) - >>> writewav24("sine24.wav", rate, x) - - """ - a32 = _np.asarray(data, dtype=_np.int32) - if a32.ndim == 1: - # Convert to a 2D array with a single column. - a32.shape = a32.shape + (1,) - # By shifting first 0 bits, then 8, then 16, the resulting output - # is 24 bit little-endian. - a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255 - wavdata = a8.astype(_np.uint8).tostring() - - w = _wave.open(filename, 'wb') - w.setnchannels(a32.shape[1]) - w.setsampwidth(3) - w.setframerate(rate) - w.writeframes(wavdata) - w.close() diff --git a/example.py b/example_script.py similarity index 71% rename from example.py rename to example_script.py index 87aef09..8f90552 100755 --- a/example.py +++ b/example_script.py @@ -1,11 +1,8 @@ import json -import warnings from dejavu import Dejavu -from dejavu.recognize import FileRecognizer, MicrophoneRecognizer - -warnings.filterwarnings("ignore") - +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer +from dejavu.logic.recognizer.file_recognizer import FileRecognizer # load config from a JSON file (or anything outputting a python dictionary) with open("dejavu.cnf.SAMPLE") as f: @@ -17,10 +14,10 @@ if __name__ == '__main__': djv = Dejavu(config) # Fingerprint all the mp3's in the directory we give it - djv.fingerprint_directory("mp3", [".mp3"]) + djv.fingerprint_directory("test", [".wav"]) # Recognize audio from a file - song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3") + song = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") print(f"From file we recognized: {song}\n") # Or recognize audio from your microphone for `secs` seconds @@ -29,7 +26,7 @@ if __name__ == '__main__': if song is None: print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)") else: - print(f"From mic with %d seconds we recognized: {(secs, song)}\n") + print(f"From mic with {secs} seconds we recognized: {song}\n") # Or use a recognizer without the shortcut, in anyway you would like recognizer = FileRecognizer(djv) diff --git a/mp3/azan_test.wav b/mp3/azan_test.wav new file mode 100644 index 0000000..34043fd Binary files /dev/null and b/mp3/azan_test.wav differ diff --git a/requirements.txt b/requirements.txt index 19cc2b9..1295bd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ scipy==1.3.1 matplotlib==3.1.1 mysql-connector-python==8.0.17 psycopg2==2.8.3 + diff --git a/run_tests.py b/run_tests.py index 7ad387d..6a5eda9 100644 --- a/run_tests.py +++ b/run_tests.py @@ -1,184 +1,166 @@ -from dejavu.testing import * -from dejavu import Dejavu -from optparse import OptionParser -import matplotlib.pyplot as plt +import argparse +import logging import time -import shutil +from os import makedirs +from os.path import exists, join +from shutil import rmtree -usage = "usage: %prog [options] TESTING_AUDIOFOLDER" -parser = OptionParser(usage=usage, version="%prog 1.1") -parser.add_option("--secs", - action="store", - dest="secs", - default=5, - type=int, - help='Number of seconds starting from zero to test') -parser.add_option("--results", - action="store", - dest="results_folder", - default="./dejavu_test_results", - help='Sets the path where the results are saved') -parser.add_option("--temp", - action="store", - dest="temp_folder", - default="./dejavu_temp_testing_files", - help='Sets the path where the temp files are saved') -parser.add_option("--log", - action="store_true", - dest="log", - default=True, - help='Enables logging') -parser.add_option("--silent", - action="store_false", - dest="silent", - default=False, - help='Disables printing') -parser.add_option("--log-file", - dest="log_file", - default="results-compare.log", - help='Set the path and filename of the log file') -parser.add_option("--padding", - action="store", - dest="padding", - default=10, - type=int, - help='Number of seconds to pad choice of place to test from') -parser.add_option("--seed", - action="store", - dest="seed", - default=None, - type=int, - help='Random seed') -options, args = parser.parse_args() -test_folder = args[0] +import matplotlib.pyplot as plt +import numpy as np -# set random seed if set by user -set_seed(options.seed) +from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles, + generate_test_files, log_msg, set_seed) -# ensure results folder exists -try: - os.stat(options.results_folder) -except: - os.mkdir(options.results_folder) -# set logging -if options.log: - logging.basicConfig(filename=options.log_file, level=logging.DEBUG) +def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool, + log_file: str, padding: int, seed: int, src: str): -# set test seconds -test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)] + # set random seed if set by user + set_seed(seed) -# generate testing files -for i in range(1, options.secs + 1, 1): - generate_test_files(test_folder, options.temp_folder, - i, padding=options.padding) + # ensure results folder exists + if not exists(results_folder): + makedirs(results_folder) -# scan files -log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder, - log=options.log, silent=options.silent) + # set logging + if log: + logging.basicConfig(filename=log_file, level=logging.DEBUG) -tm = time.time() -djv = DejavuTest(options.temp_folder, test_seconds) -log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm), - log=options.log, silent=options.silent) + # set test seconds + test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)] -tests = 1 # djv -n_secs = len(test_seconds) + # generate testing files + for i in range(1, seconds + 1, 1): + generate_test_files(src, temp_folder, i, padding=padding) -# set result variables -> 4d variables -all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)] -all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)] -all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] -all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] + # scan files + log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent) -# group results by seconds -for line in range(0, djv.n_lines): - for col in range(0, djv.n_columns): - # for dejavu - all_query_duration[col][line][0] = djv.result_query_duration[line][col] - all_match_confidence[col][line][0] = djv.result_match_confidence[line][col] + tm = time.time() + djv = DejavuTest(temp_folder, test_seconds) + log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent) - djv_match_result = djv.result_match[line][col] + tests = 1 # djv + n_secs = len(test_seconds) - if djv_match_result == 'yes': - all_match_counter[col][0][0] += 1 - elif djv_match_result == 'no': - all_match_counter[col][1][0] += 1 - else: - all_match_counter[col][2][0] += 1 + # set result variables -> 4d variables + all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)] + all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)] + all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] + all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] - djv_match_acc = djv.result_matching_times[line][col] + # group results by seconds + for line in range(0, djv.n_lines): + for col in range(0, djv.n_columns): + # for dejavu + all_query_duration[col][line][0] = djv.result_query_duration[line][col] + all_match_confidence[col][line][0] = djv.result_match_confidence[line][col] - if djv_match_acc == 0 and djv_match_result == 'yes': - all_matching_times_counter[col][0][0] += 1 - elif djv_match_acc != 0: - all_matching_times_counter[col][1][0] += 1 + djv_match_result = djv.result_match[line][col] -# create plots -djv.create_plots('Confidence', all_match_confidence, options.results_folder) -djv.create_plots('Query duration', all_query_duration, options.results_folder) + if djv_match_result == 'yes': + all_match_counter[col][0][0] += 1 + elif djv_match_result == 'no': + all_match_counter[col][1][0] += 1 + else: + all_match_counter[col][2][0] += 1 -for sec in range(0, n_secs): - ind = np.arange(3) # - width = 0.25 # the width of the bars + djv_match_acc = djv.result_matching_times[line][col] - fig = plt.figure() - ax = fig.add_subplot(111) - ax.set_xlim([-1 * width, 2.75]) + if djv_match_acc == 0 and djv_match_result == 'yes': + all_matching_times_counter[col][0][0] += 1 + elif djv_match_acc != 0: + all_matching_times_counter[col][1][0] += 1 - means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]] - rects1 = ax.bar(ind, means_dvj, width, color='r') + # create plots + djv.create_plots('Confidence', all_match_confidence, results_folder) + djv.create_plots('Query duration', all_query_duration, results_folder) - # add some - ax.set_ylabel('Matching Percentage') - ax.set_title('%s Matching Percentage' % test_seconds[sec]) - ax.set_xticks(ind + width) + for sec in range(0, n_secs): + ind = np.arange(3) + width = 0.25 # the width of the bars - labels = ['yes','no','invalid'] - ax.set_xticklabels( labels ) + fig = plt.figure() + ax = fig.add_subplot(111) + ax.set_xlim([-1 * width, 2.75]) - box = ax.get_position() - ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) - #ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - autolabeldoubles(rects1,ax) - plt.grid() + means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]] + rects1 = ax.bar(ind, means_dvj, width, color='r') - fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec]) - fig.savefig(fig_name) + # add some + ax.set_ylabel('Matching Percentage') + ax.set_title(f'{test_seconds[sec]} Matching Percentage') + ax.set_xticks(ind + width) -for sec in range(0, n_secs): - ind = np.arange(2) # - width = 0.25 # the width of the bars + labels = ['yes', 'no', 'invalid'] + ax.set_xticklabels(labels) - fig = plt.figure() - ax = fig.add_subplot(111) - ax.set_xlim([-1*width, 1.75]) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + autolabeldoubles(rects1, ax) + plt.grid() - div = all_match_counter[sec][0][0] - if div == 0 : - div = 1000000 + fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png") + fig.savefig(fig_name) - means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]] - rects1 = ax.bar(ind, means_dvj, width, color='r') + for sec in range(0, n_secs): + ind = np.arange(2) + width = 0.25 # the width of the bars - # add some - ax.set_ylabel('Matching Accuracy') - ax.set_title('%s Matching Times Accuracy' % test_seconds[sec]) - ax.set_xticks(ind + width) + fig = plt.figure() + ax = fig.add_subplot(111) + ax.set_xlim([-1 * width, 1.75]) - labels = ['yes','no'] - ax.set_xticklabels( labels ) + div = all_match_counter[sec][0][0] + if div == 0: + div = 1000000 - box = ax.get_position() - ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]] + rects1 = ax.bar(ind, means_dvj, width, color='r') - #ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - autolabeldoubles(rects1,ax) + # add some + ax.set_ylabel('Matching Accuracy') + ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy') + ax.set_xticks(ind + width) - plt.grid() + labels = ['yes', 'no'] + ax.set_xticklabels(labels) - fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec]) - fig.savefig(fig_name) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + autolabeldoubles(rects1, ax) -# remove temporary folder -shutil.rmtree(options.temp_folder) + plt.grid() + + fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png") + fig.savefig(fig_name) + + # remove temporary folder + rmtree(temp_folder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate ' + f'its configuration performance. ' + f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER' + ) + + parser.add_argument("-sec", "--seconds", action="store", default=5, type=int, + help='Number of seconds starting from zero to test.') + parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results", + help='Sets the path where the results are saved.') + parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files", + help='Sets the path where the temp files are saved.') + parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.') + parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.') + parser.add_argument("-lf", "--log-file", default="results-compare.log", + help='Set the path and filename of the log file.') + parser.add_argument("-pad", "--padding", action="store", default=10, type=int, + help='Number of seconds to pad choice of place to test from.') + parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.') + parser.add_argument("src", type=str, help='Source folder for audios to use as tests.') + + args = parser.parse_args() + + main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding, + args.seed, args.src) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..5f32207 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 + diff --git a/setup.py b/setup.py index 9484507..c94e0fb 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ from setuptools import setup, find_packages -# import os, sys def parse_requirements(requirements): @@ -14,19 +13,20 @@ def parse_requirements(requirements): reqs = list(filter((lambda x: x), nocomments)) return reqs + PACKAGE_NAME = "PyDejavu" PACKAGE_VERSION = "0.1.3" SUMMARY = 'Dejavu: Audio Fingerprinting in Python' DESCRIPTION = """ Audio fingerprinting and recognition algorithm implemented in Python -See the explanation here: +See the explanation here: `http://willdrevo.com/fingerprinting-and-audio-recognition-with-python/`__ -Dejavu can memorize recorded audio by listening to it once and fingerprinting -it. Then by playing a song and recording microphone input or on disk file, -Dejavu attempts to match the audio against the fingerprints held in the +Dejavu can memorize recorded audio by listening to it once and fingerprinting +it. Then by playing a song and recording microphone input or on disk file, +Dejavu attempts to match the audio against the fingerprints held in the database, returning the song or recording being played. __ http://willdrevo.com/fingerprinting-and-audio-recognition-with-python/ diff --git a/test/sean_secs.wav b/test/sean_secs.wav new file mode 100644 index 0000000..6f30d72 Binary files /dev/null and b/test/sean_secs.wav differ diff --git a/test/woodward_43s.wav b/test/woodward_43s.wav new file mode 100644 index 0000000..906f426 Binary files /dev/null and b/test/woodward_43s.wav differ