diff --git a/dejavu.py b/dejavu.py index 84ea42a..fe8837c 100755 --- a/dejavu.py +++ b/dejavu.py @@ -1,12 +1,12 @@ import argparse import json -from os.path import isdir import sys from argparse import RawTextHelpFormatter +from os.path import isdir from dejavu import Dejavu -from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer from dejavu.logic.recognizer.file_recognizer import FileRecognizer +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE" @@ -41,7 +41,7 @@ if __name__ == '__main__': '--fingerprint /path/to/directory') parser.add_argument('-r', '--recognize', nargs=2, help='Recognize what is ' - 'playing through the microphone\n' + 'playing through the microphone or in a file.\n' 'Usage: \n' '--recognize mic number_of_seconds \n' '--recognize file path/to/file \n') diff --git a/dejavu/__init__.py b/dejavu/__init__.py index ade6232..563a415 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -2,13 +2,19 @@ import multiprocessing import os import sys import traceback +from itertools import groupby +from time import time +from typing import Dict, List, Tuple import dejavu.logic.decoder as decoder from dejavu.base_classes.base_database import get_database -from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS, - DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, - FIELD_FILE_SHA1, OFFSET, OFFSET_SECS, - SONG_ID, SONG_NAME, TOPN) +from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, FIELD_FILE_SHA1, + FIELD_TOTAL_HASHES, + FINGERPRINTED_CONFIDENCE, + FINGERPRINTED_HASHES, HASHES_MATCHED, + INPUT_CONFIDENCE, INPUT_HASHES, OFFSET, + OFFSET_SECS, SONG_ID, SONG_NAME, TOPN) from dejavu.logic.fingerprint import fingerprint @@ -27,9 +33,13 @@ class Dejavu: self.limit = self.config.get("fingerprint_limit", None) if self.limit == -1: # for JSON compatibility self.limit = None - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() - def get_fingerprinted_songs(self): + def __load_fingerprinted_audio_hashes(self) -> None: + """ + Keeps a dictionary with the hashes of the fingerprinted songs, in that way is possible to check + whether or not an audio file was already processed. + """ # get songs previously indexed self.songs = self.db.get_songs() self.songhashes_set = set() # to know which ones we've computed before @@ -37,7 +47,27 @@ class Dejavu: song_hash = song[FIELD_FILE_SHA1] self.songhashes_set.add(song_hash) - def fingerprint_directory(self, path, extensions, nprocesses=None): + def get_fingerprinted_songs(self) -> List[Dict[str, any]]: + """ + To pull all fingerprinted songs from the database. + :return: a list of fingerprinted audios from the database. + """ + return self.db.get_songs() + + def delete_songs_by_ids(self, song_ids: List[int]) -> None: + """ + Deletes all audios given their ids. + :param song_ids: song ids to delete from the database. + """ + self.db.delete_songs_by_ids(song_ids) + + def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None: + """ + Given a directory and a set of extensions it fingerprints all files that match each extension specified. + :param path: path to the directory. + :param extensions: list of file extensions to consider. + :param nprocesses: amount of processes to fingerprint the files within the directory. + """ # Try to use the maximum amount of processes if not given. try: nprocesses = nprocesses or multiprocessing.cpu_count() @@ -61,7 +91,7 @@ class Dejavu: worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint))) # Send off our tasks - iterator = pool.imap_unordered(_fingerprint_worker, worker_input) + iterator = pool.imap_unordered(Dejavu._fingerprint_worker, worker_input) # Loop till we have all of them while True: @@ -76,25 +106,31 @@ class Dejavu: # Print traceback because we can't reraise it here traceback.print_exc(file=sys.stdout) else: - sid = self.db.insert_song(song_name, file_hash) + sid = self.db.insert_song(song_name, file_hash, len(hashes)) self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() pool.close() pool.join() - def fingerprint_file(self, filepath, song_name=None): - songname = decoder.path_to_songname(filepath) - song_hash = decoder.unique_hash(filepath) - song_name = song_name or songname + def fingerprint_file(self, file_path: str, song_name: str = None) -> None: + """ + Given a path to a file the method generates hashes for it and stores them in the database + for later being queried. + :param file_path: path to the file. + :param song_name: song name associated to the audio file. + """ + song_name_from_path = decoder.get_audio_name_from_path(file_path) + song_hash = decoder.unique_hash(file_path) + song_name = song_name or song_name_from_path # don't refingerprint already fingerprinted files if song_hash in self.songhashes_set: print(f"{song_name} already fingerprinted, continuing...") else: - song_name, hashes, file_hash = _fingerprint_worker( - filepath, + song_name, hashes, file_hash = Dejavu._fingerprint_worker( + file_path, self.limit, song_name=song_name ) @@ -102,117 +138,115 @@ class Dejavu: self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) - self.get_fingerprinted_songs() + self.__load_fingerprinted_audio_hashes() - def find_matches(self, samples, Fs=DEFAULT_FS): + def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]: + f""" + Generate the fingerprints for the given sample data (channel). + :param samples: list of ints which represents the channel info of the given audio file. + :param Fs: sampling rate which defaults to {DEFAULT_FS}. + :return: a list of tuples for hash and its corresponding offset, together with the generation time. + """ + t = time() hashes = fingerprint(samples, Fs=Fs) - return self.db.return_matches(hashes) + fingerprint_time = time() - t + return hashes, fingerprint_time - def align_matches(self, matches, topn=TOPN): + def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]: """ - Finds hash matches that align in time with other matches and finds - consensus about which hashes are "true" signal from the audio. + Finds the corresponding matches on the fingerprinted audios for the given hashes. + :param hashes: list of tuples for hashes and their corresponding offsets + :return: a tuple containing the matches found against the db, a dictionary which counts the different + hashes matched for each song (with the song id as key), and the time that the query took. - Returns a list of dictionaries (based on topn) with match information. """ - # align by diffs - diff_counter = {} - largest_count = 0 + t = time() + matches, dedup_hashes = self.db.return_matches(hashes) + query_time = time() - t - # TODO: review logic to get topn results. - for tup in matches: - sid, diff = tup - if diff not in diff_counter: - diff_counter[diff] = {} - if sid not in diff_counter[diff]: - diff_counter[diff][sid] = 0 - diff_counter[diff][sid] += 1 + return matches, dedup_hashes, query_time - if diff_counter[diff][sid] > largest_count: - largest_count = diff_counter[diff][sid] + def align_matches(self, matches: List[Tuple[int, int]], dedup_hashes: Dict[str, int], queried_hashes: int, + topn: int = TOPN) -> List[Dict[str, any]]: + """ + Finds hash matches that align in time with other matches and finds + consensus about which hashes are "true" signal from the audio. + :param matches: matches from the database + :param dedup_hashes: dictionary containing the hashes matched without duplicates for each song + (key is the song id). + :param queried_hashes: amount of hashes sent for matching against the db + :param topn: number of results being returned back. + :return: a list of dictionaries (based on topn) with match information. + """ + # count offset occurrences per song and keep only the maximum ones. + sorted_matches = sorted(matches, key=lambda m: (m[0], m[1])) + counts = [(*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))] + songs_matches = sorted( + [max(list(group), key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])], + key=lambda count: count[2], reverse=True + ) - # create dic where key are songs ids - songs_num_matches = {} - for dc in diff_counter: - for sid in diff_counter[dc]: - match_val = diff_counter[dc][sid] - if (sid not in songs_num_matches) or (match_val > songs_num_matches[sid]['value']): - songs_num_matches[sid] = { - 'sid': sid, - 'value': match_val, - 'largest': dc - } - - # use dicc of songs to create an ordered (descending) list using the match value property assigned to each song - songs_num_matches_list = [] - for s in songs_num_matches: - songs_num_matches_list.append({ - 'sid': s, - 'object': songs_num_matches[s] - }) - - songs_num_matches_list_ordered = sorted(songs_num_matches_list, key=lambda x: x['object']['value'], - reverse=True) - - # iterate the ordered list and fill results songs_result = [] - for s in songs_num_matches_list_ordered: - - # get expected variable by the original code - song_id = s['object']['sid'] - largest = s['object']['largest'] - largest_count = s['object']['value'] - - # extract identification + for song_id, offset, _ in songs_matches[0:topn]: # consider topn elements in the result song = self.db.get_song_by_id(song_id) - if song: - # TODO: Clarify what `get_song_by_id` should return. - songname = song.get(SONG_NAME, None) - # return match info - nseconds = round(float(largest) / DEFAULT_FS * - DEFAULT_WINDOW_SIZE * - DEFAULT_OVERLAP_RATIO, 5) - song = { - SONG_ID: song_id, - SONG_NAME: songname.encode("utf8"), - CONFIDENCE: largest_count, - OFFSET: int(largest), - OFFSET_SECS: nseconds, - FIELD_FILE_SHA1: song.get(FIELD_FILE_SHA1, None).encode("utf8") - } + song_name = song.get(SONG_NAME, None) + song_hashes = song.get(FIELD_TOTAL_HASHES, None) + nseconds = round(float(offset) / DEFAULT_FS * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO, 5) + hashes_matched = dedup_hashes[song_id] - songs_result.append(song) + song = { + SONG_ID: song_id, + SONG_NAME: song_name.encode("utf8"), + INPUT_HASHES: queried_hashes, + FINGERPRINTED_HASHES: song_hashes, + HASHES_MATCHED: hashes_matched, + # Percentage regarding hashes matched vs hashes from the input. + INPUT_CONFIDENCE: round(hashes_matched / queried_hashes, 2), + # Percentage regarding hashes matched vs hashes fingerprinted in the db. + FINGERPRINTED_CONFIDENCE: round(hashes_matched / song_hashes, 2), + OFFSET: offset, + OFFSET_SECS: nseconds, + FIELD_FILE_SHA1: song.get(FIELD_FILE_SHA1, None).encode("utf8") + } + + songs_result.append(song) - # only consider up to topn elements in the result - if len(songs_result) > topn: - break return songs_result - def recognize(self, recognizer, *options, **kwoptions): + def recognize(self, recognizer, *options, **kwoptions) -> Dict[str, any]: r = recognizer(self) return r.recognize(*options, **kwoptions) + @staticmethod + def _fingerprint_worker(arguments): + # Pool.imap sends arguments as tuples so we have to unpack + # them ourself. + try: + file_name, limit = arguments + except ValueError: + pass -def _fingerprint_worker(filename, limit=None, song_name=None): - # Pool.imap sends arguments as tuples so we have to unpack - # them ourself. - try: - filename, limit = filename - except ValueError: - pass + song_name, extension = os.path.splitext(os.path.basename(file_name)) - songname, extension = os.path.splitext(os.path.basename(filename)) - song_name = song_name or songname - channels, fs, file_hash = decoder.read(filename, limit) - result = set() - channel_amount = len(channels) + fingerprints, file_hash = Dejavu.get_file_fingerprints(file_name, limit, print_output=True) - for channeln, channel in enumerate(channels): - # TODO: Remove prints or change them into optional logging. - print(f"Fingerprinting channel {channeln + 1}/{channel_amount} for {filename}") - hashes = fingerprint(channel, Fs=fs) - print(f"Finished channel {channeln + 1}/{channel_amount} for {filename}") - result |= set(hashes) + return song_name, fingerprints, file_hash - return song_name, result, file_hash + @staticmethod + def get_file_fingerprints(file_name: str, limit: int, print_output: bool = False): + channels, fs, file_hash = decoder.read(file_name, limit) + fingerprints = set() + channel_amount = len(channels) + for channeln, channel in enumerate(channels, start=1): + if print_output: + print(f"Fingerprinting channel {channeln}/{channel_amount} for {file_name}") + + hashes = fingerprint(channel, Fs=fs) + + if print_output: + print(f"Finished channel {channeln}/{channel_amount} for {file_name}") + + fingerprints |= set(hashes) + + return fingerprints, file_hash diff --git a/dejavu/base_classes/base_database.py b/dejavu/base_classes/base_database.py index 77b118b..4728566 100755 --- a/dejavu/base_classes/base_database.py +++ b/dejavu/base_classes/base_database.py @@ -1,6 +1,6 @@ import abc import importlib -from typing import Dict +from typing import Dict, List, Tuple from dejavu.config.settings import DATABASES @@ -13,13 +13,13 @@ class BaseDatabase(object, metaclass=abc.ABCMeta): def __init__(self): super().__init__() - def before_fork(self): + def before_fork(self) -> None: """ Called before the database instance is given to the new process """ pass - def after_fork(self): + def after_fork(self) -> None: """ Called after the database instance has been given to the new process @@ -27,21 +27,21 @@ class BaseDatabase(object, metaclass=abc.ABCMeta): """ pass - def setup(self): + def setup(self) -> None: """ Called on creation or shortly afterwards. """ pass @abc.abstractmethod - def empty(self): + def empty(self) -> None: """ Called when the database should be cleared of all data. """ pass @abc.abstractmethod - def delete_unfingerprinted_songs(self): + def delete_unfingerprinted_songs(self) -> None: """ Called to remove any song entries that do not have any fingerprints associated with them. @@ -49,110 +49,141 @@ class BaseDatabase(object, metaclass=abc.ABCMeta): pass @abc.abstractmethod - def get_num_songs(self): + def get_num_songs(self) -> int: """ - Returns the amount of songs in the database. + Returns the song's count stored. + + :return: the amount of songs in the database. """ pass @abc.abstractmethod - def get_num_fingerprints(self): + def get_num_fingerprints(self) -> int: """ - Returns the number of fingerprints in the database. + Returns the fingerprints' count stored. + + :return: the number of fingerprints in the database. """ pass @abc.abstractmethod - def set_song_fingerprinted(self, sid): + def set_song_fingerprinted(self, song_id: int): """ Sets a specific song as having all fingerprints in the database. - sid: Song identifier + :param song_id: song identifier. """ pass @abc.abstractmethod - def get_songs(self) -> Dict[str, str]: + def get_songs(self) -> List[Dict[str, str]]: """ - Returns all fully fingerprinted songs in the database. Result must be a Dictionary. + Returns all fully fingerprinted songs in the database + + :return: a dictionary with the songs info. """ pass @abc.abstractmethod - def get_song_by_id(self, sid) -> Dict[str, str]: + def get_song_by_id(self, song_id: int) -> Dict[str, str]: """ - Return a song by its identifier. Result must be a Dictionary. - sid: Song identifier + Brings the song info from the database. + + :param song_id: song identifier. + :return: a song by its identifier. Result must be a Dictionary. """ pass @abc.abstractmethod - def insert(self, hash, sid, offset): + def insert(self, fingerprint: str, song_id: int, offset: int): """ Inserts a single fingerprint into the database. - hash: Part of a sha1 hash, in hexadecimal format - sid: Song identifier this fingerprint is off - offset: The offset this hash is from + :param fingerprint: Part of a sha1 hash, in hexadecimal format + :param song_id: Song identifier this fingerprint is off + :param offset: The offset this fingerprint is from. """ pass @abc.abstractmethod - def insert_song(self, song_name): + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: """ Inserts a song name into the database, returns the new identifier of the song. - song_name: The name of the song. + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. """ pass @abc.abstractmethod - def query(self, hash): + def query(self, fingerprint: str = None) -> List[Tuple]: """ Returns all matching fingerprint entries associated with - the given hash as parameter. + the given hash as parameter, if None is passed it returns all entries. - hash: Part of a sha1 hash, in hexadecimal format + :param fingerprint: part of a sha1 hash, in hexadecimal format + :return: a list of fingerprint records stored in the db. """ pass @abc.abstractmethod - def get_iterable_kv_pairs(self): + def get_iterable_kv_pairs(self) -> List[Tuple]: """ Returns all fingerprints in the database. + + :return: a list containing all fingerprints stored in the db. """ pass @abc.abstractmethod - def insert_hashes(self, sid, hashes, batch=1000): + def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None: """ Insert a multitude of fingerprints. - sid: Song identifier the fingerprints belong to - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. + :param song_id: Song identifier the fingerprints belong to + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: insert batches. """ @abc.abstractmethod - def return_matches(self, hashes): + def return_matches(self, hashes: List[Tuple[str, int]], batch_size: int = 1000) \ + -> Tuple[List[Tuple[int, int]], Dict[int, int]]: """ Searches the database for pairs of (hash, offset) values. - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: number of query's batches. + :return: a list of (sid, offset_difference) tuples and a + dictionary with the amount of hashes matched (not considering + duplicated hashes) in each song. + - song id: Song identifier + - offset_difference: (database_offset - sampled_offset) + """ + pass - Returns a sequence of (sid, offset_difference) tuples. - - sid: Song identifier - offset_difference: (offset - database_offset) + @abc.abstractmethod + def delete_songs_by_ids(self, song_ids: List[int], batch_size: int = 1000) -> None: + """ + Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + :param song_ids: song ids to be deleted from the database. + :param batch_size: number of query's batches. """ pass -def get_database(database_type="mysql"): +def get_database(database_type: str = "mysql") -> BaseDatabase: + """ + Given a database type it returns a database instance for that type. + :param database_type: type of the database. + :return: an instance of BaseDatabase depending on given database_type. + """ try: path, db_class_name = DATABASES[database_type] db_module = importlib.import_module(path) diff --git a/dejavu/base_classes/base_recognizer.py b/dejavu/base_classes/base_recognizer.py index cd07c01..c0f4749 100644 --- a/dejavu/base_classes/base_recognizer.py +++ b/dejavu/base_classes/base_recognizer.py @@ -1,4 +1,8 @@ import abc +from time import time +from typing import Dict, List, Tuple + +import numpy as np from dejavu.config.settings import DEFAULT_FS @@ -8,12 +12,22 @@ class BaseRecognizer(object, metaclass=abc.ABCMeta): self.dejavu = dejavu self.Fs = DEFAULT_FS - def _recognize(self, *data): - matches = [] - for d in data: - matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) - return self.dejavu.align_matches(matches) + def _recognize(self, *data) -> Tuple[List[Dict[str, any]], int, int, int]: + fingerprint_times = [] + hashes = set() # to remove possible duplicated fingerprints we built a set. + for channel in data: + fingerprints, fingerprint_time = self.dejavu.generate_fingerprints(channel, Fs=self.Fs) + fingerprint_times.append(fingerprint_time) + hashes |= set(fingerprints) + + matches, dedup_hashes, query_time = self.dejavu.find_matches(hashes) + + t = time() + final_results = self.dejavu.align_matches(matches, dedup_hashes, len(hashes)) + align_time = time() - t + + return final_results, np.sum(fingerprint_times), query_time, align_time @abc.abstractmethod - def recognize(self): + def recognize(self) -> Dict[str, any]: pass # base class does nothing diff --git a/dejavu/base_classes/common_database.py b/dejavu/base_classes/common_database.py index 8b6f2c1..ec8d560 100644 --- a/dejavu/base_classes/common_database.py +++ b/dejavu/base_classes/common_database.py @@ -1,4 +1,5 @@ import abc +from typing import Dict, List, Tuple from dejavu.base_classes.base_database import BaseDatabase @@ -11,13 +12,13 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): def __init__(self): super().__init__() - def before_fork(self): + def before_fork(self) -> None: """ Called before the database instance is given to the new process """ pass - def after_fork(self): + def after_fork(self) -> None: """ Called after the database instance has been given to the new process @@ -25,7 +26,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): """ pass - def setup(self): + def setup(self) -> None: """ Called on creation or shortly afterwards. """ @@ -34,7 +35,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): cur.execute(self.CREATE_FINGERPRINTS_TABLE) cur.execute(self.DELETE_UNFINGERPRINTED) - def empty(self): + def empty(self) -> None: """ Called when the database should be cleared of all data. """ @@ -44,7 +45,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): self.setup() - def delete_unfingerprinted_songs(self): + def delete_unfingerprinted_songs(self) -> None: """ Called to remove any song entries that do not have any fingerprints associated with them. @@ -52,9 +53,11 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): with self.cursor() as cur: cur.execute(self.DELETE_UNFINGERPRINTED) - def get_num_songs(self): + def get_num_songs(self) -> int: """ - Returns the amount of songs in the database. + Returns the song's count stored. + + :return: the amount of songs in the database. """ with self.cursor() as cur: cur.execute(self.SELECT_UNIQUE_SONG_IDS) @@ -62,9 +65,11 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): return count - def get_num_fingerprints(self): + def get_num_fingerprints(self) -> int: """ - Returns the number of fingerprints in the database. + Returns the fingerprints' count stored. + + :return: the number of fingerprints in the database. """ with self.cursor() as cur: cur.execute(self.SELECT_NUM_FINGERPRINTS) @@ -72,122 +77,155 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): return count - def set_song_fingerprinted(self, sid): + def set_song_fingerprinted(self, song_id): """ Sets a specific song as having all fingerprints in the database. - sid: Song identifier + :param song_id: song identifier. """ with self.cursor() as cur: - cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,)) + cur.execute(self.UPDATE_SONG_FINGERPRINTED, (song_id,)) - def get_songs(self): + def get_songs(self) -> List[Dict[str, str]]: """ - Returns all fully fingerprinted songs in the database. Result must be a Dictionary. + Returns all fully fingerprinted songs in the database + + :return: a dictionary with the songs info. """ with self.cursor(dictionary=True) as cur: cur.execute(self.SELECT_SONGS) - for row in cur: - yield row + return list(cur) - def get_song_by_id(self, sid): + def get_song_by_id(self, song_id: int) -> Dict[str, str]: """ - Return a song by its identifier. Result must be a Dictionary. - sid: Song identifier + Brings the song info from the database. + + :param song_id: song identifier. + :return: a song by its identifier. Result must be a Dictionary. """ with self.cursor(dictionary=True) as cur: - cur.execute(self.SELECT_SONG, (sid,)) + cur.execute(self.SELECT_SONG, (song_id,)) return cur.fetchone() - def insert(self, fingerprint, sid, offset): + def insert(self, fingerprint: str, song_id: int, offset: int): """ Inserts a single fingerprint into the database. - fingerprint: Part of a sha1 hash, in hexadecimal format - sid: Song identifier this fingerprint is off - offset: The offset this fingerprint is from + :param fingerprint: Part of a sha1 hash, in hexadecimal format + :param song_id: Song identifier this fingerprint is off + :param offset: The offset this fingerprint is from. """ with self.cursor() as cur: - cur.execute(self.INSERT_FINGERPRINT, (fingerprint, sid, offset)) + cur.execute(self.INSERT_FINGERPRINT, (fingerprint, song_id, offset)) @abc.abstractmethod - def insert_song(self, song_name): + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: """ Inserts a song name into the database, returns the new identifier of the song. - song_name: The name of the song. + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. """ pass - def query(self, fingerprint): + def query(self, fingerprint: str = None) -> List[Tuple]: """ Returns all matching fingerprint entries associated with - the given fingerprint as parameter. + the given hash as parameter, if None is passed it returns all entries. - fingerprint: Part of a sha1 hash, in hexadecimal format + :param fingerprint: part of a sha1 hash, in hexadecimal format + :return: a list of fingerprint records stored in the db. """ - if fingerprint: - with self.cursor() as cur: + with self.cursor() as cur: + if fingerprint: cur.execute(self.SELECT, (fingerprint,)) - for sid, offset in cur: - yield (sid, offset) - else: # select all if no key - with self.cursor() as cur: + else: # select all if no key cur.execute(self.SELECT_ALL) - for sid, offset in cur: - yield (sid, offset) + return list(cur) - def get_iterable_kv_pairs(self): + def get_iterable_kv_pairs(self) -> List[Tuple]: """ Returns all fingerprints in the database. + + :return: a list containing all fingerprints stored in the db. """ return self.query(None) - def insert_hashes(self, sid, hashes, batch=1000): + def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None: """ Insert a multitude of fingerprints. - sid: Song identifier the fingerprints belong to - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. + :param song_id: Song identifier the fingerprints belong to + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: insert batches. """ - values = [(sid, hsh, int(offset)) for hsh, offset in hashes] + values = [(song_id, hsh, int(offset)) for hsh, offset in hashes] with self.cursor() as cur: - for index in range(0, len(hashes), batch): - cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch]) + for index in range(0, len(hashes), batch_size): + cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch_size]) - def return_matches(self, hashes, batch=1000): + def return_matches(self, hashes: List[Tuple[str, int]], + batch_size: int = 1000) -> Tuple[List[Tuple[int, int]], Dict[int, int]]: """ Searches the database for pairs of (hash, offset) values. - hashes: A sequence of tuples in the format (hash, offset) - - hash: Part of a sha1 hash, in hexadecimal format - - offset: Offset this hash was created from/at. - - Returns a sequence of (sid, offset_difference) tuples. - - sid: Song identifier - offset_difference: (offset - database_offset) + :param hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + :param batch_size: number of query's batches. + :return: a list of (sid, offset_difference) tuples and a + dictionary with the amount of hashes matched (not considering + duplicated hashes) in each song. + - song id: Song identifier + - offset_difference: (database_offset - sampled_offset) """ # Create a dictionary of hash => offset pairs for later lookups mapper = {} for hsh, offset in hashes: - mapper[hsh.upper()] = offset + if hsh.upper() in mapper.keys(): + mapper[hsh.upper()].append(offset) + else: + mapper[hsh.upper()] = [offset] - # Get an iterable of all the hashes we need values = list(mapper.keys()) - with self.cursor() as cur: - for index in range(0, len(values), batch): - # Create our IN part of the query - query = self.SELECT_MULTIPLE - query = query % ', '.join([self.IN_MATCH] * len(values[index: index + batch])) + # in order to count each hash only once per db offset we use the dic below + dedup_hashes = {} - cur.execute(query, values[index: index + batch]) + results = [] + with self.cursor() as cur: + for index in range(0, len(values), batch_size): + # Create our IN part of the query + query = self.SELECT_MULTIPLE % ', '.join([self.IN_MATCH] * len(values[index: index + batch_size])) + + cur.execute(query, values[index: index + batch_size]) for hsh, sid, offset in cur: - # (sid, db_offset - song_sampled_offset) - yield (sid, offset - mapper[hsh]) + if sid not in dedup_hashes.keys(): + dedup_hashes[sid] = 1 + else: + dedup_hashes[sid] += 1 + # we now evaluate all offset for each hash matched + for song_sampled_offset in mapper[hsh]: + results.append((sid, offset - song_sampled_offset)) + + return results, dedup_hashes + + def delete_songs_by_ids(self, song_ids: List[int], batch_size: int = 1000) -> None: + """ + Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + :param song_ids: song ids to be deleted from the database. + :param batch_size: number of query's batches. + """ + with self.cursor() as cur: + for index in range(0, len(song_ids), batch_size): + # Create our IN part of the query + query = self.DELETE_SONGS % ', '.join(['%s'] * len(song_ids[index: index + batch_size])) + + cur.execute(query, song_ids[index: index + batch_size]) diff --git a/dejavu/config/settings.py b/dejavu/config/settings.py index bc711b7..58f294d 100644 --- a/dejavu/config/settings.py +++ b/dejavu/config/settings.py @@ -1,8 +1,26 @@ # Dejavu + +# DEJAVU JSON RESPONSE SONG_ID = "song_id" SONG_NAME = 'song_name' -CONFIDENCE = 'confidence' -MATCH_TIME = 'match_time' +RESULTS = 'results' + +HASHES_MATCHED = 'hashes_matched_in_input' + +# Hashes fingerprinted in the db. +FINGERPRINTED_HASHES = 'fingerprinted_hashes_in_db' +# Percentage regarding hashes matched vs hashes fingerprinted in the db. +FINGERPRINTED_CONFIDENCE = 'fingerprinted_confidence' + +# Hashes generated from the input. +INPUT_HASHES = 'input_total_hashes' +# Percentage regarding hashes matched vs hashes from the input. +INPUT_CONFIDENCE = 'input_confidence' + +TOTAL_TIME = 'total_time' +FINGERPRINT_TIME = 'fingerprint_time' +QUERY_TIME = 'query_time' +ALIGN_TIME = 'align_time' OFFSET = 'offset' OFFSET_SECS = 'offset_seconds' @@ -20,6 +38,7 @@ FIELD_SONG_ID = 'song_id' FIELD_SONGNAME = 'song_name' FIELD_FINGERPRINTED = "fingerprinted" FIELD_FILE_SHA1 = 'file_sha1' +FIELD_TOTAL_HASHES = 'total_hashes' # TABLE FINGERPRINTS FINGERPRINTS_TABLENAME = "fingerprints" @@ -43,7 +62,7 @@ DEFAULT_OVERLAP_RATIO = 0.5 # Degree to which a fingerprint can be paired with its neighbors -- # higher will cause more fingerprints, but potentially better accuracy. -DEFAULT_FAN_VALUE = 15 +DEFAULT_FAN_VALUE = 5 # 15 was the original value. # Minimum amplitude in spectrogram in order to be considered a peak. # This can be raised to reduce number of fingerprints, but can negatively @@ -53,7 +72,7 @@ DEFAULT_AMP_MIN = 10 # Number of cells around an amplitude peak in the spectrogram in order # for Dejavu to consider it a spectral peak. Higher values mean less # fingerprints and faster matching, but can potentially affect accuracy. -PEAK_NEIGHBORHOOD_SIZE = 20 +PEAK_NEIGHBORHOOD_SIZE = 10 # 20 was the original value. # Thresholds on how close or far fingerprints can be in time in order # to be paired as a fingerprint. If your max is too low, higher values of diff --git a/dejavu/database_handler/mysql_database.py b/dejavu/database_handler/mysql_database.py index e1e4257..1a8c506 100755 --- a/dejavu/database_handler/mysql_database.py +++ b/dejavu/database_handler/mysql_database.py @@ -6,8 +6,8 @@ from mysql.connector.errors import DatabaseError from dejavu.base_classes.common_database import CommonDatabase from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, - FIELD_SONGNAME, FINGERPRINTS_TABLENAME, - SONGS_TABLENAME) + FIELD_SONGNAME, FIELD_TOTAL_HASHES, + FINGERPRINTS_TABLENAME, SONGS_TABLENAME) class MySQLDatabase(CommonDatabase): @@ -20,6 +20,7 @@ class MySQLDatabase(CommonDatabase): , `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL , `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0 , `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL + , `{FIELD_TOTAL_HASHES}` INT NOT NULL DEFAULT 0 , `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP , `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP , CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`) @@ -52,8 +53,8 @@ class MySQLDatabase(CommonDatabase): """ INSERT_SONG = f""" - INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`) - VALUES (%s, UNHEX(%s)); + INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`,`{FIELD_TOTAL_HASHES}`) + VALUES (%s, UNHEX(%s), %s); """ # SELECTS @@ -72,7 +73,7 @@ class MySQLDatabase(CommonDatabase): SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;" SELECT_SONG = f""" - SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`, `{FIELD_TOTAL_HASHES}` FROM `{SONGS_TABLENAME}` WHERE `{FIELD_SONG_ID}` = %s; """ @@ -90,6 +91,8 @@ class MySQLDatabase(CommonDatabase): `{FIELD_SONG_ID}` , `{FIELD_SONGNAME}` , HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}` + , `{FIELD_TOTAL_HASHES}` + , `date_created` FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 1; """ @@ -98,16 +101,20 @@ class MySQLDatabase(CommonDatabase): DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;" DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;" - # update + # UPDATE UPDATE_SONG_FINGERPRINTED = f""" UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s; """ - # DELETE + # DELETES DELETE_UNFINGERPRINTED = f""" DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0; """ + DELETE_SONGS = f""" + DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_SONG_ID}` IN (%s); + """ + # IN IN_MATCH = f"UNHEX(%s)" @@ -116,17 +123,23 @@ class MySQLDatabase(CommonDatabase): self.cursor = cursor_factory(**options) self._options = options - def after_fork(self): + def after_fork(self) -> None: # Clear the cursor cache, we don't want any stale connections from # the previous process. Cursor.clear_cache() - def insert_song(self, song_name, file_hash): + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: """ - Inserts song in the database and returns the ID of the inserted record. + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. """ with self.cursor() as cur: - cur.execute(self.INSERT_SONG, (song_name, file_hash)) + cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes)) return cur.lastrowid def __getstate__(self): diff --git a/dejavu/database_handler/postgres_database.py b/dejavu/database_handler/postgres_database.py index 341641b..4ac7131 100755 --- a/dejavu/database_handler/postgres_database.py +++ b/dejavu/database_handler/postgres_database.py @@ -6,8 +6,8 @@ from psycopg2.extras import DictCursor from dejavu.base_classes.common_database import CommonDatabase from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED, FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID, - FIELD_SONGNAME, FINGERPRINTS_TABLENAME, - SONGS_TABLENAME) + FIELD_SONGNAME, FIELD_TOTAL_HASHES, + FINGERPRINTS_TABLENAME, SONGS_TABLENAME) class PostgreSQLDatabase(CommonDatabase): @@ -20,6 +20,7 @@ class PostgreSQLDatabase(CommonDatabase): , "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL , "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0 , "{FIELD_FILE_SHA1}" BYTEA + , "{FIELD_TOTAL_HASHES}" INT NOT NULL DEFAULT 0 , "date_created" TIMESTAMP NOT NULL DEFAULT now() , "date_modified" TIMESTAMP NOT NULL DEFAULT now() , CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}") @@ -58,8 +59,8 @@ class PostgreSQLDatabase(CommonDatabase): """ INSERT_SONG = f""" - INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}") - VALUES (%s, decode(%s, 'hex')) + INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}","{FIELD_TOTAL_HASHES}") + VALUES (%s, decode(%s, 'hex'), %s) RETURNING "{FIELD_SONG_ID}"; """ @@ -79,7 +80,10 @@ class PostgreSQLDatabase(CommonDatabase): SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";' SELECT_SONG = f""" - SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + SELECT + "{FIELD_SONGNAME}" + , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + , "{FIELD_TOTAL_HASHES}" FROM "{SONGS_TABLENAME}" WHERE "{FIELD_SONG_ID}" = %s; """ @@ -97,6 +101,8 @@ class PostgreSQLDatabase(CommonDatabase): "{FIELD_SONG_ID}" , "{FIELD_SONGNAME}" , upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}" + , "{FIELD_TOTAL_HASHES}" + , "date_created" FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 1; """ @@ -113,11 +119,15 @@ class PostgreSQLDatabase(CommonDatabase): WHERE "{FIELD_SONG_ID}" = %s; """ - # DELETE + # DELETES DELETE_UNFINGERPRINTED = f""" DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0; """ + DELETE_SONGS = f""" + DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_SONG_ID}" IN (%s); + """ + # IN IN_MATCH = f"decode(%s, 'hex')" @@ -126,17 +136,23 @@ class PostgreSQLDatabase(CommonDatabase): self.cursor = cursor_factory(**options) self._options = options - def after_fork(self): + def after_fork(self) -> None: # Clear the cursor cache, we don't want any stale connections from # the previous process. Cursor.clear_cache() - def insert_song(self, song_name, file_hash): + def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int: """ - Inserts song in the database and returns the ID of the inserted record. + Inserts a song name into the database, returns the new + identifier of the song. + + :param song_name: The name of the song. + :param file_hash: Hash from the fingerprinted file. + :param total_hashes: amount of hashes to be inserted on fingerprint table. + :return: the inserted id. """ with self.cursor() as cur: - cur.execute(self.INSERT_SONG, (song_name, file_hash)) + cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes)) return cur.fetchone()[0] def __getstate__(self): diff --git a/dejavu/logic/decoder.py b/dejavu/logic/decoder.py index 245d8b1..615cba0 100755 --- a/dejavu/logic/decoder.py +++ b/dejavu/logic/decoder.py @@ -1,6 +1,7 @@ import fnmatch import os from hashlib import sha1 +from typing import List, Tuple import numpy as np from pydub import AudioSegment @@ -9,35 +10,47 @@ from pydub.utils import audioop from dejavu.third_party import wavio -def unique_hash(filepath, blocksize=2**20): +def unique_hash(file_path: str, block_size: int = 2**20) -> str: """ Small function to generate a hash to uniquely generate a file. Inspired by MD5 version here: http://stackoverflow.com/a/1131255/712997 Works with large files. + + :param file_path: path to file. + :param block_size: read block size. + :return: a hash in an hexagesimal string form. """ s = sha1() - with open(filepath, "rb") as f: + with open(file_path, "rb") as f: while True: - buf = f.read(blocksize) + buf = f.read(block_size) if not buf: break s.update(buf) return s.hexdigest().upper() -def find_files(path, extensions): +def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]: + """ + Get all files that meet the specified extensions. + :param path: path to a directory with audio files. + :param extensions: file extensions to look for. + :return: a list of tuples with file name and its extension. + """ # Allow both with ".mp3" and without "mp3" to be used for extensions extensions = [e.replace(".", "") for e in extensions] + results = [] for dirpath, dirnames, files in os.walk(path): for extension in extensions: for f in fnmatch.filter(files, f"*.{extension}"): p = os.path.join(dirpath, f) - yield (p, extension) + results.append((p, extension)) + return results -def read(filename, limit=None): +def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: """ Reads any file supported by pydub (ffmpeg) and returns the data contained within. If file reading fails due to input being a 24-bit wav file, @@ -47,11 +60,13 @@ def read(filename, limit=None): of the file by specifying the `limit` parameter. This is the amount of seconds from the start of the file. - returns: (channels, samplerate) + :param file_name: file to be read. + :param limit: number of seconds to limit. + :return: tuple list of (channels, sample_rate, content_file_hash). """ # pydub does not support 24-bit wav files, use wavio when this occurs try: - audiofile = AudioSegment.from_file(filename) + audiofile = AudioSegment.from_file(file_name) if limit: audiofile = audiofile[:limit * 1000] @@ -64,7 +79,7 @@ def read(filename, limit=None): audiofile.frame_rate except audioop.error: - _, _, audiofile = wavio.readwav(filename) + _, _, audiofile = wavio.readwav(file_name) if limit: audiofile = audiofile[:limit * 1000] @@ -76,12 +91,12 @@ def read(filename, limit=None): for chn in audiofile: channels.append(chn) - return channels, audiofile.frame_rate, unique_hash(filename) + return channels, audiofile.frame_rate, unique_hash(file_name) -def path_to_songname(path): +def get_audio_name_from_path(file_path: str) -> str: """ - Extracts song name from a filepath. Used to identify which songs - have already been fingerprinted on disk. + Extracts song name from a file path. + :param file_path: path to an audio file. """ - return os.path.splitext(os.path.basename(path))[0] + return os.path.splitext(os.path.basename(file_path))[0] diff --git a/dejavu/logic/fingerprint.py b/dejavu/logic/fingerprint.py index 46729f6..c2e09a5 100755 --- a/dejavu/logic/fingerprint.py +++ b/dejavu/logic/fingerprint.py @@ -93,6 +93,7 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): if PEAK_SORT: peaks.sort(key=itemgetter(1)) + hashes = [] for i in range(len(peaks)): for j in range(1, fan_value): if (i + j) < len(peaks): @@ -105,4 +106,7 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: h = hashlib.sha1(f"{str(freq1)}|{str(freq2)}|{str(t_delta)}".encode('utf-8')) - yield (h.hexdigest()[0:FINGERPRINT_REDUCTION], t1) + + hashes.append((h.hexdigest()[0:FINGERPRINT_REDUCTION], t1)) + + return hashes diff --git a/dejavu/logic/recognizer/file_recognizer.py b/dejavu/logic/recognizer/file_recognizer.py index 8c019be..ded9650 100644 --- a/dejavu/logic/recognizer/file_recognizer.py +++ b/dejavu/logic/recognizer/file_recognizer.py @@ -1,24 +1,32 @@ -import time +from time import time +from typing import Dict import dejavu.logic.decoder as decoder from dejavu.base_classes.base_recognizer import BaseRecognizer +from dejavu.config.settings import (ALIGN_TIME, FINGERPRINT_TIME, QUERY_TIME, + RESULTS, TOTAL_TIME) class FileRecognizer(BaseRecognizer): def __init__(self, dejavu): super().__init__(dejavu) - def recognize_file(self, filename): - frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) + def recognize_file(self, filename: str) -> Dict[str, any]: + channels, self.Fs, _ = decoder.read(filename, self.dejavu.limit) - t = time.time() - matches = self._recognize(*frames) - t = time.time() - t + t = time() + matches, fingerprint_time, query_time, align_time = self._recognize(*channels) + t = time() - t - for match in matches: - match['match_time'] = t + results = { + TOTAL_TIME: t, + FINGERPRINT_TIME: fingerprint_time, + QUERY_TIME: query_time, + ALIGN_TIME: align_time, + RESULTS: matches + } - return matches + return results - def recognize(self, filename): + def recognize(self, filename: str) -> Dict[str, any]: return self.recognize_file(filename) diff --git a/dejavu/tests/dejavu_test.py b/dejavu/tests/dejavu_test.py index 3a301aa..ca9e52c 100644 --- a/dejavu/tests/dejavu_test.py +++ b/dejavu/tests/dejavu_test.py @@ -12,10 +12,10 @@ import matplotlib.pyplot as plt import numpy as np from pydub import AudioSegment -from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS, - DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, - MATCH_TIME, OFFSET, SONG_NAME) -from dejavu.logic.decoder import path_to_songname +from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO, + DEFAULT_WINDOW_SIZE, HASHES_MATCHED, + OFFSET, RESULTS, SONG_NAME, TOTAL_TIME) +from dejavu.logic.decoder import get_audio_name_from_path class DejavuTest: @@ -115,8 +115,8 @@ class DejavuTest: col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0]) # format: XXXX_offset_length.mp3, we also take into account underscores within XXXX - splits = path_to_songname(f).split("_") - song = "_".join(splits[0:len(path_to_songname(f).split("_")) - 2]) + splits = get_audio_name_from_path(f).split("_") + song = "_".join(splits[0:len(get_audio_name_from_path(f).split("_")) - 2]) line = self.get_line_id(song) result = subprocess.check_output([ "python", @@ -138,8 +138,8 @@ class DejavuTest: result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"')) # which song did we predict? We consider only the first match. - result = result[0] - song_result = result[SONG_NAME] + match = result[RESULTS][0] + song_result = match[SONG_NAME] log_msg(f'song: {song}') log_msg(f'song_result: {song_result}') @@ -153,22 +153,22 @@ class DejavuTest: log_msg('correct match') print(self.result_match) self.result_match[line][col] = 'yes' - self.result_query_duration[line][col] = round(result[MATCH_TIME], 3) - self.result_match_confidence[line][col] = result[CONFIDENCE] + self.result_query_duration[line][col] = round(result[TOTAL_TIME], 3) + self.result_match_confidence[line][col] = match[HASHES_MATCHED] # using replace in f for getting rid of underscores in name song_start_time = re.findall("_[^_]+", f.replace(song, "")) song_start_time = song_start_time[0].lstrip("_ ") - result_start_time = round((result[OFFSET] * DEFAULT_WINDOW_SIZE * + result_start_time = round((match[OFFSET] * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0) self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) if abs(self.result_matching_times[line][col]) == 1: self.result_matching_times[line][col] = 0 - log_msg(f'query duration: {round(result[MATCH_TIME], 3)}') - log_msg(f'confidence: {result[CONFIDENCE]}') + log_msg(f'query duration: {round(result[TOTAL_TIME], 3)}') + log_msg(f'confidence: {match[HASHES_MATCHED]}') log_msg(f'song start_time: {song_start_time}') log_msg(f'result start time: {result_start_time}') diff --git a/example_script.py b/example_script.py index 8f90552..f29adec 100755 --- a/example_script.py +++ b/example_script.py @@ -1,8 +1,8 @@ import json from dejavu import Dejavu -from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer from dejavu.logic.recognizer.file_recognizer import FileRecognizer +from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer # load config from a JSON file (or anything outputting a python dictionary) with open("dejavu.cnf.SAMPLE") as f: @@ -17,18 +17,26 @@ if __name__ == '__main__': djv.fingerprint_directory("test", [".wav"]) # Recognize audio from a file - song = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") - print(f"From file we recognized: {song}\n") + results = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") + print(f"From file we recognized: {results}\n") # Or recognize audio from your microphone for `secs` seconds secs = 5 - song = djv.recognize(MicrophoneRecognizer, seconds=secs) - if song is None: + results = djv.recognize(MicrophoneRecognizer, seconds=secs) + if results is None: print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)") else: - print(f"From mic with {secs} seconds we recognized: {song}\n") + print(f"From mic with {secs} seconds we recognized: {results}\n") # Or use a recognizer without the shortcut, in anyway you would like recognizer = FileRecognizer(djv) - song = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") - print(f"No shortcut, we recognized: {song}\n") + results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") + print(f"No shortcut, we recognized: {results}\n") + + # To list all fingerprinted songs in the db you can use the following: + # fingerprinted_songs = djv.get_fingerprinted_songs() + # print(fingerprinted_songs) + + # And to delete a song or a set of songs you can use the following: + # song_ids_to_delete = [1] + # djv.delete_songs_by_ids(song_ids_to_delete) diff --git a/requirements.txt b/requirements.txt index 1295bd1..19cc2b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ scipy==1.3.1 matplotlib==3.1.1 mysql-connector-python==8.0.17 psycopg2==2.8.3 - diff --git a/setup.py b/setup.py index c94e0fb..d198e23 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup def parse_requirements(requirements):