From 27bf6b380bf9b5f0f1710efda0c98517ca1181ef Mon Sep 17 00:00:00 2001 From: mrepetto Date: Fri, 11 Oct 2019 19:49:38 -0300 Subject: [PATCH] More changes: - added even more docstring to the solution. - changed maximum filter mask on fingerprints (now configurable) --- dejavu/__init__.py | 7 ++ dejavu/base_classes/base_database.py | 2 + dejavu/base_classes/common_database.py | 1 + dejavu/config/settings.py | 12 +++- dejavu/logic/decoder.py | 3 + dejavu/logic/fingerprint.py | 96 +++++++++++++++++++------- example_script.py | 8 --- 7 files changed, 93 insertions(+), 36 deletions(-) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index b450452..5173824 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -50,6 +50,7 @@ class Dejavu: def get_fingerprinted_songs(self) -> List[Dict[str, any]]: """ To pull all fingerprinted songs from the database. + :return: a list of fingerprinted audios from the database. """ return self.db.get_songs() @@ -57,6 +58,7 @@ class Dejavu: def delete_songs_by_id(self, song_ids: List[int]) -> None: """ Deletes all audios given their ids. + :param song_ids: song ids to delete from the database. """ self.db.delete_songs_by_id(song_ids) @@ -64,6 +66,7 @@ class Dejavu: def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None: """ Given a directory and a set of extensions it fingerprints all files that match each extension specified. + :param path: path to the directory. :param extensions: list of file extensions to consider. :param nprocesses: amount of processes to fingerprint the files within the directory. @@ -119,6 +122,7 @@ class Dejavu: """ Given a path to a file the method generates hashes for it and stores them in the database for later being queried. + :param file_path: path to the file. :param song_name: song name associated to the audio file. """ @@ -143,6 +147,7 @@ class Dejavu: def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]: f""" Generate the fingerprints for the given sample data (channel). + :param samples: list of ints which represents the channel info of the given audio file. :param Fs: sampling rate which defaults to {DEFAULT_FS}. :return: a list of tuples for hash and its corresponding offset, together with the generation time. @@ -155,6 +160,7 @@ class Dejavu: def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]: """ Finds the corresponding matches on the fingerprinted audios for the given hashes. + :param hashes: list of tuples for hashes and their corresponding offsets :return: a tuple containing the matches found against the db, a dictionary which counts the different hashes matched for each song (with the song id as key), and the time that the query took. @@ -171,6 +177,7 @@ class Dejavu: """ Finds hash matches that align in time with other matches and finds consensus about which hashes are "true" signal from the audio. + :param matches: matches from the database :param dedup_hashes: dictionary containing the hashes matched without duplicates for each song (key is the song id). diff --git a/dejavu/base_classes/base_database.py b/dejavu/base_classes/base_database.py index 1aef252..839a72a 100755 --- a/dejavu/base_classes/base_database.py +++ b/dejavu/base_classes/base_database.py @@ -172,6 +172,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta): def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: """ Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + :param song_ids: song ids to be deleted from the database. :param batch_size: number of query's batches. """ @@ -181,6 +182,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta): def get_database(database_type: str = "mysql") -> BaseDatabase: """ Given a database type it returns a database instance for that type. + :param database_type: type of the database. :return: an instance of BaseDatabase depending on given database_type. """ diff --git a/dejavu/base_classes/common_database.py b/dejavu/base_classes/common_database.py index 542869d..517cda7 100644 --- a/dejavu/base_classes/common_database.py +++ b/dejavu/base_classes/common_database.py @@ -220,6 +220,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta): def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: """ Given a list of song ids it deletes all songs specified and their corresponding fingerprints. + :param song_ids: song ids to be deleted from the database. :param batch_size: number of query's batches. """ diff --git a/dejavu/config/settings.py b/dejavu/config/settings.py index 58f294d..0e20569 100644 --- a/dejavu/config/settings.py +++ b/dejavu/config/settings.py @@ -48,6 +48,14 @@ FIELD_HASH = 'hash' FIELD_OFFSET = 'offset' # FINGERPRINTS CONFIG: +# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter +# changes the morphology mask when looking for maximum peaks on the spectrogram matrix. +# Possible values are: [1, 2] +# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this +# is the value used in the original dejavu code). +# And 2 sets a square mask, i.e. all elements are considered neighbors. +CONNECTIVITY_MASK = 2 + # Sampling rate, related to the Nyquist conditions, which affects # the range frequencies we can detect. DEFAULT_FS = 44100 @@ -60,8 +68,8 @@ DEFAULT_WINDOW_SIZE = 4096 # matching, but potentially more fingerprints. DEFAULT_OVERLAP_RATIO = 0.5 -# Degree to which a fingerprint can be paired with its neighbors -- -# higher will cause more fingerprints, but potentially better accuracy. +# Degree to which a fingerprint can be paired with its neighbors. Higher values will +# cause more fingerprints, but potentially better accuracy. DEFAULT_FAN_VALUE = 5 # 15 was the original value. # Minimum amplitude in spectrogram in order to be considered a peak. diff --git a/dejavu/logic/decoder.py b/dejavu/logic/decoder.py index 615cba0..ccafa26 100755 --- a/dejavu/logic/decoder.py +++ b/dejavu/logic/decoder.py @@ -34,6 +34,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str: def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]: """ Get all files that meet the specified extensions. + :param path: path to a directory with audio files. :param extensions: file extensions to look for. :return: a list of tuples with file name and its extension. @@ -97,6 +98,8 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: def get_audio_name_from_path(file_path: str) -> str: """ Extracts song name from a file path. + :param file_path: path to an audio file. + :return: file name """ return os.path.splitext(os.path.basename(file_path))[0] diff --git a/dejavu/logic/fingerprint.py b/dejavu/logic/fingerprint.py index c2e09a5..c3089aa 100755 --- a/dejavu/logic/fingerprint.py +++ b/dejavu/logic/fingerprint.py @@ -1,5 +1,6 @@ import hashlib from operator import itemgetter +from typing import List, Tuple import matplotlib.mlab as mlab import matplotlib.pyplot as plt @@ -9,24 +10,30 @@ from scipy.ndimage.morphology import (binary_erosion, generate_binary_structure, iterate_structure) -from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, - DEFAULT_FS, DEFAULT_OVERLAP_RATIO, - DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, - MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, +from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN, + DEFAULT_FAN_VALUE, DEFAULT_FS, + DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, + FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA, + MIN_HASH_TIME_DELTA, PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) -IDX_FREQ_I = 0 -IDX_TIME_J = 1 - -def fingerprint(channel_samples, - Fs=DEFAULT_FS, - wsize=DEFAULT_WINDOW_SIZE, - wratio=DEFAULT_OVERLAP_RATIO, - fan_value=DEFAULT_FAN_VALUE, - amp_min=DEFAULT_AMP_MIN): +def fingerprint(channel_samples: List[int], + Fs: int = DEFAULT_FS, + wsize: int = DEFAULT_WINDOW_SIZE, + wratio: float = DEFAULT_OVERLAP_RATIO, + fan_value: int = DEFAULT_FAN_VALUE, + amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]: """ FFT the channel, log transform output, find local maxima, then return locally sensitive hashes. + + :param channel_samples: channel samples to fingerprint. + :param Fs: audio sampling rate. + :param wsize: FFT windows size. + :param wratio: ratio by which each sequential window overlaps the last and the next window. + :param fan_value: degree to which a fingerprint can be paired with its neighbors. + :param amp_min: minimum amplitude in spectrogram in order to be considered a peak. + :return: a list of hashes with their corresponding offsets. """ # FFT the signal and extract frequency components arr2D = mlab.specgram( @@ -36,7 +43,7 @@ def fingerprint(channel_samples, window=mlab.window_hanning, noverlap=int(wsize * wratio))[0] - # Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning. + # Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning. arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0)) local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) @@ -45,18 +52,45 @@ def fingerprint(channel_samples, return generate_hashes(local_maxima, fan_value=fan_value) -def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): - # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure - struct = generate_binary_structure(2, 1) +def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\ + -> List[Tuple[List[int], List[int]]]: + """ + Extract maximum peaks from the spectogram matrix (arr2D). + + :param arr2D: matrix representing the spectogram. + :param plot: for plotting the results. + :param amp_min: minimum amplitude in spectrogram in order to be considered a peak. + :return: a list composed by a list of frequencies and times. + """ + # Original code from the repo is using a morphology mask that does not consider diagonal elements + # as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing + # is to change from the current diamond figure to a just a normal square one: + # F T F T T T + # T T T ==> T T T + # F T F T T T + # In my local tests time performance of the square mask was ~3 times faster + # respect to the diamond one, without hurting accuracy of the predictions. + # I've made now the mask shape configurable in order to allow both ways of find maximum peaks. + # That being said, we generate the mask by using the following function + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html + struct = generate_binary_structure(2, CONNECTIVITY_MASK) + + # And then we apply dilation using the following function + # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html + # Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just + # change it by the following code: + # neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool) neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) - # find local maxima using our filter shape + # find local maxima using our filter mask local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D + + # Applying erosion, the dejavu documentation does not talk about this step. background = (arr2D == 0) eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) - # Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^') - detected_peaks = local_max ^ eroded_background + # Boolean mask of arr2D with True at peaks (applying XOR on both matrices). + detected_peaks = local_max != eroded_background # extract peaks amps = arr2D[detected_peaks] @@ -64,6 +98,7 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): # filter peaks amps = amps.flatten() + # get indices for frequency and time filter_idxs = np.where(amps > amp_min) @@ -84,12 +119,21 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): return list(zip(freqs_filter, times_filter)) -def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): +def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]: """ Hash list structure: sha1_hash[0:FINGERPRINT_REDUCTION] time_offset - [(e05b341a9b77a51fd26, 32), ... ] + [(e05b341a9b77a51fd26, 32), ... ] + + :param peaks: list of peak frequencies and times. + :param fan_value: degree to which a fingerprint can be paired with its neighbors. + :return: a list of hashes with their corresponding offsets. """ + # frequencies are in the first position of the tuples + idx_freq = 0 + # times are in the second position of the tuples + idx_time = 1 + if PEAK_SORT: peaks.sort(key=itemgetter(1)) @@ -98,10 +142,10 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): for j in range(1, fan_value): if (i + j) < len(peaks): - freq1 = peaks[i][IDX_FREQ_I] - freq2 = peaks[i + j][IDX_FREQ_I] - t1 = peaks[i][IDX_TIME_J] - t2 = peaks[i + j][IDX_TIME_J] + freq1 = peaks[i][idx_freq] + freq2 = peaks[i + j][idx_freq] + t1 = peaks[i][idx_time] + t2 = peaks[i + j][idx_time] t_delta = t2 - t1 if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: diff --git a/example_script.py b/example_script.py index f29adec..19cdf15 100755 --- a/example_script.py +++ b/example_script.py @@ -32,11 +32,3 @@ if __name__ == '__main__': recognizer = FileRecognizer(djv) results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") print(f"No shortcut, we recognized: {results}\n") - - # To list all fingerprinted songs in the db you can use the following: - # fingerprinted_songs = djv.get_fingerprinted_songs() - # print(fingerprinted_songs) - - # And to delete a song or a set of songs you can use the following: - # song_ids_to_delete = [1] - # djv.delete_songs_by_ids(song_ids_to_delete)