diff --git a/README.md b/README.md index 7ec3ccd..dc20967 100755 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ Second, you'll need to create a MySQL database where Dejavu can store fingerprin Now you're ready to start fingerprinting your audio collection! +Obs: The same from above goes for postgres database if you want to use it. + ## Quickstart ```bash @@ -44,8 +46,8 @@ Start by creating a Dejavu object with your configurations settings (Dejavu take ... "database": { ... "host": "127.0.0.1", ... "user": "root", -... "passwd": , -... "db": , +... "password": , +... "database": , ... } ... } >>> djv = Dejavu(config) @@ -81,7 +83,7 @@ The following keys are mandatory: The following keys are optional: * `fingerprint_limit`: allows you to control how many seconds of each audio file to fingerprint. Leaving out this key, or alternatively using `-1` and `None` will cause Dejavu to fingerprint the entire audio file. Default value is `None`. -* `database_type`: as of now, only `mysql` (the default value) is supported. If you'd like to subclass `Database` and add another, please fork and send a pull request! +* `database_type`: `mysql` (the default value) and `postgres` are supported. If you'd like to add another subclass for `BaseDatabase` and implement a new type of database, please fork and send a pull request! An example configuration is as follows: @@ -91,8 +93,8 @@ An example configuration is as follows: ... "database": { ... "host": "127.0.0.1", ... "user": "root", -... "passwd": "Password123", -... "db": "dejavu_db", +... "password": "Password123", +... "database": "dejavu_db", ... }, ... "database_type" : "mysql", ... "fingerprint_limit" : 10 @@ -102,16 +104,16 @@ An example configuration is as follows: ## Tuning -Inside `fingerprint.py`, you may want to adjust following parameters (some values are given below). +Inside `config/settings.py`, you may want to adjust following parameters (some values are given below). FINGERPRINT_REDUCTION = 30 PEAK_SORT = False DEFAULT_OVERLAP_RATIO = 0.4 - DEFAULT_FAN_VALUE = 10 - DEFAULT_AMP_MIN = 15 - PEAK_NEIGHBORHOOD_SIZE = 30 + DEFAULT_FAN_VALUE = 5 + DEFAULT_AMP_MIN = 10 + PEAK_NEIGHBORHOOD_SIZE = 10 -These parameters are described in the `fingerprint.py` in detail. Read that in-order to understand the impact of changing these values. +These parameters are described within the file in detail. Read that in-order to understand the impact of changing these values. ## Recognizing @@ -123,13 +125,13 @@ Through the terminal: ```bash $ python dejavu.py --recognize file sometrack.wav -{'song_id': 1, 'song_name': 'Taylor Swift - Shake It Off', 'confidence': 3948, 'offset_seconds': 30.00018, 'match_time': 0.7159781455993652, 'offset': 646L} +{'total_time': 2.863781690597534, 'fingerprint_time': 2.4306554794311523, 'query_time': 0.4067542552947998, 'align_time': 0.007731199264526367, 'results': [{'song_id': 1, 'song_name': 'Taylor Swift - Shake It Off', 'input_total_hashes': 76168, 'fingerprinted_hashes_in_db': 4919, 'hashes_matched_in_input': 794, 'input_confidence': 0.01, 'fingerprinted_confidence': 0.16, 'offset': -924, 'offset_seconds': -30.00018, 'file_sha1': b'3DC269DF7B8DB9B30D2604DA80783155912593E8'}, {...}, ...]} ``` or in scripting, assuming you've already instantiated a Dejavu object: ```python ->>> from dejavu.recognize import FileRecognizer +>>> from dejavu.logic.recognizer.file_recognizer import FileRecognizer >>> song = djv.recognize(FileRecognizer, "va_us_top_40/wav/Mirrors - Justin Timberlake.wav") ``` @@ -138,7 +140,7 @@ or in scripting, assuming you've already instantiated a Dejavu object: With scripting: ```python ->>> from dejavu.recognize import MicrophoneRecognizer +>>> from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer >>> song = djv.recognize(MicrophoneRecognizer, seconds=10) # Defaults to 10 seconds. ``` diff --git a/dejavu.cnf.SAMPLE b/dejavu.cnf.SAMPLE index cd677b0..e161192 100755 --- a/dejavu.cnf.SAMPLE +++ b/dejavu.cnf.SAMPLE @@ -2,7 +2,8 @@ "database": { "host": "127.0.0.1", "user": "root", - "passwd": "12345678", - "db": "dejavu" - } -} \ No newline at end of file + "password": "rootpass", + "database": "dejavu" + }, + "database_type": "mysql" +} diff --git a/dejavu.py b/dejavu.py index a7f74a1..fe8837c 100755 --- a/dejavu.py +++ b/dejavu.py @@ -1,30 +1,25 @@ -#!/usr/bin/python - -import os -import sys -import json -import warnings import argparse +import json +import sys +from argparse import RawTextHelpFormatter +from os.path import isdir from dejavu import Dejavu -from dejavu.recognize import FileRecognizer -from dejavu.recognize import MicrophoneRecognizer -from argparse import RawTextHelpFormatter - -warnings.filterwarnings("ignore") +from dejavu.logic.recognizer.file_recognizer import FileRecognizer +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE" def init(configpath): - """ + """ Load config from a JSON file """ try: with open(configpath) as f: config = json.load(f) except IOError as err: - print("Cannot open configuration: %s. Exiting" % (str(err))) + print(f"Cannot open configuration: {str(err)}. Exiting") sys.exit(1) # create a Dejavu instance @@ -46,7 +41,7 @@ if __name__ == '__main__': '--fingerprint /path/to/directory') parser.add_argument('-r', '--recognize', nargs=2, help='Recognize what is ' - 'playing through the microphone\n' + 'playing through the microphone or in a file.\n' 'Usage: \n' '--recognize mic number_of_seconds \n' '--recognize file path/to/file \n') @@ -59,7 +54,6 @@ if __name__ == '__main__': config_file = args.config if config_file is None: config_file = DEFAULT_CONFIG_FILE - # print "Using default config file: %s" % (config_file) djv = init(config_file) if args.fingerprint: @@ -67,28 +61,24 @@ if __name__ == '__main__': if len(args.fingerprint) == 2: directory = args.fingerprint[0] extension = args.fingerprint[1] - print("Fingerprinting all .%s files in the %s directory" - % (extension, directory)) + print(f"Fingerprinting all .{extension} files in the {directory} directory") djv.fingerprint_directory(directory, ["." + extension], 4) elif len(args.fingerprint) == 1: filepath = args.fingerprint[0] - if os.path.isdir(filepath): + if isdir(filepath): print("Please specify an extension if you'd like to fingerprint a directory!") sys.exit(1) djv.fingerprint_file(filepath) elif args.recognize: # Recognize audio source - song = None + songs = None source = args.recognize[0] opt_arg = args.recognize[1] if source in ('mic', 'microphone'): - song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg) + songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg) elif source == 'file': - song = djv.recognize(FileRecognizer, opt_arg) - decoded_song = repr(song).decode('string_escape') - print(decoded_song) - - sys.exit(0) + songs = djv.recognize(FileRecognizer, opt_arg) + print(songs) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index 7cc3f3b..fac72bc 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -1,28 +1,29 @@ -from dejavu.database import get_database, Database -import dejavu.decoder as decoder -import fingerprint import multiprocessing import os -import traceback import sys +import traceback +from itertools import groupby +from time import time +from typing import Dict, List, Tuple + +import dejavu.logic.decoder as decoder +from dejavu.base_classes.base_database import get_database +from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, FIELD_FILE_SHA1, + FIELD_TOTAL_HASHES, + FINGERPRINTED_CONFIDENCE, + FINGERPRINTED_HASHES, HASHES_MATCHED, + INPUT_CONFIDENCE, INPUT_HASHES, OFFSET, + OFFSET_SECS, SONG_ID, SONG_NAME, TOPN) +from dejavu.logic.fingerprint import fingerprint -class Dejavu(object): - - SONG_ID = "song_id" - SONG_NAME = 'song_name' - CONFIDENCE = 'confidence' - MATCH_TIME = 'match_time' - OFFSET = 'offset' - OFFSET_SECS = 'offset_seconds' - +class Dejavu: def __init__(self, config): - super(Dejavu, self).__init__() - self.config = config # initialize db - db_cls = get_database(config.get("database_type", None)) + db_cls = get_database(config.get("database_type", "mysql").lower()) self.db = db_cls(**config.get("database", {})) self.db.setup() @@ -32,17 +33,44 @@ class Dejavu(object): self.limit = self.config.get("fingerprint_limit", None) if self.limit == -1: # for JSON compatibility self.limit = None - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() - def get_fingerprinted_songs(self): + def __load_fingerprinted_audio_hashes(self) -> None: + """ + Keeps a dictionary with the hashes of the fingerprinted songs, in that way is possible to check + whether or not an audio file was already processed. + """ # get songs previously indexed self.songs = self.db.get_songs() self.songhashes_set = set() # to know which ones we've computed before for song in self.songs: - song_hash = song[Database.FIELD_FILE_SHA1] + song_hash = song[FIELD_FILE_SHA1] self.songhashes_set.add(song_hash) - def fingerprint_directory(self, path, extensions, nprocesses=None): + def get_fingerprinted_songs(self) -> List[Dict[str, any]]: + """ + To pull all fingerprinted songs from the database. + + :return: a list of fingerprinted audios from the database. + """ + return self.db.get_songs() + + def delete_songs_by_id(self, song_ids: List[int]) -> None: + """ + Deletes all audios given their ids. + + :param song_ids: song ids to delete from the database. + """ + self.db.delete_songs_by_id(song_ids) + + def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None: + """ + Given a directory and a set of extensions it fingerprints all files that match each extension specified. + + :param path: path to the directory. + :param extensions: list of file extensions to consider. + :param nprocesses: amount of processes to fingerprint the files within the directory. + """ # Try to use the maximum amount of processes if not given. try: nprocesses = nprocesses or multiprocessing.cpu_count() @@ -55,54 +83,58 @@ class Dejavu(object): filenames_to_fingerprint = [] for filename, _ in decoder.find_files(path, extensions): - # don't refingerprint already fingerprinted files if decoder.unique_hash(filename) in self.songhashes_set: - print "%s already fingerprinted, continuing..." % filename + print(f"{filename} already fingerprinted, continuing...") continue filenames_to_fingerprint.append(filename) # Prepare _fingerprint_worker input - worker_input = zip(filenames_to_fingerprint, - [self.limit] * len(filenames_to_fingerprint)) + worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint))) # Send off our tasks - iterator = pool.imap_unordered(_fingerprint_worker, - worker_input) + iterator = pool.imap_unordered(Dejavu._fingerprint_worker, worker_input) # Loop till we have all of them while True: try: - song_name, hashes, file_hash = iterator.next() + song_name, hashes, file_hash = next(iterator) except multiprocessing.TimeoutError: continue except StopIteration: break - except: + except Exception: print("Failed fingerprinting") # Print traceback because we can't reraise it here traceback.print_exc(file=sys.stdout) else: - sid = self.db.insert_song(song_name, file_hash) + sid = self.db.insert_song(song_name, file_hash, len(hashes)) self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() pool.close() pool.join() - def fingerprint_file(self, filepath, song_name=None): - songname = decoder.path_to_songname(filepath) - song_hash = decoder.unique_hash(filepath) - song_name = song_name or songname + def fingerprint_file(self, file_path: str, song_name: str = None) -> None: + """ + Given a path to a file the method generates hashes for it and stores them in the database + for later be queried. + + :param file_path: path to the file. + :param song_name: song name associated to the audio file. + """ + song_name_from_path = decoder.get_audio_name_from_path(file_path) + song_hash = decoder.unique_hash(file_path) + song_name = song_name or song_name_from_path # don't refingerprint already fingerprinted files if song_hash in self.songhashes_set: - print "%s already fingerprinted, continuing..." % song_name + print(f"{song_name} already fingerprinted, continuing...") else: - song_name, hashes, file_hash = _fingerprint_worker( - filepath, + song_name, hashes, file_hash = Dejavu._fingerprint_worker( + file_path, self.limit, song_name=song_name ) @@ -110,93 +142,118 @@ class Dejavu(object): self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() - def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): - hashes = fingerprint.fingerprint(samples, Fs=Fs) - return self.db.return_matches(hashes) + def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]: + f""" + Generate the fingerprints for the given sample data (channel). - def align_matches(self, matches): + :param samples: list of ints which represents the channel info of the given audio file. + :param Fs: sampling rate which defaults to {DEFAULT_FS}. + :return: a list of tuples for hash and its corresponding offset, together with the generation time. """ - Finds hash matches that align in time with other matches and finds - consensus about which hashes are "true" signal from the audio. + t = time() + hashes = fingerprint(samples, Fs=Fs) + fingerprint_time = time() - t + return hashes, fingerprint_time - Returns a dictionary with match information. + def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]: """ - # align by diffs - diff_counter = {} - largest = 0 - largest_count = 0 - song_id = -1 - for tup in matches: - sid, diff = tup - if diff not in diff_counter: - diff_counter[diff] = {} - if sid not in diff_counter[diff]: - diff_counter[diff][sid] = 0 - diff_counter[diff][sid] += 1 + Finds the corresponding matches on the fingerprinted audios for the given hashes. - if diff_counter[diff][sid] > largest_count: - largest = diff - largest_count = diff_counter[diff][sid] - song_id = sid + :param hashes: list of tuples for hashes and their corresponding offsets + :return: a tuple containing the matches found against the db, a dictionary which counts the different + hashes matched for each song (with the song id as key), and the time that the query took. - # extract idenfication - song = self.db.get_song_by_id(song_id) - if song: - # TODO: Clarify what `get_song_by_id` should return. - songname = song.get(Dejavu.SONG_NAME, None) - else: - return None + """ + t = time() + matches, dedup_hashes = self.db.return_matches(hashes) + query_time = time() - t - # return match info - nseconds = round(float(largest) / fingerprint.DEFAULT_FS * - fingerprint.DEFAULT_WINDOW_SIZE * - fingerprint.DEFAULT_OVERLAP_RATIO, 5) - song = { - Dejavu.SONG_ID : song_id, - Dejavu.SONG_NAME : songname.encode("utf8"), - Dejavu.CONFIDENCE : largest_count, - Dejavu.OFFSET : int(largest), - Dejavu.OFFSET_SECS : nseconds, - Database.FIELD_FILE_SHA1 : song.get(Database.FIELD_FILE_SHA1, None).encode("utf8"),} - return song + return matches, dedup_hashes, query_time - def recognize(self, recognizer, *options, **kwoptions): + def align_matches(self, matches: List[Tuple[int, int]], dedup_hashes: Dict[str, int], queried_hashes: int, + topn: int = TOPN) -> List[Dict[str, any]]: + """ + Finds hash matches that align in time with other matches and finds + consensus about which hashes are "true" signal from the audio. + + :param matches: matches from the database + :param dedup_hashes: dictionary containing the hashes matched without duplicates for each song + (key is the song id). + :param queried_hashes: amount of hashes sent for matching against the db + :param topn: number of results being returned back. + :return: a list of dictionaries (based on topn) with match information. + """ + # count offset occurrences per song and keep only the maximum ones. + sorted_matches = sorted(matches, key=lambda m: (m[0], m[1])) + counts = [(*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))] + songs_matches = sorted( + [max(list(group), key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])], + key=lambda count: count[2], reverse=True + ) + + songs_result = [] + for song_id, offset, _ in songs_matches[0:topn]: # consider topn elements in the result + song = self.db.get_song_by_id(song_id) + + song_name = song.get(SONG_NAME, None) + song_hashes = song.get(FIELD_TOTAL_HASHES, None) + nseconds = round(float(offset) / DEFAULT_FS * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO, 5) + hashes_matched = dedup_hashes[song_id] + + song = { + SONG_ID: song_id, + SONG_NAME: song_name.encode("utf8"), + INPUT_HASHES: queried_hashes, + FINGERPRINTED_HASHES: song_hashes, + HASHES_MATCHED: hashes_matched, + # Percentage regarding hashes matched vs hashes from the input. + INPUT_CONFIDENCE: round(hashes_matched / queried_hashes, 2), + # Percentage regarding hashes matched vs hashes fingerprinted in the db. + FINGERPRINTED_CONFIDENCE: round(hashes_matched / song_hashes, 2), + OFFSET: offset, + OFFSET_SECS: nseconds, + FIELD_FILE_SHA1: song.get(FIELD_FILE_SHA1, None).encode("utf8") + } + + songs_result.append(song) + + return songs_result + + def recognize(self, recognizer, *options, **kwoptions) -> Dict[str, any]: r = recognizer(self) return r.recognize(*options, **kwoptions) + @staticmethod + def _fingerprint_worker(arguments): + # Pool.imap sends arguments as tuples so we have to unpack + # them ourself. + try: + file_name, limit = arguments + except ValueError: + pass -def _fingerprint_worker(filename, limit=None, song_name=None): - # Pool.imap sends arguments as tuples so we have to unpack - # them ourself. - try: - filename, limit = filename - except ValueError: - pass + song_name, extension = os.path.splitext(os.path.basename(file_name)) - songname, extension = os.path.splitext(os.path.basename(filename)) - song_name = song_name or songname - channels, Fs, file_hash = decoder.read(filename, limit) - result = set() - channel_amount = len(channels) + fingerprints, file_hash = Dejavu.get_file_fingerprints(file_name, limit, print_output=True) - for channeln, channel in enumerate(channels): - # TODO: Remove prints or change them into optional logging. - print("Fingerprinting channel %d/%d for %s" % (channeln + 1, - channel_amount, - filename)) - hashes = fingerprint.fingerprint(channel, Fs=Fs) - print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, - filename)) - result |= set(hashes) + return song_name, fingerprints, file_hash - return song_name, result, file_hash + @staticmethod + def get_file_fingerprints(file_name: str, limit: int, print_output: bool = False): + channels, fs, file_hash = decoder.read(file_name, limit) + fingerprints = set() + channel_amount = len(channels) + for channeln, channel in enumerate(channels, start=1): + if print_output: + print(f"Fingerprinting channel {channeln}/{channel_amount} for {file_name}") + hashes = fingerprint(channel, Fs=fs) -def chunkify(lst, n): - """ - Splits a list into roughly n equal parts. - http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts - """ - return [lst[i::n] for i in xrange(n)] + if print_output: + print(f"Finished channel {channeln}/{channel_amount} for {file_name}") + + fingerprints |= set(hashes) + + return fingerprints, file_hash diff --git a/dejavu/base_classes/__init__.py b/dejavu/base_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/base_classes/base_database.py b/dejavu/base_classes/base_database.py new file mode 100755 index 0000000..839a72a --- /dev/null +++ b/dejavu/base_classes/base_database.py @@ -0,0 +1,195 @@ +import abc +import importlib +from typing import Dict, List, Tuple + +from dejavu.config.settings import DATABASES + + +class BaseDatabase(object, metaclass=abc.ABCMeta): + # Name of your Database subclass, this is used in configuration + # to refer to your class + type = None + + def __init__(self): + super().__init__() + + def before_fork(self) -> None: + """ + Called before the database instance is given to the new process + """ + pass + + def after_fork(self) -> None: + """ + Called after the database instance has been given to the new process + + This will be called in the new process. + """ + pass + + def setup(self) -> None: + """ + Called on creation or shortly afterwards. + """ + pass + + @abc.abstractmethod + def empty(self) -> None: + """ + Called when the database should be cleared of all data. + """ + pass + + @abc.abstractmethod + def delete_unfingerprinted_songs(self) -> None: + """ + Called to remove any song entries that do not have any fingerprints + associated with them. + """ + pass + + @abc.abstractmethod + def get_num_songs(self) -> int: + """ + Returns the song's count stored. + + :return: the amount of songs in the database. + """ + pass + + @abc.abstractmethod + def get_num_fingerprints(self) -> int: + """ + Returns the fingerprints' count stored. + + :return: the number of fingerprints in the database. + """ + pass + + @abc.abstractmethod + def set_song_fingerprinted(self, song_id: int): + """ + Sets a specific song as having all fingerprints in the database. + + :param song_id: song identifier. + """ + pass + + @abc.abstractmethod + def get_songs(self) -> List[Dict[str, str]]: + """ + Returns all fully fingerprinted songs in the database + + :return: a dictionary with the songs info. + """ + pass + + @abc.abstractmethod + def get_song_by_id(self, song_id: int) -> Dict[str, str]: + """ + Brings the song info from the database. + + :param song_id: song identifier. + :return: a song by its identifier. Result must be a Dictionary. + """ + pass + + @abc.abstractmethod + def insert(self, fingerprint: str, song_id: int, offset: int): + """ + Inserts a single fingerprint into the database. + + :param fingerprint: Part of a sha1 hash, in hexadecimal format + :param song_id: Song identifier this fingerprint is off + :param offset: The offset this fingerprint is from. + """ + pass + + @abc.abstractmethod + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: + """ + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. + """ + pass + + @abc.abstractmethod + def query(self, fingerprint: str = None) -> List[Tuple]: + """ + Returns all matching fingerprint entries associated with + the given hash as parameter, if None is passed it returns all entries. + + :param fingerprint: part of a sha1 hash, in hexadecimal format + :return: a list of fingerprint records stored in the db. + """ + pass + + @abc.abstractmethod + def get_iterable_kv_pairs(self) -> List[Tuple]: + """ + Returns all fingerprints in the database. + + :return: a list containing all fingerprints stored in the db. + """ + pass + + @abc.abstractmethod + def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None: + """ + Insert a multitude of fingerprints. + + :param song_id: Song identifier the fingerprints belong to + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: insert batches. + """ + + @abc.abstractmethod + def return_matches(self, hashes: List[Tuple[str, int]], batch_size: int = 1000) \ + -> Tuple[List[Tuple[int, int]], Dict[int, int]]: + """ + Searches the database for pairs of (hash, offset) values. + + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: number of query's batches. + :return: a list of (sid, offset_difference) tuples and a + dictionary with the amount of hashes matched (not considering + duplicated hashes) in each song. + - song id: Song identifier + - offset_difference: (database_offset - sampled_offset) + """ + pass + + @abc.abstractmethod + def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: + """ + Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + + :param song_ids: song ids to be deleted from the database. + :param batch_size: number of query's batches. + """ + pass + + +def get_database(database_type: str = "mysql") -> BaseDatabase: + """ + Given a database type it returns a database instance for that type. + + :param database_type: type of the database. + :return: an instance of BaseDatabase depending on given database_type. + """ + try: + path, db_class_name = DATABASES[database_type] + db_module = importlib.import_module(path) + db_class = getattr(db_module, db_class_name) + return db_class + except (ImportError, KeyError): + raise TypeError("Unsupported database type supplied.") diff --git a/dejavu/base_classes/base_recognizer.py b/dejavu/base_classes/base_recognizer.py new file mode 100644 index 0000000..c0f4749 --- /dev/null +++ b/dejavu/base_classes/base_recognizer.py @@ -0,0 +1,33 @@ +import abc +from time import time +from typing import Dict, List, Tuple + +import numpy as np + +from dejavu.config.settings import DEFAULT_FS + + +class BaseRecognizer(object, metaclass=abc.ABCMeta): + def __init__(self, dejavu): + self.dejavu = dejavu + self.Fs = DEFAULT_FS + + def _recognize(self, *data) -> Tuple[List[Dict[str, any]], int, int, int]: + fingerprint_times = [] + hashes = set() # to remove possible duplicated fingerprints we built a set. + for channel in data: + fingerprints, fingerprint_time = self.dejavu.generate_fingerprints(channel, Fs=self.Fs) + fingerprint_times.append(fingerprint_time) + hashes |= set(fingerprints) + + matches, dedup_hashes, query_time = self.dejavu.find_matches(hashes) + + t = time() + final_results = self.dejavu.align_matches(matches, dedup_hashes, len(hashes)) + align_time = time() - t + + return final_results, np.sum(fingerprint_times), query_time, align_time + + @abc.abstractmethod + def recognize(self) -> Dict[str, any]: + pass # base class does nothing diff --git a/dejavu/base_classes/common_database.py b/dejavu/base_classes/common_database.py new file mode 100644 index 0000000..e884285 --- /dev/null +++ b/dejavu/base_classes/common_database.py @@ -0,0 +1,232 @@ +import abc +from typing import Dict, List, Tuple + +from dejavu.base_classes.base_database import BaseDatabase + + +class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): + # Since several methods across different databases are actually just the same + # I've built this class with the idea to reuse that logic instead of copy pasting + # over and over the same code. + + def __init__(self): + super().__init__() + + def before_fork(self) -> None: + """ + Called before the database instance is given to the new process + """ + pass + + def after_fork(self) -> None: + """ + Called after the database instance has been given to the new process + + This will be called in the new process. + """ + pass + + def setup(self) -> None: + """ + Called on creation or shortly afterwards. + """ + with self.cursor() as cur: + cur.execute(self.CREATE_SONGS_TABLE) + cur.execute(self.CREATE_FINGERPRINTS_TABLE) + cur.execute(self.DELETE_UNFINGERPRINTED) + + def empty(self) -> None: + """ + Called when the database should be cleared of all data. + """ + with self.cursor() as cur: + cur.execute(self.DROP_FINGERPRINTS) + cur.execute(self.DROP_SONGS) + + self.setup() + + def delete_unfingerprinted_songs(self) -> None: + """ + Called to remove any song entries that do not have any fingerprints + associated with them. + """ + with self.cursor() as cur: + cur.execute(self.DELETE_UNFINGERPRINTED) + + def get_num_songs(self) -> int: + """ + Returns the song's count stored. + + :return: the amount of songs in the database. + """ + with self.cursor(buffered=True) as cur: + cur.execute(self.SELECT_UNIQUE_SONG_IDS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + + return count + + def get_num_fingerprints(self) -> int: + """ + Returns the fingerprints' count stored. + + :return: the number of fingerprints in the database. + """ + with self.cursor(buffered=True) as cur: + cur.execute(self.SELECT_NUM_FINGERPRINTS) + count = cur.fetchone()[0] if cur.rowcount != 0 else 0 + + return count + + def set_song_fingerprinted(self, song_id): + """ + Sets a specific song as having all fingerprints in the database. + + :param song_id: song identifier. + """ + with self.cursor() as cur: + cur.execute(self.UPDATE_SONG_FINGERPRINTED, (song_id,)) + + def get_songs(self) -> List[Dict[str, str]]: + """ + Returns all fully fingerprinted songs in the database + + :return: a dictionary with the songs info. + """ + with self.cursor(dictionary=True) as cur: + cur.execute(self.SELECT_SONGS) + return list(cur) + + def get_song_by_id(self, song_id: int) -> Dict[str, str]: + """ + Brings the song info from the database. + + :param song_id: song identifier. + :return: a song by its identifier. Result must be a Dictionary. + """ + with self.cursor(dictionary=True) as cur: + cur.execute(self.SELECT_SONG, (song_id,)) + return cur.fetchone() + + def insert(self, fingerprint: str, song_id: int, offset: int): + """ + Inserts a single fingerprint into the database. + + :param fingerprint: Part of a sha1 hash, in hexadecimal format + :param song_id: Song identifier this fingerprint is off + :param offset: The offset this fingerprint is from. + """ + with self.cursor() as cur: + cur.execute(self.INSERT_FINGERPRINT, (fingerprint, song_id, offset)) + + @abc.abstractmethod + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: + """ + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. + """ + pass + + def query(self, fingerprint: str = None) -> List[Tuple]: + """ + Returns all matching fingerprint entries associated with + the given hash as parameter, if None is passed it returns all entries. + + :param fingerprint: part of a sha1 hash, in hexadecimal format + :return: a list of fingerprint records stored in the db. + """ + with self.cursor() as cur: + if fingerprint: + cur.execute(self.SELECT, (fingerprint,)) + else: # select all if no key + cur.execute(self.SELECT_ALL) + return list(cur) + + def get_iterable_kv_pairs(self) -> List[Tuple]: + """ + Returns all fingerprints in the database. + + :return: a list containing all fingerprints stored in the db. + """ + return self.query(None) + + def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None: + """ + Insert a multitude of fingerprints. + + :param song_id: Song identifier the fingerprints belong to + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: insert batches. + """ + values = [(song_id, hsh, int(offset)) for hsh, offset in hashes] + + with self.cursor() as cur: + for index in range(0, len(hashes), batch_size): + cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch_size]) + + def return_matches(self, hashes: List[Tuple[str, int]], + batch_size: int = 1000) -> Tuple[List[Tuple[int, int]], Dict[int, int]]: + """ + Searches the database for pairs of (hash, offset) values. + + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: number of query's batches. + :return: a list of (sid, offset_difference) tuples and a + dictionary with the amount of hashes matched (not considering + duplicated hashes) in each song. + - song id: Song identifier + - offset_difference: (database_offset - sampled_offset) + """ + # Create a dictionary of hash => offset pairs for later lookups + mapper = {} + for hsh, offset in hashes: + if hsh.upper() in mapper.keys(): + mapper[hsh.upper()].append(offset) + else: + mapper[hsh.upper()] = [offset] + + values = list(mapper.keys()) + + # in order to count each hash only once per db offset we use the dic below + dedup_hashes = {} + + results = [] + with self.cursor() as cur: + for index in range(0, len(values), batch_size): + # Create our IN part of the query + query = self.SELECT_MULTIPLE % ', '.join([self.IN_MATCH] * len(values[index: index + batch_size])) + + cur.execute(query, values[index: index + batch_size]) + + for hsh, sid, offset in cur: + if sid not in dedup_hashes.keys(): + dedup_hashes[sid] = 1 + else: + dedup_hashes[sid] += 1 + # we now evaluate all offset for each hash matched + for song_sampled_offset in mapper[hsh]: + results.append((sid, offset - song_sampled_offset)) + + return results, dedup_hashes + + def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: + """ + Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + + :param song_ids: song ids to be deleted from the database. + :param batch_size: number of query's batches. + """ + with self.cursor() as cur: + for index in range(0, len(song_ids), batch_size): + # Create our IN part of the query + query = self.DELETE_SONGS % ', '.join(['%s'] * len(song_ids[index: index + batch_size])) + + cur.execute(query, song_ids[index: index + batch_size]) diff --git a/dejavu/config/__init__.py b/dejavu/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/config/settings.py b/dejavu/config/settings.py new file mode 100644 index 0000000..0e20569 --- /dev/null +++ b/dejavu/config/settings.py @@ -0,0 +1,102 @@ +# Dejavu + +# DEJAVU JSON RESPONSE +SONG_ID = "song_id" +SONG_NAME = 'song_name' +RESULTS = 'results' + +HASHES_MATCHED = 'hashes_matched_in_input' + +# Hashes fingerprinted in the db. +FINGERPRINTED_HASHES = 'fingerprinted_hashes_in_db' +# Percentage regarding hashes matched vs hashes fingerprinted in the db. +FINGERPRINTED_CONFIDENCE = 'fingerprinted_confidence' + +# Hashes generated from the input. +INPUT_HASHES = 'input_total_hashes' +# Percentage regarding hashes matched vs hashes from the input. +INPUT_CONFIDENCE = 'input_confidence' + +TOTAL_TIME = 'total_time' +FINGERPRINT_TIME = 'fingerprint_time' +QUERY_TIME = 'query_time' +ALIGN_TIME = 'align_time' +OFFSET = 'offset' +OFFSET_SECS = 'offset_seconds' + +# DATABASE CLASS INSTANCES: +DATABASES = { + 'mysql': ("dejavu.database_handler.mysql_database", "MySQLDatabase"), + 'postgres': ("dejavu.database_handler.postgres_database", "PostgreSQLDatabase") +} + +# TABLE SONGS +SONGS_TABLENAME = "songs" + +# SONGS FIELDS +FIELD_SONG_ID = 'song_id' +FIELD_SONGNAME = 'song_name' +FIELD_FINGERPRINTED = "fingerprinted" +FIELD_FILE_SHA1 = 'file_sha1' +FIELD_TOTAL_HASHES = 'total_hashes' + +# TABLE FINGERPRINTS +FINGERPRINTS_TABLENAME = "fingerprints" + +# FINGERPRINTS FIELDS +FIELD_HASH = 'hash' +FIELD_OFFSET = 'offset' + +# FINGERPRINTS CONFIG: +# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter +# changes the morphology mask when looking for maximum peaks on the spectrogram matrix. +# Possible values are: [1, 2] +# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this +# is the value used in the original dejavu code). +# And 2 sets a square mask, i.e. all elements are considered neighbors. +CONNECTIVITY_MASK = 2 + +# Sampling rate, related to the Nyquist conditions, which affects +# the range frequencies we can detect. +DEFAULT_FS = 44100 + +# Size of the FFT window, affects frequency granularity +DEFAULT_WINDOW_SIZE = 4096 + +# Ratio by which each sequential window overlaps the last and the +# next window. Higher overlap will allow a higher granularity of offset +# matching, but potentially more fingerprints. +DEFAULT_OVERLAP_RATIO = 0.5 + +# Degree to which a fingerprint can be paired with its neighbors. Higher values will +# cause more fingerprints, but potentially better accuracy. +DEFAULT_FAN_VALUE = 5 # 15 was the original value. + +# Minimum amplitude in spectrogram in order to be considered a peak. +# This can be raised to reduce number of fingerprints, but can negatively +# affect accuracy. +DEFAULT_AMP_MIN = 10 + +# Number of cells around an amplitude peak in the spectrogram in order +# for Dejavu to consider it a spectral peak. Higher values mean less +# fingerprints and faster matching, but can potentially affect accuracy. +PEAK_NEIGHBORHOOD_SIZE = 10 # 20 was the original value. + +# Thresholds on how close or far fingerprints can be in time in order +# to be paired as a fingerprint. If your max is too low, higher values of +# DEFAULT_FAN_VALUE may not perform as expected. +MIN_HASH_TIME_DELTA = 0 +MAX_HASH_TIME_DELTA = 200 + +# If True, will sort peaks temporally for fingerprinting; +# not sorting will cut down number of fingerprints, but potentially +# affect performance. +PEAK_SORT = True + +# Number of bits to grab from the front of the SHA1 hash in the +# fingerprint calculation. The more you grab, the more memory storage, +# with potentially lesser collisions of matches. +FINGERPRINT_REDUCTION = 20 + +# Number of results being returned for file recognition +TOPN = 2 diff --git a/dejavu/database.py b/dejavu/database.py deleted file mode 100755 index e5732ff..0000000 --- a/dejavu/database.py +++ /dev/null @@ -1,176 +0,0 @@ -from __future__ import absolute_import -import abc - - -class Database(object): - __metaclass__ = abc.ABCMeta - - FIELD_FILE_SHA1 = 'file_sha1' - FIELD_SONG_ID = 'song_id' - FIELD_SONGNAME = 'song_name' - FIELD_OFFSET = 'offset' - FIELD_HASH = 'hash' - - # Name of your Database subclass, this is used in configuration - # to refer to your class - type = None - - def __init__(self): - super(Database, self).__init__() - - def before_fork(self): - """ - Called before the database instance is given to the new process - """ - pass - - def after_fork(self): - """ - Called after the database instance has been given to the new process - - This will be called in the new process. - """ - pass - - def setup(self): - """ - Called on creation or shortly afterwards. - """ - pass - - @abc.abstractmethod - def empty(self): - """ - Called when the database should be cleared of all data. - """ - pass - - @abc.abstractmethod - def delete_unfingerprinted_songs(self): - """ - Called to remove any song entries that do not have any fingerprints - associated with them. - """ - pass - - @abc.abstractmethod - def get_num_songs(self): - """ - Returns the amount of songs in the database. - """ - pass - - @abc.abstractmethod - def get_num_fingerprints(self): - """ - Returns the number of fingerprints in the database. - """ - pass - - @abc.abstractmethod - def set_song_fingerprinted(self, sid): - """ - Sets a specific song as having all fingerprints in the database. - - sid: Song identifier - """ - pass - - @abc.abstractmethod - def get_songs(self): - """ - Returns all fully fingerprinted songs in the database. - """ - pass - - @abc.abstractmethod - def get_song_by_id(self, sid): - """ - Return a song by its identifier - - sid: Song identifier - """ - pass - - @abc.abstractmethod - def insert(self, hash, sid, offset): - """ - Inserts a single fingerprint into the database. - - hash: Part of a sha1 hash, in hexadecimal format - sid: Song identifier this fingerprint is off - offset: The offset this hash is from - """ - pass - - @abc.abstractmethod - def insert_song(self, song_name): - """ - Inserts a song name into the database, returns the new - identifier of the song. - - song_name: The name of the song. - """ - pass - - @abc.abstractmethod - def query(self, hash): - """ - Returns all matching fingerprint entries associated with - the given hash as parameter. - - hash: Part of a sha1 hash, in hexadecimal format - """ - pass - - @abc.abstractmethod - def get_iterable_kv_pairs(self): - """ - Returns all fingerprints in the database. - """ - pass - - @abc.abstractmethod - def insert_hashes(self, sid, hashes): - """ - Insert a multitude of fingerprints. - - sid: Song identifier the fingerprints belong to - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. - """ - pass - - @abc.abstractmethod - def return_matches(self, hashes): - """ - Searches the database for pairs of (hash, offset) values. - - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. - - Returns a sequence of (sid, offset_difference) tuples. - - sid: Song identifier - offset_difference: (offset - database_offset) - """ - pass - - -def get_database(database_type=None): - # Default to using the mysql database - database_type = database_type or "mysql" - # Lower all the input. - database_type = database_type.lower() - - for db_cls in Database.__subclasses__(): - if db_cls.type == database_type: - return db_cls - - raise TypeError("Unsupported database type supplied.") - - -# Import our default database handler -import dejavu.database_sql diff --git a/dejavu/database_handler/__init__.py b/dejavu/database_handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/database_handler/mysql_database.py b/dejavu/database_handler/mysql_database.py new file mode 100755 index 0000000..1a8c506 --- /dev/null +++ b/dejavu/database_handler/mysql_database.py @@ -0,0 +1,203 @@ +import queue + +import mysql.connector +from mysql.connector.errors import DatabaseError + +from dejavu.base_classes.common_database import CommonDatabase +from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FIELD_TOTAL_HASHES, + FINGERPRINTS_TABLENAME, SONGS_TABLENAME) + + +class MySQLDatabase(CommonDatabase): + type = "mysql" + + # CREATES + CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` ( + `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT + , `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL + , `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0 + , `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL + , `{FIELD_TOTAL_HASHES}` INT NOT NULL DEFAULT 0 + , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + , CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`) + , CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`) + ) ENGINE=INNODB; + """ + + CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_HASH}` BINARY(10) NOT NULL + , `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL + , `{FIELD_OFFSET}` INT UNSIGNED NOT NULL + , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + , INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`) + , CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}` + UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`) + , CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`) + REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE + ) ENGINE=INNODB; + """ + + # INSERTS (IGNORES DUPLICATES) + INSERT_FINGERPRINT = f""" + INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` ( + `{FIELD_SONG_ID}` + , `{FIELD_HASH}` + , `{FIELD_OFFSET}`) + VALUES (%s, UNHEX(%s), %s); + """ + + INSERT_SONG = f""" + INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`,`{FIELD_TOTAL_HASHES}`) + VALUES (%s, UNHEX(%s), %s); + """ + + # SELECTS + SELECT = f""" + SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` = UNHEX(%s); + """ + + SELECT_MULTIPLE = f""" + SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` + FROM `{FINGERPRINTS_TABLENAME}` + WHERE `{FIELD_HASH}` IN (%s); + """ + + SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;" + + SELECT_SONG = f""" + SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`, `{FIELD_TOTAL_HASHES}` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_SONG_ID}` = %s; + """ + + SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;" + + SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT(`{FIELD_SONG_ID}`) AS n + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; + """ + + SELECT_SONGS = f""" + SELECT + `{FIELD_SONG_ID}` + , `{FIELD_SONGNAME}` + , HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + , `{FIELD_TOTAL_HASHES}` + , `date_created` + FROM `{SONGS_TABLENAME}` + WHERE `{FIELD_FINGERPRINTED}` = 1; + """ + + # DROPS + DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;" + DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;" + + # UPDATE + UPDATE_SONG_FINGERPRINTED = f""" + UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s; + """ + + # DELETES + DELETE_UNFINGERPRINTED = f""" + DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0; + """ + + DELETE_SONGS = f""" + DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_SONG_ID}` IN (%s); + """ + + # IN + IN_MATCH = f"UNHEX(%s)" + + def __init__(self, **options): + super().__init__() + self.cursor = cursor_factory(**options) + self._options = options + + def after_fork(self) -> None: + # Clear the cursor cache, we don't want any stale connections from + # the previous process. + Cursor.clear_cache() + + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: + """ + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. + """ + with self.cursor() as cur: + cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes)) + return cur.lastrowid + + def __getstate__(self): + return self._options, + + def __setstate__(self, state): + self._options, = state + self.cursor = cursor_factory(**self._options) + + +def cursor_factory(**factory_options): + def cursor(**options): + options.update(factory_options) + return Cursor(**options) + return cursor + + +class Cursor(object): + """ + Establishes a connection to the database and returns an open cursor. + # Use as context manager + with Cursor() as cur: + cur.execute(query) + ... + """ + def __init__(self, dictionary=False, **options): + super().__init__() + + self._cache = queue.Queue(maxsize=5) + + try: + conn = self._cache.get_nowait() + # Ping the connection before using it from the cache. + conn.ping(True) + except queue.Empty: + conn = mysql.connector.connect(**options) + + self.conn = conn + self.dictionary = dictionary + + @classmethod + def clear_cache(cls): + cls._cache = queue.Queue(maxsize=5) + + def __enter__(self): + self.cursor = self.conn.cursor(dictionary=self.dictionary) + return self.cursor + + def __exit__(self, extype, exvalue, traceback): + # if we had a MySQL related error we try to rollback the cursor. + if extype is DatabaseError: + self.cursor.rollback() + + self.cursor.close() + self.conn.commit() + + # Put it back on the queue + try: + self._cache.put_nowait(self.conn) + except queue.Full: + self.conn.close() diff --git a/dejavu/database_handler/postgres_database.py b/dejavu/database_handler/postgres_database.py new file mode 100755 index 0000000..4ac7131 --- /dev/null +++ b/dejavu/database_handler/postgres_database.py @@ -0,0 +1,219 @@ +import queue + +import psycopg2 +from psycopg2.extras import DictCursor + +from dejavu.base_classes.common_database import CommonDatabase +from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, + FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, + FIELD_SONGNAME, FIELD_TOTAL_HASHES, + FINGERPRINTS_TABLENAME, SONGS_TABLENAME) + + +class PostgreSQLDatabase(CommonDatabase): + type = "postgres" + + # CREATES + CREATE_SONGS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" ( + "{FIELD_SONG_ID}" SERIAL + , "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL + , "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0 + , "{FIELD_FILE_SHA1}" BYTEA + , "{FIELD_TOTAL_HASHES}" INT NOT NULL DEFAULT 0 + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}") + , CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}") + ); + """ + + CREATE_FINGERPRINTS_TABLE = f""" + CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_HASH}" BYTEA NOT NULL + , "{FIELD_SONG_ID}" INT NOT NULL + , "{FIELD_OFFSET}" INT NOT NULL + , "date_created" TIMESTAMP NOT NULL DEFAULT now() + , "date_modified" TIMESTAMP NOT NULL DEFAULT now() + , CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}") + , CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}") + REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" + USING hash ("{FIELD_HASH}"); + """ + + CREATE_FINGERPRINTS_TABLE_INDEX = f""" + CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" + USING hash ("{FIELD_HASH}"); + """ + + # INSERTS (IGNORES DUPLICATES) + INSERT_FINGERPRINT = f""" + INSERT INTO "{FINGERPRINTS_TABLENAME}" ( + "{FIELD_SONG_ID}" + , "{FIELD_HASH}" + , "{FIELD_OFFSET}") + VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING; + """ + + INSERT_SONG = f""" + INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}","{FIELD_TOTAL_HASHES}") + VALUES (%s, decode(%s, 'hex'), %s) + RETURNING "{FIELD_SONG_ID}"; + """ + + # SELECTS + SELECT = f""" + SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" = decode(%s, 'hex'); + """ + + SELECT_MULTIPLE = f""" + SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}" + FROM "{FINGERPRINTS_TABLENAME}" + WHERE "{FIELD_HASH}" IN (%s); + """ + + SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";' + + SELECT_SONG = f""" + SELECT + "{FIELD_SONGNAME}" + , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + , "{FIELD_TOTAL_HASHES}" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_SONG_ID}" = %s; + """ + + SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";' + + SELECT_UNIQUE_SONG_IDS = f""" + SELECT COUNT("{FIELD_SONG_ID}") AS n + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; + """ + + SELECT_SONGS = f""" + SELECT + "{FIELD_SONG_ID}" + , "{FIELD_SONGNAME}" + , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + , "{FIELD_TOTAL_HASHES}" + , "date_created" + FROM "{SONGS_TABLENAME}" + WHERE "{FIELD_FINGERPRINTED}" = 1; + """ + + # DROPS + DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";' + DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";' + + # UPDATE + UPDATE_SONG_FINGERPRINTED = f""" + UPDATE "{SONGS_TABLENAME}" SET + "{FIELD_FINGERPRINTED}" = 1 + , "date_modified" = now() + WHERE "{FIELD_SONG_ID}" = %s; + """ + + # DELETES + DELETE_UNFINGERPRINTED = f""" + DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0; + """ + + DELETE_SONGS = f""" + DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_SONG_ID}" IN (%s); + """ + + # IN + IN_MATCH = f"decode(%s, 'hex')" + + def __init__(self, **options): + super().__init__() + self.cursor = cursor_factory(**options) + self._options = options + + def after_fork(self) -> None: + # Clear the cursor cache, we don't want any stale connections from + # the previous process. + Cursor.clear_cache() + + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: + """ + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. + """ + with self.cursor() as cur: + cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes)) + return cur.fetchone()[0] + + def __getstate__(self): + return self._options, + + def __setstate__(self, state): + self._options, = state + self.cursor = cursor_factory(**self._options) + + +def cursor_factory(**factory_options): + def cursor(**options): + options.update(factory_options) + return Cursor(**options) + return cursor + + +class Cursor(object): + """ + Establishes a connection to the database and returns an open cursor. + # Use as context manager + with Cursor() as cur: + cur.execute(query) + ... + """ + def __init__(self, dictionary=False, **options): + super().__init__() + + self._cache = queue.Queue(maxsize=5) + + try: + conn = self._cache.get_nowait() + # Ping the connection before using it from the cache. + conn.ping(True) + except queue.Empty: + conn = psycopg2.connect(**options) + + self.conn = conn + self.dictionary = dictionary + + @classmethod + def clear_cache(cls): + cls._cache = queue.Queue(maxsize=5) + + def __enter__(self): + if self.dictionary: + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + else: + self.cursor = self.conn.cursor() + return self.cursor + + def __exit__(self, extype, exvalue, traceback): + # if we had a PostgreSQL related error we try to rollback the cursor. + if extype is psycopg2.DatabaseError: + self.cursor.rollback() + + self.cursor.close() + self.conn.commit() + + # Put it back on the queue + try: + self._cache.put_nowait(self.conn) + except queue.Full: + self.conn.close() diff --git a/dejavu/database_sql.py b/dejavu/database_sql.py deleted file mode 100755 index 0fe2e68..0000000 --- a/dejavu/database_sql.py +++ /dev/null @@ -1,373 +0,0 @@ -from __future__ import absolute_import -from itertools import izip_longest -import Queue - -import MySQLdb as mysql -from MySQLdb.cursors import DictCursor - -from dejavu.database import Database - - -class SQLDatabase(Database): - """ - Queries: - - 1) Find duplicates (shouldn't be any, though): - - select `hash`, `song_id`, `offset`, count(*) cnt - from fingerprints - group by `hash`, `song_id`, `offset` - having cnt > 1 - order by cnt asc; - - 2) Get number of hashes by song: - - select song_id, song_name, count(song_id) as num - from fingerprints - natural join songs - group by song_id - order by count(song_id) desc; - - 3) get hashes with highest number of collisions - - select - hash, - count(distinct song_id) as n - from fingerprints - group by `hash` - order by n DESC; - - => 26 different songs with same fingerprint (392 times): - - select songs.song_name, fingerprints.offset - from fingerprints natural join songs - where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; - """ - - type = "mysql" - - # tables - FINGERPRINTS_TABLENAME = "fingerprints" - SONGS_TABLENAME = "songs" - - # fields - FIELD_FINGERPRINTED = "fingerprinted" - - # creates - CREATE_FINGERPRINTS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` binary(10) not null, - `%s` mediumint unsigned not null, - `%s` int unsigned not null, - INDEX (%s), - UNIQUE KEY `unique_constraint` (%s, %s, %s), - FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE - ) ENGINE=INNODB;""" % ( - FINGERPRINTS_TABLENAME, Database.FIELD_HASH, - Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH, - Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH, - Database.FIELD_SONG_ID, SONGS_TABLENAME, Database.FIELD_SONG_ID - ) - - CREATE_SONGS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` mediumint unsigned not null auto_increment, - `%s` varchar(250) not null, - `%s` tinyint default 0, - `%s` binary(20) not null, - PRIMARY KEY (`%s`), - UNIQUE KEY `%s` (`%s`) - ) ENGINE=INNODB;""" % ( - SONGS_TABLENAME, Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, FIELD_FINGERPRINTED, - Database.FIELD_FILE_SHA1, - Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, - ) - - # inserts (ignores duplicates) - INSERT_FINGERPRINT = """ - INSERT IGNORE INTO %s (%s, %s, %s) values - (UNHEX(%%s), %%s, %%s); - """ % (FINGERPRINTS_TABLENAME, Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET) - - INSERT_SONG = "INSERT INTO %s (%s, %s) values (%%s, UNHEX(%%s));" % ( - SONGS_TABLENAME, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1) - - # selects - SELECT = """ - SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s); - """ % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME, Database.FIELD_HASH) - - SELECT_MULTIPLE = """ - SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s); - """ % (Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET, - FINGERPRINTS_TABLENAME, Database.FIELD_HASH) - - SELECT_ALL = """ - SELECT %s, %s FROM %s; - """ % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME) - - SELECT_SONG = """ - SELECT %s, HEX(%s) as %s FROM %s WHERE %s = %%s; - """ % (Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, SONGS_TABLENAME, Database.FIELD_SONG_ID) - - SELECT_NUM_FINGERPRINTS = """ - SELECT COUNT(*) as n FROM %s - """ % (FINGERPRINTS_TABLENAME) - - SELECT_UNIQUE_SONG_IDS = """ - SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1; - """ % (Database.FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) - - SELECT_SONGS = """ - SELECT %s, %s, HEX(%s) as %s FROM %s WHERE %s = 1; - """ % (Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, - SONGS_TABLENAME, FIELD_FINGERPRINTED) - - # drops - DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME - DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME - - # update - UPDATE_SONG_FINGERPRINTED = """ - UPDATE %s SET %s = 1 WHERE %s = %%s - """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED, Database.FIELD_SONG_ID) - - # delete - DELETE_UNFINGERPRINTED = """ - DELETE FROM %s WHERE %s = 0; - """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED) - - def __init__(self, **options): - super(SQLDatabase, self).__init__() - self.cursor = cursor_factory(**options) - self._options = options - - def after_fork(self): - # Clear the cursor cache, we don't want any stale connections from - # the previous process. - Cursor.clear_cache() - - def setup(self): - """ - Creates any non-existing tables required for dejavu to function. - - This also removes all songs that have been added but have no - fingerprints associated with them. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.CREATE_SONGS_TABLE) - cur.execute(self.CREATE_FINGERPRINTS_TABLE) - cur.execute(self.DELETE_UNFINGERPRINTED) - - def empty(self): - """ - Drops tables created by dejavu and then creates them again - by calling `SQLDatabase.setup`. - - .. warning: - This will result in a loss of data - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.DROP_FINGERPRINTS) - cur.execute(self.DROP_SONGS) - - self.setup() - - def delete_unfingerprinted_songs(self): - """ - Removes all songs that have no fingerprints associated with them. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.DELETE_UNFINGERPRINTED) - - def get_num_songs(self): - """ - Returns number of songs the database has fingerprinted. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.SELECT_UNIQUE_SONG_IDS) - - for count, in cur: - return count - return 0 - - def get_num_fingerprints(self): - """ - Returns number of fingerprints the database has fingerprinted. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.SELECT_NUM_FINGERPRINTS) - - for count, in cur: - return count - return 0 - - def set_song_fingerprinted(self, sid): - """ - Set the fingerprinted flag to TRUE (1) once a song has been completely - fingerprinted in the database. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,)) - - def get_songs(self): - """ - Return songs that have the fingerprinted flag set TRUE (1). - """ - with self.cursor(cursor_type=DictCursor, charset="utf8") as cur: - cur.execute(self.SELECT_SONGS) - for row in cur: - yield row - - def get_song_by_id(self, sid): - """ - Returns song by its ID. - """ - with self.cursor(cursor_type=DictCursor, charset="utf8") as cur: - cur.execute(self.SELECT_SONG, (sid,)) - return cur.fetchone() - - def insert(self, hash, sid, offset): - """ - Insert a (sha1, song_id, offset) row into database. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset)) - - def insert_song(self, songname, file_hash): - """ - Inserts song in the database and returns the ID of the inserted record. - """ - with self.cursor(charset="utf8") as cur: - cur.execute(self.INSERT_SONG, (songname, file_hash)) - return cur.lastrowid - - def query(self, hash): - """ - Return all tuples associated with hash. - - If hash is None, returns all entries in the - database (be careful with that one!). - """ - # select all if no key - query = self.SELECT_ALL if hash is None else self.SELECT - - with self.cursor(charset="utf8") as cur: - cur.execute(query) - for sid, offset in cur: - yield (sid, offset) - - def get_iterable_kv_pairs(self): - """ - Returns all tuples in database. - """ - return self.query(None) - - def insert_hashes(self, sid, hashes): - """ - Insert series of hash => song_id, offset - values into the database. - """ - values = [] - for hash, offset in hashes: - values.append((hash, sid, offset)) - - with self.cursor(charset="utf8") as cur: - for split_values in grouper(values, 1000): - cur.executemany(self.INSERT_FINGERPRINT, split_values) - - def return_matches(self, hashes): - """ - Return the (song_id, offset_diff) tuples associated with - a list of (sha1, sample_offset) values. - """ - # Create a dictionary of hash => offset pairs for later lookups - mapper = {} - for hash, offset in hashes: - mapper[hash.upper()] = offset - - # Get an iteratable of all the hashes we need - values = mapper.keys() - - with self.cursor(charset="utf8") as cur: - for split_values in grouper(values, 1000): - # Create our IN part of the query - query = self.SELECT_MULTIPLE - query = query % ', '.join(['UNHEX(%s)'] * len(split_values)) - - cur.execute(query, split_values) - - for hash, sid, offset in cur: - # (sid, db_offset - song_sampled_offset) - yield (sid, offset - mapper[hash]) - - def __getstate__(self): - return (self._options,) - - def __setstate__(self, state): - self._options, = state - self.cursor = cursor_factory(**self._options) - - -def grouper(iterable, n, fillvalue=None): - args = [iter(iterable)] * n - return (filter(None, values) for values - in izip_longest(fillvalue=fillvalue, *args)) - - -def cursor_factory(**factory_options): - def cursor(**options): - options.update(factory_options) - return Cursor(**options) - return cursor - - -class Cursor(object): - """ - Establishes a connection to the database and returns an open cursor. - - - ```python - # Use as context manager - with Cursor() as cur: - cur.execute(query) - ``` - """ - - def __init__(self, cursor_type=mysql.cursors.Cursor, **options): - super(Cursor, self).__init__() - - self._cache = Queue.Queue(maxsize=5) - try: - conn = self._cache.get_nowait() - except Queue.Empty: - conn = mysql.connect(**options) - else: - # Ping the connection before using it from the cache. - conn.ping(True) - - self.conn = conn - self.conn.autocommit(False) - self.cursor_type = cursor_type - - @classmethod - def clear_cache(cls): - cls._cache = Queue.Queue(maxsize=5) - - def __enter__(self): - self.cursor = self.conn.cursor(self.cursor_type) - return self.cursor - - def __exit__(self, extype, exvalue, traceback): - # if we had a MySQL related error we try to rollback the cursor. - if extype is mysql.MySQLError: - self.cursor.rollback() - - self.cursor.close() - self.conn.commit() - - # Put it back on the queue - try: - self._cache.put_nowait(self.conn) - except Queue.Full: - self.conn.close() diff --git a/dejavu/fingerprint.py b/dejavu/fingerprint.py deleted file mode 100755 index f56118a..0000000 --- a/dejavu/fingerprint.py +++ /dev/null @@ -1,157 +0,0 @@ -import numpy as np -import matplotlib.mlab as mlab -import matplotlib.pyplot as plt -from scipy.ndimage.filters import maximum_filter -from scipy.ndimage.morphology import (generate_binary_structure, - iterate_structure, binary_erosion) -import hashlib -from operator import itemgetter - -IDX_FREQ_I = 0 -IDX_TIME_J = 1 - -###################################################################### -# Sampling rate, related to the Nyquist conditions, which affects -# the range frequencies we can detect. -DEFAULT_FS = 44100 - -###################################################################### -# Size of the FFT window, affects frequency granularity -DEFAULT_WINDOW_SIZE = 4096 - -###################################################################### -# Ratio by which each sequential window overlaps the last and the -# next window. Higher overlap will allow a higher granularity of offset -# matching, but potentially more fingerprints. -DEFAULT_OVERLAP_RATIO = 0.5 - -###################################################################### -# Degree to which a fingerprint can be paired with its neighbors -- -# higher will cause more fingerprints, but potentially better accuracy. -DEFAULT_FAN_VALUE = 15 - -###################################################################### -# Minimum amplitude in spectrogram in order to be considered a peak. -# This can be raised to reduce number of fingerprints, but can negatively -# affect accuracy. -DEFAULT_AMP_MIN = 10 - -###################################################################### -# Number of cells around an amplitude peak in the spectrogram in order -# for Dejavu to consider it a spectral peak. Higher values mean less -# fingerprints and faster matching, but can potentially affect accuracy. -PEAK_NEIGHBORHOOD_SIZE = 20 - -###################################################################### -# Thresholds on how close or far fingerprints can be in time in order -# to be paired as a fingerprint. If your max is too low, higher values of -# DEFAULT_FAN_VALUE may not perform as expected. -MIN_HASH_TIME_DELTA = 0 -MAX_HASH_TIME_DELTA = 200 - -###################################################################### -# If True, will sort peaks temporally for fingerprinting; -# not sorting will cut down number of fingerprints, but potentially -# affect performance. -PEAK_SORT = True - -###################################################################### -# Number of bits to grab from the front of the SHA1 hash in the -# fingerprint calculation. The more you grab, the more memory storage, -# with potentially lesser collisions of matches. -FINGERPRINT_REDUCTION = 20 - -def fingerprint(channel_samples, Fs=DEFAULT_FS, - wsize=DEFAULT_WINDOW_SIZE, - wratio=DEFAULT_OVERLAP_RATIO, - fan_value=DEFAULT_FAN_VALUE, - amp_min=DEFAULT_AMP_MIN): - """ - FFT the channel, log transform output, find local maxima, then return - locally sensitive hashes. - """ - # FFT the signal and extract frequency components - arr2D = mlab.specgram( - channel_samples, - NFFT=wsize, - Fs=Fs, - window=mlab.window_hanning, - noverlap=int(wsize * wratio))[0] - - # apply log transform since specgram() returns linear array - arr2D = 10 * np.log10(arr2D) - arr2D[arr2D == -np.inf] = 0 # replace infs with zeros - - # find local maxima - local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) - - # return hashes - return generate_hashes(local_maxima, fan_value=fan_value) - - -def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): - # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure - struct = generate_binary_structure(2, 1) - neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) - - # find local maxima using our filter shape - local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D - background = (arr2D == 0) - eroded_background = binary_erosion(background, structure=neighborhood, - border_value=1) - - # Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^') - detected_peaks = local_max ^ eroded_background - - # extract peaks - amps = arr2D[detected_peaks] - j, i = np.where(detected_peaks) - - # filter peaks - amps = amps.flatten() - peaks = zip(i, j, amps) - peaks_filtered = filter(lambda x: x[2]>amp_min, peaks) # freq, time, amp - # get indices for frequency and time - frequency_idx = [] - time_idx = [] - for x in peaks_filtered: - frequency_idx.append(x[1]) - time_idx.append(x[0]) - - if plot: - # scatter of the peaks - fig, ax = plt.subplots() - ax.imshow(arr2D) - ax.scatter(time_idx, frequency_idx) - ax.set_xlabel('Time') - ax.set_ylabel('Frequency') - ax.set_title("Spectrogram") - plt.gca().invert_yaxis() - plt.show() - - return zip(frequency_idx, time_idx) - - -def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): - """ - Hash list structure: - sha1_hash[0:20] time_offset - [(e05b341a9b77a51fd26, 32), ... ] - """ - if PEAK_SORT: - peaks.sort(key=itemgetter(1)) - - for i in range(len(peaks)): - for j in range(1, fan_value): - if (i + j) < len(peaks): - - freq1 = peaks[i][IDX_FREQ_I] - freq2 = peaks[i + j][IDX_FREQ_I] - t1 = peaks[i][IDX_TIME_J] - t2 = peaks[i + j][IDX_TIME_J] - t_delta = t2 - t1 - - if t_delta >= MIN_HASH_TIME_DELTA and t_delta <= MAX_HASH_TIME_DELTA: - h = hashlib.sha1( - "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) - yield (h.hexdigest()[0:FINGERPRINT_REDUCTION], t1) diff --git a/dejavu/logic/__init__.py b/dejavu/logic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/decoder.py b/dejavu/logic/decoder.py similarity index 50% rename from dejavu/decoder.py rename to dejavu/logic/decoder.py index 04aa39f..ccafa26 100755 --- a/dejavu/decoder.py +++ b/dejavu/logic/decoder.py @@ -1,40 +1,57 @@ -import os import fnmatch +import os +from hashlib import sha1 +from typing import List, Tuple + import numpy as np from pydub import AudioSegment from pydub.utils import audioop -import wavio -from hashlib import sha1 -def unique_hash(filepath, blocksize=2**20): +from dejavu.third_party import wavio + + +def unique_hash(file_path: str, block_size: int = 2**20) -> str: """ Small function to generate a hash to uniquely generate a file. Inspired by MD5 version here: http://stackoverflow.com/a/1131255/712997 - Works with large files. + Works with large files. + + :param file_path: path to file. + :param block_size: read block size. + :return: a hash in an hexagesimal string form. """ s = sha1() - with open(filepath , "rb") as f: + with open(file_path, "rb") as f: while True: - buf = f.read(blocksize) + buf = f.read(block_size) if not buf: break s.update(buf) return s.hexdigest().upper() -def find_files(path, extensions): +def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]: + """ + Get all files that meet the specified extensions. + + :param path: path to a directory with audio files. + :param extensions: file extensions to look for. + :return: a list of tuples with file name and its extension. + """ # Allow both with ".mp3" and without "mp3" to be used for extensions extensions = [e.replace(".", "") for e in extensions] + results = [] for dirpath, dirnames, files in os.walk(path): for extension in extensions: - for f in fnmatch.filter(files, "*.%s" % extension): + for f in fnmatch.filter(files, f"*.{extension}"): p = os.path.join(dirpath, f) - yield (p, extension) + results.append((p, extension)) + return results -def read(filename, limit=None): +def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: """ Reads any file supported by pydub (ffmpeg) and returns the data contained within. If file reading fails due to input being a 24-bit wav file, @@ -44,24 +61,26 @@ def read(filename, limit=None): of the file by specifying the `limit` parameter. This is the amount of seconds from the start of the file. - returns: (channels, samplerate) + :param file_name: file to be read. + :param limit: number of seconds to limit. + :return: tuple list of (channels, sample_rate, content_file_hash). """ # pydub does not support 24-bit wav files, use wavio when this occurs try: - audiofile = AudioSegment.from_file(filename) + audiofile = AudioSegment.from_file(file_name) if limit: audiofile = audiofile[:limit * 1000] - data = np.fromstring(audiofile._data, np.int16) + data = np.fromstring(audiofile.raw_data, np.int16) channels = [] - for chn in xrange(audiofile.channels): + for chn in range(audiofile.channels): channels.append(data[chn::audiofile.channels]) - fs = audiofile.frame_rate + audiofile.frame_rate except audioop.error: - fs, _, audiofile = wavio.readwav(filename) + _, _, audiofile = wavio.readwav(file_name) if limit: audiofile = audiofile[:limit * 1000] @@ -73,12 +92,14 @@ def read(filename, limit=None): for chn in audiofile: channels.append(chn) - return channels, audiofile.frame_rate, unique_hash(filename) + return channels, audiofile.frame_rate, unique_hash(file_name) -def path_to_songname(path): +def get_audio_name_from_path(file_path: str) -> str: """ - Extracts song name from a filepath. Used to identify which songs - have already been fingerprinted on disk. + Extracts song name from a file path. + + :param file_path: path to an audio file. + :return: file name """ - return os.path.splitext(os.path.basename(path))[0] + return os.path.splitext(os.path.basename(file_path))[0] diff --git a/dejavu/logic/fingerprint.py b/dejavu/logic/fingerprint.py new file mode 100755 index 0000000..c3089aa --- /dev/null +++ b/dejavu/logic/fingerprint.py @@ -0,0 +1,156 @@ +import hashlib +from operator import itemgetter +from typing import List, Tuple + +import matplotlib.mlab as mlab +import matplotlib.pyplot as plt +import numpy as np +from scipy.ndimage.filters import maximum_filter +from scipy.ndimage.morphology import (binary_erosion, + generate_binary_structure, + iterate_structure) + +from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN, + DEFAULT_FAN_VALUE, DEFAULT_FS, + DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, + FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA, + MIN_HASH_TIME_DELTA, + PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) + + +def fingerprint(channel_samples: List[int], + Fs: int = DEFAULT_FS, + wsize: int = DEFAULT_WINDOW_SIZE, + wratio: float = DEFAULT_OVERLAP_RATIO, + fan_value: int = DEFAULT_FAN_VALUE, + amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]: + """ + FFT the channel, log transform output, find local maxima, then return locally sensitive hashes. + + :param channel_samples: channel samples to fingerprint. + :param Fs: audio sampling rate. + :param wsize: FFT windows size. + :param wratio: ratio by which each sequential window overlaps the last and the next window. + :param fan_value: degree to which a fingerprint can be paired with its neighbors. + :param amp_min: minimum amplitude in spectrogram in order to be considered a peak. + :return: a list of hashes with their corresponding offsets. + """ + # FFT the signal and extract frequency components + arr2D = mlab.specgram( + channel_samples, + NFFT=wsize, + Fs=Fs, + window=mlab.window_hanning, + noverlap=int(wsize * wratio))[0] + + # Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning. + arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0)) + + local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) + + # return hashes + return generate_hashes(local_maxima, fan_value=fan_value) + + +def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\ + -> List[Tuple[List[int], List[int]]]: + """ + Extract maximum peaks from the spectogram matrix (arr2D). + + :param arr2D: matrix representing the spectogram. + :param plot: for plotting the results. + :param amp_min: minimum amplitude in spectrogram in order to be considered a peak. + :return: a list composed by a list of frequencies and times. + """ + # Original code from the repo is using a morphology mask that does not consider diagonal elements + # as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing + # is to change from the current diamond figure to a just a normal square one: + # F T F T T T + # T T T ==> T T T + # F T F T T T + # In my local tests time performance of the square mask was ~3 times faster + # respect to the diamond one, without hurting accuracy of the predictions. + # I've made now the mask shape configurable in order to allow both ways of find maximum peaks. + # That being said, we generate the mask by using the following function + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html + struct = generate_binary_structure(2, CONNECTIVITY_MASK) + + # And then we apply dilation using the following function + # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html + # Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just + # change it by the following code: + # neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool) + neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) + + # find local maxima using our filter mask + local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D + + # Applying erosion, the dejavu documentation does not talk about this step. + background = (arr2D == 0) + eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) + + # Boolean mask of arr2D with True at peaks (applying XOR on both matrices). + detected_peaks = local_max != eroded_background + + # extract peaks + amps = arr2D[detected_peaks] + freqs, times = np.where(detected_peaks) + + # filter peaks + amps = amps.flatten() + + # get indices for frequency and time + filter_idxs = np.where(amps > amp_min) + + freqs_filter = freqs[filter_idxs] + times_filter = times[filter_idxs] + + if plot: + # scatter of the peaks + fig, ax = plt.subplots() + ax.imshow(arr2D) + ax.scatter(times_filter, freqs_filter) + ax.set_xlabel('Time') + ax.set_ylabel('Frequency') + ax.set_title("Spectrogram") + plt.gca().invert_yaxis() + plt.show() + + return list(zip(freqs_filter, times_filter)) + + +def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]: + """ + Hash list structure: + sha1_hash[0:FINGERPRINT_REDUCTION] time_offset + [(e05b341a9b77a51fd26, 32), ... ] + + :param peaks: list of peak frequencies and times. + :param fan_value: degree to which a fingerprint can be paired with its neighbors. + :return: a list of hashes with their corresponding offsets. + """ + # frequencies are in the first position of the tuples + idx_freq = 0 + # times are in the second position of the tuples + idx_time = 1 + + if PEAK_SORT: + peaks.sort(key=itemgetter(1)) + + hashes = [] + for i in range(len(peaks)): + for j in range(1, fan_value): + if (i + j) < len(peaks): + + freq1 = peaks[i][idx_freq] + freq2 = peaks[i + j][idx_freq] + t1 = peaks[i][idx_time] + t2 = peaks[i + j][idx_time] + t_delta = t2 - t1 + + if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: + h = hashlib.sha1(f"{str(freq1)}|{str(freq2)}|{str(t_delta)}".encode('utf-8')) + + hashes.append((h.hexdigest()[0:FINGERPRINT_REDUCTION], t1)) + + return hashes diff --git a/dejavu/logic/recognizer/__init__.py b/dejavu/logic/recognizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/logic/recognizer/file_recognizer.py b/dejavu/logic/recognizer/file_recognizer.py new file mode 100644 index 0000000..ded9650 --- /dev/null +++ b/dejavu/logic/recognizer/file_recognizer.py @@ -0,0 +1,32 @@ +from time import time +from typing import Dict + +import dejavu.logic.decoder as decoder +from dejavu.base_classes.base_recognizer import BaseRecognizer +from dejavu.config.settings import (ALIGN_TIME, FINGERPRINT_TIME, QUERY_TIME, + RESULTS, TOTAL_TIME) + + +class FileRecognizer(BaseRecognizer): + def __init__(self, dejavu): + super().__init__(dejavu) + + def recognize_file(self, filename: str) -> Dict[str, any]: + channels, self.Fs, _ = decoder.read(filename, self.dejavu.limit) + + t = time() + matches, fingerprint_time, query_time, align_time = self._recognize(*channels) + t = time() - t + + results = { + TOTAL_TIME: t, + FINGERPRINT_TIME: fingerprint_time, + QUERY_TIME: query_time, + ALIGN_TIME: align_time, + RESULTS: matches + } + + return results + + def recognize(self, filename: str) -> Dict[str, any]: + return self.recognize_file(filename) diff --git a/dejavu/recognize.py b/dejavu/logic/recognizer/microphone_recognizer.py similarity index 64% rename from dejavu/recognize.py rename to dejavu/logic/recognizer/microphone_recognizer.py index 269a82a..3bfef0d 100755 --- a/dejavu/recognize.py +++ b/dejavu/logic/recognizer/microphone_recognizer.py @@ -1,55 +1,17 @@ -# encoding: utf-8 -import dejavu.fingerprint as fingerprint -import dejavu.decoder as decoder import numpy as np import pyaudio -import time - -class BaseRecognizer(object): - - def __init__(self, dejavu): - self.dejavu = dejavu - self.Fs = fingerprint.DEFAULT_FS - - def _recognize(self, *data): - matches = [] - for d in data: - matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) - return self.dejavu.align_matches(matches) - - def recognize(self): - pass # base class does nothing - - -class FileRecognizer(BaseRecognizer): - def __init__(self, dejavu): - super(FileRecognizer, self).__init__(dejavu) - - def recognize_file(self, filename): - frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) - - t = time.time() - match = self._recognize(*frames) - t = time.time() - t - - if match: - match['match_time'] = t - - return match - - def recognize(self, filename): - return self.recognize_file(filename) +from dejavu.base_classes.base_recognizer import BaseRecognizer class MicrophoneRecognizer(BaseRecognizer): - default_chunksize = 8192 - default_format = pyaudio.paInt16 - default_channels = 2 - default_samplerate = 44100 + default_chunksize = 8192 + default_format = pyaudio.paInt16 + default_channels = 2 + default_samplerate = 44100 def __init__(self, dejavu): - super(MicrophoneRecognizer, self).__init__(dejavu) + super().__init__(dejavu) self.audio = pyaudio.PyAudio() self.stream = None self.data = [] diff --git a/dejavu/tests/__init__.py b/dejavu/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/testing.py b/dejavu/tests/dejavu_test.py similarity index 53% rename from dejavu/testing.py rename to dejavu/tests/dejavu_test.py index d2a3b48..ca9e52c 100644 --- a/dejavu/testing.py +++ b/dejavu/tests/dejavu_test.py @@ -1,46 +1,221 @@ -from __future__ import division -from pydub import AudioSegment -from dejavu.decoder import path_to_songname -from dejavu import Dejavu -from dejavu.fingerprint import * -import traceback import fnmatch -import os, re, ast -import subprocess -import random +import json import logging +import random +import re +import subprocess +import traceback +from os import listdir, makedirs, walk +from os.path import basename, exists, isfile, join, splitext + +import matplotlib.pyplot as plt +import numpy as np +from pydub import AudioSegment + +from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, HASHES_MATCHED, + OFFSET, RESULTS, SONG_NAME, TOTAL_TIME) +from dejavu.logic.decoder import get_audio_name_from_path + + +class DejavuTest: + def __init__(self, folder, seconds): + super().__init__() + + self.test_folder = folder + self.test_seconds = seconds + self.test_songs = [] + + print("test_seconds", self.test_seconds) + + self.test_files = [ + f for f in listdir(self.test_folder) + if isfile(join(self.test_folder, f)) + and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds]) + ] + + print("test_files", self.test_files) + + self.n_columns = len(self.test_seconds) + self.n_lines = int(len(self.test_files) / self.n_columns) + + print("columns:", self.n_columns) + print("length of test files:", len(self.test_files)) + print("lines:", self.n_lines) + + # variable match results (yes, no, invalid) + self.result_match = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + + print("result_match matrix:", self.result_match) + + # variable match precision (if matched in the corrected time) + self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + + # variable matching time (query time) + self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + + # variable confidence + self.result_match_confidence = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] + + self.begin() + + def get_column_id(self, secs): + for i, sec in enumerate(self.test_seconds): + if secs == sec: + return i + + def get_line_id(self, song): + for i, s in enumerate(self.test_songs): + if song == s: + return i + self.test_songs.append(song) + return len(self.test_songs) - 1 + + def create_plots(self, name, results, results_folder): + for sec in range(0, len(self.test_seconds)): + ind = np.arange(self.n_lines) + width = 0.25 # the width of the bars + + fig = plt.figure() + ax = fig.add_subplot(111) + ax.set_xlim([-1 * width, 2 * width]) + + means_dvj = [x[0] for x in results[sec]] + rects1 = ax.bar(ind, means_dvj, width, color='r') + + # add some + ax.set_ylabel(name) + ax.set_title(f"{self.test_seconds[sec]} {name} Results") + ax.set_xticks(ind + width) + + labels = [0 for x in range(0, self.n_lines)] + for x in range(0, self.n_lines): + labels[x] = f"song {x+1}" + ax.set_xticklabels(labels) + + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + + if name == 'Confidence': + autolabel(rects1, ax) + else: + autolabeldoubles(rects1, ax) + + plt.grid() + + fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png") + fig.savefig(fig_name) + + def begin(self): + for f in self.test_files: + log_msg('--------------------------------------------------') + log_msg(f'file: {f}') + + # get column + col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0]) + + # format: XXXX_offset_length.mp3, we also take into account underscores within XXXX + splits = get_audio_name_from_path(f).split("_") + song = "_".join(splits[0:len(get_audio_name_from_path(f).split("_")) - 2]) + line = self.get_line_id(song) + result = subprocess.check_output([ + "python", + "dejavu.py", + '-r', + 'file', + join(self.test_folder, f)]) + + if result.strip() == "None": + log_msg('No match') + self.result_match[line][col] = 'no' + self.result_matching_times[line][col] = 0 + self.result_query_duration[line][col] = 0 + self.result_match_confidence[line][col] = 0 + + else: + result = result.strip() + # we parse the output song back to a json + result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"')) + + # which song did we predict? We consider only the first match. + match = result[RESULTS][0] + song_result = match[SONG_NAME] + log_msg(f'song: {song}') + log_msg(f'song_result: {song_result}') + + if song_result != song: + log_msg('invalid match') + self.result_match[line][col] = 'invalid' + self.result_matching_times[line][col] = 0 + self.result_query_duration[line][col] = 0 + self.result_match_confidence[line][col] = 0 + else: + log_msg('correct match') + print(self.result_match) + self.result_match[line][col] = 'yes' + self.result_query_duration[line][col] = round(result[TOTAL_TIME], 3) + self.result_match_confidence[line][col] = match[HASHES_MATCHED] + + # using replace in f for getting rid of underscores in name + song_start_time = re.findall("_[^_]+", f.replace(song, "")) + song_start_time = song_start_time[0].lstrip("_ ") + + result_start_time = round((match[OFFSET] * DEFAULT_WINDOW_SIZE * + DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0) + + self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) + if abs(self.result_matching_times[line][col]) == 1: + self.result_matching_times[line][col] = 0 + + log_msg(f'query duration: {round(result[TOTAL_TIME], 3)}') + log_msg(f'confidence: {match[HASHES_MATCHED]}') + log_msg(f'song start_time: {song_start_time}') + log_msg(f'result start time: {result_start_time}') + + if self.result_matching_times[line][col] == 0: + log_msg('accurate match') + else: + log_msg('inaccurate match') + log_msg('--------------------------------------------------\n') + def set_seed(seed=None): """ - `seed` as None means that the sampling will be random. + `seed` as None means that the sampling will be random. - Setting your own seed means that you can produce the - same experiment over and over. + Setting your own seed means that you can produce the + same experiment over and over. """ - if seed != None: + if seed is not None: random.seed(seed) + def get_files_recursive(src, fmt): """ - `src` is the source directory. + `src` is the source directory. `fmt` is the extension, ie ".mp3" or "mp3", etc. """ - for root, dirnames, filenames in os.walk(src): + files = [] + for root, dirnames, filenames in walk(src): for filename in fnmatch.filter(filenames, '*' + fmt): - yield os.path.join(root, filename) + files.append(join(root, filename)) + + return files + def get_length_audio(audiopath, extension): """ - Returns length of audio in seconds. - Returns None if format isn't supported or in case of error. + Returns length of audio in seconds. + Returns None if format isn't supported or in case of error. """ try: audio = AudioSegment.from_file(audiopath, extension.replace(".", "")) - except: - print "Error in get_length_audio(): %s" % traceback.format_exc() + except Exception: + print(f"Error in get_length_audio(): {traceback.format_exc()}") return None return int(len(audio) / 1000.0) + def get_starttime(length, nseconds, padding): """ `length` is total audio length in seconds @@ -52,225 +227,60 @@ def get_starttime(length, nseconds, padding): return 0 return random.randint(padding, maximum) + def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10): """ Generates a test file for each file recursively in `src` directory - of given format using `nseconds` sampled from the audio file. + of given format using `nseconds` sampled from the audio file. Results are written to `dest` directory. `padding` is the number of off-limit seconds and the beginning and - end of a track that won't be sampled in testing. Often you want to - avoid silence, etc. + end of a track that won't be sampled in testing. Often you want to + avoid silence, etc. """ # create directories if necessary - for directory in [src, dest]: - try: - os.stat(directory) - except: - os.mkdir(directory) + if not exists(dest): + makedirs(dest) # find files recursively of a given file format for fmt in fmts: - testsources = get_files_recursive(src, fmt) + testsources = get_files_recursive(src, fmt) for audiosource in testsources: - print "audiosource:", audiosource - - filename, extension = os.path.splitext(os.path.basename(audiosource)) - length = get_length_audio(audiosource, extension) + print("audiosource:", audiosource) + + filename, extension = splitext(basename(audiosource)) + length = get_length_audio(audiosource, extension) starttime = get_starttime(length, nseconds, padding) - test_file_name = "%s_%s_%ssec.%s" % ( - os.path.join(dest, filename), starttime, - nseconds, extension.replace(".", "")) - + test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}" + subprocess.check_output([ "ffmpeg", "-y", - "-ss", "%d" % starttime, - '-t' , "%d" % nseconds, + "-ss", f"{starttime}", + '-t', f"{nseconds}", "-i", audiosource, test_file_name]) + def log_msg(msg, log=True, silent=False): if log: logging.debug(msg) if not silent: - print msg + print(msg) + def autolabel(rects, ax): # attach some text labels for rect in rects: height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, - '%d' % int(height), ha='center', va='bottom') + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom') + def autolabeldoubles(rects, ax): # attach some text labels for rect in rects: height = rect.get_height() - ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, - '%s' % round(float(height), 3), ha='center', va='bottom') - -class DejavuTest(object): - def __init__(self, folder, seconds): - super(DejavuTest, self).__init__() - - self.test_folder = folder - self.test_seconds = seconds - self.test_songs = [] - - print "test_seconds", self.test_seconds - - self.test_files = [ - f for f in os.listdir(self.test_folder) - if os.path.isfile(os.path.join(self.test_folder, f)) - and re.findall("[0-9]*sec", f)[0] in self.test_seconds] - - print "test_files", self.test_files - - self.n_columns = len(self.test_seconds) - self.n_lines = int(len(self.test_files) / self.n_columns) - - print "columns:", self.n_columns - print "length of test files:", len(self.test_files) - print "lines:", self.n_lines - - # variable match results (yes, no, invalid) - self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] - - print "result_match matrix:", self.result_match - - # variable match precision (if matched in the corrected time) - self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] - - # variable mahing time (query time) - self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] - - # variable confidence - self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)] - - self.begin() - - def get_column_id (self, secs): - for i, sec in enumerate(self.test_seconds): - if secs == sec: - return i - - def get_line_id (self, song): - for i, s in enumerate(self.test_songs): - if song == s: - return i - self.test_songs.append(song) - return len(self.test_songs) - 1 - - def create_plots(self, name, results, results_folder): - for sec in range(0, len(self.test_seconds)): - ind = np.arange(self.n_lines) # - width = 0.25 # the width of the bars - - fig = plt.figure() - ax = fig.add_subplot(111) - ax.set_xlim([-1 * width, 2 * width]) - - means_dvj = [x[0] for x in results[sec]] - rects1 = ax.bar(ind, means_dvj, width, color='r') - - # add some - ax.set_ylabel(name) - ax.set_title("%s %s Results" % (self.test_seconds[sec], name)) - ax.set_xticks(ind + width) - - labels = [0 for x in range(0, self.n_lines)] - for x in range(0, self.n_lines): - labels[x] = "song %s" % (x+1) - ax.set_xticklabels(labels) - - box = ax.get_position() - ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) - - #ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - - if name == 'Confidence': - autolabel(rects1, ax) - else: - autolabeldoubles(rects1, ax) - - plt.grid() - - fig_name = os.path.join(results_folder, "%s_%s.png" % (name, self.test_seconds[sec])) - fig.savefig(fig_name) - - def begin(self): - for f in self.test_files: - log_msg('--------------------------------------------------') - log_msg('file: %s' % f) - - # get column - col = self.get_column_id(re.findall("[0-9]*sec", f)[0]) - # format: XXXX_offset_length.mp3 - song = path_to_songname(f).split("_")[0] - line = self.get_line_id(song) - result = subprocess.check_output([ - "python", - "dejavu.py", - '-r', - 'file', - self.test_folder + "/" + f]) - - if result.strip() == "None": - log_msg('No match') - self.result_match[line][col] = 'no' - self.result_matching_times[line][col] = 0 - self.result_query_duration[line][col] = 0 - self.result_match_confidence[line][col] = 0 - - else: - result = result.strip() - result = result.replace(" \'", ' "') - result = result.replace("{\'", '{"') - result = result.replace("\':", '":') - result = result.replace("\',", '",') - - # which song did we predict? - result = ast.literal_eval(result) - song_result = result["song_name"] - log_msg('song: %s' % song) - log_msg('song_result: %s' % song_result) - - if song_result != song: - log_msg('invalid match') - self.result_match[line][col] = 'invalid' - self.result_matching_times[line][col] = 0 - self.result_query_duration[line][col] = 0 - self.result_match_confidence[line][col] = 0 - else: - log_msg('correct match') - print self.result_match - self.result_match[line][col] = 'yes' - self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3) - self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE] - - song_start_time = re.findall("\_[^\_]+",f) - song_start_time = song_start_time[0].lstrip("_ ") - - result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE * - DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS), 0) - - self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) - if (abs(self.result_matching_times[line][col]) == 1): - self.result_matching_times[line][col] = 0 - - log_msg('query duration: %s' % round(result[Dejavu.MATCH_TIME],3)) - log_msg('confidence: %s' % result[Dejavu.CONFIDENCE]) - log_msg('song start_time: %s' % song_start_time) - log_msg('result start time: %s' % result_start_time) - if (self.result_matching_times[line][col] == 0): - log_msg('accurate match') - else: - log_msg('inaccurate match') - log_msg('--------------------------------------------------\n') - - - - + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}', + ha='center', va='bottom') diff --git a/dejavu/third_party/__init__.py b/dejavu/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/third_party/wavio.py b/dejavu/third_party/wavio.py new file mode 100644 index 0000000..8c706b8 --- /dev/null +++ b/dejavu/third_party/wavio.py @@ -0,0 +1,357 @@ +# wavio.py +# Author: Warren Weckesser +# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause) +# Synopsis: A Python module for reading and writing 24 bit WAV files. +# Github: github.com/WarrenWeckesser/wavio + +""" +The wavio module defines the functions: +read(file) + Read a WAV file and return a `wavio.Wav` object, with attributes + `data`, `rate` and `sampwidth`. +write(filename, data, rate, scale=None, sampwidth=None) + Write a numpy array to a WAV file. +----- +Author: Warren Weckesser +License: BSD 2-Clause: +Copyright (c) 2015, Warren Weckesser +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + + +import wave as _wave + +import numpy as _np + +__version__ = "0.0.5.dev1" + + +def _wav2array(nchannels, sampwidth, data): + """data must be the string containing the bytes from the wav file.""" + num_samples, remainder = divmod(len(data), sampwidth * nchannels) + if remainder > 0: + raise ValueError('The length of data is not a multiple of ' + 'sampwidth * num_channels.') + if sampwidth > 4: + raise ValueError("sampwidth must not be greater than 4.") + + if sampwidth == 3: + a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8) + raw_bytes = _np.frombuffer(data, dtype=_np.uint8) + a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth) + a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255 + result = a.view('> _np.array([0, 8, 16])) & 255 + wavdata = a8.astype(_np.uint8).tostring() + else: + # Make sure the array is little-endian, and then convert using + # tostring() + a = a.astype('<' + a.dtype.str[1:], copy=False) + wavdata = a.tostring() + return wavdata + + +class Wav(object): + """ + Object returned by `wavio.read`. Attributes are: + data : numpy array + The array of data read from the WAV file. + rate : float + The sample rate of the WAV file. + sampwidth : int + The sample width (i.e. number of bytes per sample) of the WAV file. + For example, `sampwidth == 3` is a 24 bit WAV file. + """ + + def __init__(self, data, rate, sampwidth): + self.data = data + self.rate = rate + self.sampwidth = sampwidth + + def __repr__(self): + s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, " + f"rate={self.rate}, sampwidth={self.sampwidth})") + return s + + +def read(file): + """ + Read a WAV file. + Parameters + ---------- + file : string or file object + Either the name of a file or an open file pointer. + Returns + ------- + wav : wavio.Wav() instance + The return value is an instance of the class `wavio.Wav`, + with the following attributes: + data : numpy array + The array containing the data. The shape of the array + is (num_samples, num_channels). num_channels is the + number of audio channels (1 for mono, 2 for stereo). + rate : float + The sampling frequency (i.e. frame rate) + sampwidth : float + The sample width, in bytes. E.g. for a 24 bit WAV file, + sampwidth is 3. + Notes + ----- + This function uses the `wave` module of the Python standard libary + to read the WAV file, so it has the same limitations as that library. + In particular, the function does not read compressed WAV files, and + it does not read files with floating point data. + The array returned by `wavio.read` is always two-dimensional. If the + WAV data is mono, the array will have shape (num_samples, 1). + `wavio.read()` does not scale or normalize the data. The data in the + array `wav.data` is the data that was in the file. When the file + contains 24 bit samples, the resulting numpy array is 32 bit integers, + with values that have been sign-extended. + """ + wav = _wave.open(file) + rate = wav.getframerate() + nchannels = wav.getnchannels() + sampwidth = wav.getsampwidth() + nframes = wav.getnframes() + data = wav.readframes(nframes) + wav.close() + array = _wav2array(nchannels, sampwidth, data) + w = Wav(data=array, rate=rate, sampwidth=sampwidth) + return w + + +_sampwidth_dtypes = {1: _np.uint8, + 2: _np.int16, + 3: _np.int32, + 4: _np.int32} +_sampwidth_ranges = {1: (0, 256), + 2: (-2**15, 2**15), + 3: (-2**23, 2**23), + 4: (-2**31, 2**31)} + + +def _scale_to_sampwidth(data, sampwidth, vmin, vmax): + # Scale and translate the values to fit the range of the data type + # associated with the given sampwidth. + + data = data.clip(vmin, vmax) + + dt = _sampwidth_dtypes[sampwidth] + if vmax == vmin: + data = _np.zeros(data.shape, dtype=dt) + else: + outmin, outmax = _sampwidth_ranges[sampwidth] + if outmin != vmin or outmax != vmax: + vmin = float(vmin) + vmax = float(vmax) + data = (float(outmax - outmin) * (data - vmin) / + (vmax - vmin)).astype(_np.int64) + outmin + data[data == outmax] = outmax - 1 + data = data.astype(dt) + + return data + + +def write(file, data, rate, scale=None, sampwidth=None): + """ + Write the numpy array `data` to a WAV file. + The Python standard library "wave" is used to write the data + to the file, so this function has the same limitations as that + module. In particular, the Python library does not support + floating point data. When given a floating point array, this + function converts the values to integers. See below for the + conversion rules. + Parameters + ---------- + file : string, or file object open for writing in binary mode + Either the name of a file or an open file pointer. + data : numpy array, 1- or 2-dimensional, integer or floating point + If it is 2-d, the rows are the frames (i.e. samples) and the + columns are the channels. + rate : float + The sampling frequency (i.e. frame rate) of the data. + sampwidth : int, optional + The sample width, in bytes, of the output file. + If `sampwidth` is not given, it is inferred (if possible) from + the data type of `data`, as follows:: + data.dtype sampwidth + ---------- --------- + uint8, int8 1 + uint16, int16 2 + uint32, int32 4 + For any other data types, or to write a 24 bit file, `sampwidth` + must be given. + scale : tuple or str, optional + By default, the data written to the file is scaled up or down to + occupy the full range of the output data type. So, for example, + the unsigned 8 bit data [0, 1, 2, 15] would be written to the file + as [0, 17, 30, 255]. More generally, the default behavior is + (roughly):: + vmin = data.min() + vmax = data.max() + outmin = + outmax = + outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin + The `scale` argument allows the scaling of the output data to be + changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which + case the given values override the use of `data.min()` and + `data.max()` for `vmin` and `vmax` shown above. (If either value + is `None`, the value shown above is used.) Data outside the + range (vmin, vmax) is clipped. If `vmin == vmax`, the output is + all zeros. + If `scale` is the string "none", then `vmin` and `vmax` are set to + `outmin` and `outmax`, respectively. This means the data is written + to the file with no scaling. (Note: `scale="none" is not the same + as `scale=None`. The latter means "use the default behavior", + which is to scale by the data minimum and maximum.) + If `scale` is the string "dtype-limits", then `vmin` and `vmax` + are set to the minimum and maximum integers of `data.dtype`. + The string "dtype-limits" is not allowed when the `data` is a + floating point array. + If using `scale` results in values that exceed the limits of the + output sample width, the data is clipped. For example, the + following code:: + >> x = np.array([-100, 0, 100, 200, 300, 325]) + >> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1) + will write the values [0, 0, 100, 200, 255, 255] to the file. + Example + ------- + Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file. + >> import numpy as np + >> import wavio + >> rate = 22050 # samples per second + >> T = 3 # sample duration (seconds) + >> f = 440.0 # sound frequency (Hz) + >> t = np.linspace(0, T, T*rate, endpoint=False) + >> x = np.sin(2*np.pi * f * t) + >> wavio.write("sine24.wav", x, rate, sampwidth=3) + Create a file that contains the 16 bit integer values -10000 and 10000 + repeated 100 times. Don't automatically scale the values. Use a sample + rate 8000. + >> x = np.empty(200, dtype=np.int16) + >> x[::2] = -10000 + >> x[1::2] = 10000 + >> wavio.write("foo.wav", x, 8000, scale='none') + Check that the file contains what we expect. + >> w = wavio.read("foo.wav") + >> np.all(w.data[:, 0] == x) + True + In the following, the values -10000 and 10000 (from within the 16 bit + range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and + 168 (in the range [0, 2**8-1]). + >> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits') + >> w = wavio.read("foo.wav") + >> w.data[:4, 0] + array([ 88, 168, 88, 168], dtype=uint8) + """ + + if sampwidth is None: + if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4: + raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.') + sampwidth = data.itemsize + else: + if sampwidth not in [1, 2, 3, 4]: + raise ValueError('sampwidth must be 1, 2, 3 or 4.') + + outdtype = _sampwidth_dtypes[sampwidth] + outmin, outmax = _sampwidth_ranges[sampwidth] + + if scale == "none": + data = data.clip(outmin, outmax-1).astype(outdtype) + elif scale == "dtype-limits": + if not _np.issubdtype(data.dtype, _np.integer): + raise ValueError("scale cannot be 'dtype-limits' with non-integer data.") + # Easy transforms that just changed the signedness of the data. + if sampwidth == 1 and data.dtype == _np.int8: + data = (data.astype(_np.int16) + 128).astype(_np.uint8) + elif sampwidth == 2 and data.dtype == _np.uint16: + data = (data.astype(_np.int32) - 32768).astype(_np.int16) + elif sampwidth == 4 and data.dtype == _np.uint32: + data = (data.astype(_np.int64) - 2**31).astype(_np.int32) + elif data.itemsize != sampwidth: + # Integer input, but rescaling is needed to adjust the + # input range to the output sample width. + ii = _np.iinfo(data.dtype) + vmin = ii.min + vmax = ii.max + data = _scale_to_sampwidth(data, sampwidth, vmin, vmax) + else: + if scale is None: + vmin = data.min() + vmax = data.max() + else: + # scale must be a tuple of the form (vmin, vmax) + vmin, vmax = scale + if vmin is None: + vmin = data.min() + if vmax is None: + vmax = data.max() + + data = _scale_to_sampwidth(data, sampwidth, vmin, vmax) + + # At this point, `data` has been converted to have one of the following: + # sampwidth dtype + # --------- ----- + # 1 uint8 + # 2 int16 + # 3 int32 + # 4 int32 + # The values in `data` are in the form in which they will be saved; + # no more scaling will take place. + + if data.ndim == 1: + data = data.reshape(-1, 1) + + wavdata = _array2wav(data, sampwidth) + + w = _wave.open(file, 'wb') + w.setnchannels(data.shape[1]) + w.setsampwidth(sampwidth) + w.setframerate(rate) + w.writeframes(wavdata) + w.close() diff --git a/dejavu/wavio.py b/dejavu/wavio.py deleted file mode 100644 index e8d1fc3..0000000 --- a/dejavu/wavio.py +++ /dev/null @@ -1,121 +0,0 @@ -# wavio.py -# Author: Warren Weckesser -# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause) -# Synopsis: A Python module for reading and writing 24 bit WAV files. -# Github: github.com/WarrenWeckesser/wavio - -import wave as _wave -import numpy as _np - - -def _wav2array(nchannels, sampwidth, data): - """data must be the string containing the bytes from the wav file.""" - num_samples, remainder = divmod(len(data), sampwidth * nchannels) - if remainder > 0: - raise ValueError('The length of data is not a multiple of ' - 'sampwidth * num_channels.') - if sampwidth > 4: - raise ValueError("sampwidth must not be greater than 4.") - - if sampwidth == 3: - a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8) - raw_bytes = _np.fromstring(data, dtype=_np.uint8) - a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth) - a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255 - result = a.view('>> rate = 22050 # samples per second - >>> T = 3 # sample duration (seconds) - >>> f = 440.0 # sound frequency (Hz) - >>> t = np.linspace(0, T, T*rate, endpoint=False) - >>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t) - >>> writewav24("sine24.wav", rate, x) - - """ - a32 = _np.asarray(data, dtype=_np.int32) - if a32.ndim == 1: - # Convert to a 2D array with a single column. - a32.shape = a32.shape + (1,) - # By shifting first 0 bits, then 8, then 16, the resulting output - # is 24 bit little-endian. - a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255 - wavdata = a8.astype(_np.uint8).tostring() - - w = _wave.open(filename, 'wb') - w.setnchannels(a32.shape[1]) - w.setsampwidth(3) - w.setframerate(rate) - w.writeframes(wavdata) - w.close() diff --git a/example.py b/example.py deleted file mode 100755 index 1c99e69..0000000 --- a/example.py +++ /dev/null @@ -1,35 +0,0 @@ -import warnings -import json -warnings.filterwarnings("ignore") - -from dejavu import Dejavu -from dejavu.recognize import FileRecognizer, MicrophoneRecognizer - -# load config from a JSON file (or anything outputting a python dictionary) -with open("dejavu.cnf.SAMPLE") as f: - config = json.load(f) - -if __name__ == '__main__': - - # create a Dejavu instance - djv = Dejavu(config) - - # Fingerprint all the mp3's in the directory we give it - djv.fingerprint_directory("mp3", [".mp3"]) - - # Recognize audio from a file - song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3") - print "From file we recognized: %s\n" % song - - # Or recognize audio from your microphone for `secs` seconds - secs = 5 - song = djv.recognize(MicrophoneRecognizer, seconds=secs) - if song is None: - print "Nothing recognized -- did you play the song out loud so your mic could hear it? :)" - else: - print "From mic with %d seconds we recognized: %s\n" % (secs, song) - - # Or use a recognizer without the shortcut, in anyway you would like - recognizer = FileRecognizer(djv) - song = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") - print "No shortcut, we recognized: %s\n" % song \ No newline at end of file diff --git a/example_script.py b/example_script.py new file mode 100755 index 0000000..19cdf15 --- /dev/null +++ b/example_script.py @@ -0,0 +1,34 @@ +import json + +from dejavu import Dejavu +from dejavu.logic.recognizer.file_recognizer import FileRecognizer +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer + +# load config from a JSON file (or anything outputting a python dictionary) +with open("dejavu.cnf.SAMPLE") as f: + config = json.load(f) + +if __name__ == '__main__': + + # create a Dejavu instance + djv = Dejavu(config) + + # Fingerprint all the mp3's in the directory we give it + djv.fingerprint_directory("test", [".wav"]) + + # Recognize audio from a file + results = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") + print(f"From file we recognized: {results}\n") + + # Or recognize audio from your microphone for `secs` seconds + secs = 5 + results = djv.recognize(MicrophoneRecognizer, seconds=secs) + if results is None: + print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)") + else: + print(f"From mic with {secs} seconds we recognized: {results}\n") + + # Or use a recognizer without the shortcut, in anyway you would like + recognizer = FileRecognizer(djv) + results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") + print(f"No shortcut, we recognized: {results}\n") diff --git a/mp3/azan_test.wav b/mp3/azan_test.wav new file mode 100644 index 0000000..34043fd Binary files /dev/null and b/mp3/azan_test.wav differ diff --git a/requirements.txt b/requirements.txt index 9478f73..19cc2b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ -# requirements file - -### BEGIN ### -pydub>=0.9.4 -PyAudio>=0.2.7 -numpy>=1.8.2 -scipy>=0.12.1 -matplotlib>=1.3.1 -### END ### +pydub==0.23.1 +PyAudio==0.2.11 +numpy==1.17.2 +scipy==1.3.1 +matplotlib==3.1.1 +mysql-connector-python==8.0.17 +psycopg2==2.8.3 diff --git a/run_tests.py b/run_tests.py index b0dfde9..6a5eda9 100644 --- a/run_tests.py +++ b/run_tests.py @@ -1,184 +1,166 @@ -from dejavu.testing import * -from dejavu import Dejavu -from optparse import OptionParser -import matplotlib.pyplot as plt +import argparse +import logging import time -import shutil +from os import makedirs +from os.path import exists, join +from shutil import rmtree -usage = "usage: %prog [options] TESTING_AUDIOFOLDER" -parser = OptionParser(usage=usage, version="%prog 1.1") -parser.add_option("--secs", - action="store", - dest="secs", - default=5, - type=int, - help='Number of seconds starting from zero to test') -parser.add_option("--results", - action="store", - dest="results_folder", - default="./dejavu_test_results", - help='Sets the path where the results are saved') -parser.add_option("--temp", - action="store", - dest="temp_folder", - default="./dejavu_temp_testing_files", - help='Sets the path where the temp files are saved') -parser.add_option("--log", - action="store_true", - dest="log", - default=True, - help='Enables logging') -parser.add_option("--silent", - action="store_false", - dest="silent", - default=False, - help='Disables printing') -parser.add_option("--log-file", - dest="log_file", - default="results-compare.log", - help='Set the path and filename of the log file') -parser.add_option("--padding", - action="store", - dest="padding", - default=10, - type=int, - help='Number of seconds to pad choice of place to test from') -parser.add_option("--seed", - action="store", - dest="seed", - default=None, - type=int, - help='Random seed') -options, args = parser.parse_args() -test_folder = args[0] +import matplotlib.pyplot as plt +import numpy as np -# set random seed if set by user -set_seed(options.seed) +from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles, + generate_test_files, log_msg, set_seed) -# ensure results folder exists -try: - os.stat(options.results_folder) -except: - os.mkdir(options.results_folder) -# set logging -if options.log: - logging.basicConfig(filename=options.log_file, level=logging.DEBUG) +def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool, + log_file: str, padding: int, seed: int, src: str): -# set test seconds -test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)] + # set random seed if set by user + set_seed(seed) -# generate testing files -for i in range(1, options.secs + 1, 1): - generate_test_files(test_folder, options.temp_folder, - i, padding=options.padding) + # ensure results folder exists + if not exists(results_folder): + makedirs(results_folder) -# scan files -log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder, - log=options.log, silent=options.silent) + # set logging + if log: + logging.basicConfig(filename=log_file, level=logging.DEBUG) -tm = time.time() -djv = DejavuTest(options.temp_folder, test_seconds) -log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm), - log=options.log, silent=options.silent) + # set test seconds + test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)] -tests = 1 # djv -n_secs = len(test_seconds) + # generate testing files + for i in range(1, seconds + 1, 1): + generate_test_files(src, temp_folder, i, padding=padding) -# set result variables -> 4d variables -all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)] -all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)] -all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)] -all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)] + # scan files + log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent) -# group results by seconds -for line in range(0, djv.n_lines): - for col in range(0, djv.n_columns): - # for dejavu - all_query_duration[col][line][0] = djv.result_query_duration[line][col] - all_match_confidence[col][line][0] = djv.result_match_confidence[line][col] + tm = time.time() + djv = DejavuTest(temp_folder, test_seconds) + log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent) - djv_match_result = djv.result_match[line][col] + tests = 1 # djv + n_secs = len(test_seconds) - if djv_match_result == 'yes': - all_match_counter[col][0][0] += 1 - elif djv_match_result == 'no': - all_match_counter[col][1][0] += 1 - else: - all_match_counter[col][2][0] += 1 + # set result variables -> 4d variables + all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)] + all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)] + all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] + all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)] - djv_match_acc = djv.result_matching_times[line][col] + # group results by seconds + for line in range(0, djv.n_lines): + for col in range(0, djv.n_columns): + # for dejavu + all_query_duration[col][line][0] = djv.result_query_duration[line][col] + all_match_confidence[col][line][0] = djv.result_match_confidence[line][col] - if djv_match_acc == 0 and djv_match_result == 'yes': - all_matching_times_counter[col][0][0] += 1 - elif djv_match_acc != 0: - all_matching_times_counter[col][1][0] += 1 + djv_match_result = djv.result_match[line][col] -# create plots -djv.create_plots('Confidence', all_match_confidence, options.results_folder) -djv.create_plots('Query duration', all_query_duration, options.results_folder) + if djv_match_result == 'yes': + all_match_counter[col][0][0] += 1 + elif djv_match_result == 'no': + all_match_counter[col][1][0] += 1 + else: + all_match_counter[col][2][0] += 1 -for sec in range(0, n_secs): - ind = np.arange(3) # - width = 0.25 # the width of the bars + djv_match_acc = djv.result_matching_times[line][col] - fig = plt.figure() - ax = fig.add_subplot(111) - ax.set_xlim([-1 * width, 2.75]) + if djv_match_acc == 0 and djv_match_result == 'yes': + all_matching_times_counter[col][0][0] += 1 + elif djv_match_acc != 0: + all_matching_times_counter[col][1][0] += 1 - means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]] - rects1 = ax.bar(ind, means_dvj, width, color='r') + # create plots + djv.create_plots('Confidence', all_match_confidence, results_folder) + djv.create_plots('Query duration', all_query_duration, results_folder) - # add some - ax.set_ylabel('Matching Percentage') - ax.set_title('%s Matching Percentage' % test_seconds[sec]) - ax.set_xticks(ind + width) + for sec in range(0, n_secs): + ind = np.arange(3) + width = 0.25 # the width of the bars - labels = ['yes','no','invalid'] - ax.set_xticklabels( labels ) + fig = plt.figure() + ax = fig.add_subplot(111) + ax.set_xlim([-1 * width, 2.75]) - box = ax.get_position() - ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) - #ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - autolabeldoubles(rects1,ax) - plt.grid() + means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]] + rects1 = ax.bar(ind, means_dvj, width, color='r') - fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec]) - fig.savefig(fig_name) + # add some + ax.set_ylabel('Matching Percentage') + ax.set_title(f'{test_seconds[sec]} Matching Percentage') + ax.set_xticks(ind + width) -for sec in range(0, n_secs): - ind = np.arange(2) # - width = 0.25 # the width of the bars + labels = ['yes', 'no', 'invalid'] + ax.set_xticklabels(labels) - fig = plt.figure() - ax = fig.add_subplot(111) - ax.set_xlim([-1*width, 1.75]) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + autolabeldoubles(rects1, ax) + plt.grid() - div = all_match_counter[sec][0][0] - if div == 0 : - div = 1000000 + fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png") + fig.savefig(fig_name) - means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]] - rects1 = ax.bar(ind, means_dvj, width, color='r') + for sec in range(0, n_secs): + ind = np.arange(2) + width = 0.25 # the width of the bars - # add some - ax.set_ylabel('Matching Accuracy') - ax.set_title('%s Matching Times Accuracy' % test_seconds[sec]) - ax.set_xticks(ind + width) + fig = plt.figure() + ax = fig.add_subplot(111) + ax.set_xlim([-1 * width, 1.75]) - labels = ['yes','no'] - ax.set_xticklabels( labels ) + div = all_match_counter[sec][0][0] + if div == 0: + div = 1000000 - box = ax.get_position() - ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]] + rects1 = ax.bar(ind, means_dvj, width, color='r') - #ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) - autolabeldoubles(rects1,ax) + # add some + ax.set_ylabel('Matching Accuracy') + ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy') + ax.set_xticks(ind + width) - plt.grid() + labels = ['yes', 'no'] + ax.set_xticklabels(labels) - fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec]) - fig.savefig(fig_name) + box = ax.get_position() + ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) + autolabeldoubles(rects1, ax) -# remove temporary folder -shutil.rmtree(options.temp_folder) + plt.grid() + + fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png") + fig.savefig(fig_name) + + # remove temporary folder + rmtree(temp_folder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate ' + f'its configuration performance. ' + f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER' + ) + + parser.add_argument("-sec", "--seconds", action="store", default=5, type=int, + help='Number of seconds starting from zero to test.') + parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results", + help='Sets the path where the results are saved.') + parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files", + help='Sets the path where the temp files are saved.') + parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.') + parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.') + parser.add_argument("-lf", "--log-file", default="results-compare.log", + help='Set the path and filename of the log file.') + parser.add_argument("-pad", "--padding", action="store", default=10, type=int, + help='Number of seconds to pad choice of place to test from.') + parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.') + parser.add_argument("src", type=str, help='Source folder for audios to use as tests.') + + args = parser.parse_args() + + main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding, + args.seed, args.src) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..5f32207 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 + diff --git a/setup.py b/setup.py index 8312d1d..d198e23 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ -from setuptools import setup, find_packages -# import os, sys +from setuptools import find_packages, setup def parse_requirements(requirements): @@ -7,26 +6,27 @@ def parse_requirements(requirements): with open(requirements) as f: lines = [l for l in f] # remove spaces - stripped = map((lambda x: x.strip()), lines) + stripped = list(map((lambda x: x.strip()), lines)) # remove comments - nocomments = filter((lambda x: not x.startswith('#')), stripped) + nocomments = list(filter((lambda x: not x.startswith('#')), stripped)) # remove empty lines - reqs = filter((lambda x: x), nocomments) + reqs = list(filter((lambda x: x), nocomments)) return reqs + PACKAGE_NAME = "PyDejavu" PACKAGE_VERSION = "0.1.3" SUMMARY = 'Dejavu: Audio Fingerprinting in Python' DESCRIPTION = """ Audio fingerprinting and recognition algorithm implemented in Python -See the explanation here: +See the explanation here: `http://willdrevo.com/fingerprinting-and-audio-recognition-with-python/`__ -Dejavu can memorize recorded audio by listening to it once and fingerprinting -it. Then by playing a song and recording microphone input or on disk file, -Dejavu attempts to match the audio against the fingerprints held in the +Dejavu can memorize recorded audio by listening to it once and fingerprinting +it. Then by playing a song and recording microphone input or on disk file, +Dejavu attempts to match the audio against the fingerprints held in the database, returning the song or recording being played. __ http://willdrevo.com/fingerprinting-and-audio-recognition-with-python/ diff --git a/test/sean_secs.wav b/test/sean_secs.wav new file mode 100644 index 0000000..6f30d72 Binary files /dev/null and b/test/sean_secs.wav differ diff --git a/test/woodward_43s.wav b/test/woodward_43s.wav new file mode 100644 index 0000000..906f426 Binary files /dev/null and b/test/woodward_43s.wav differ