More changes:

- added even more docstring to the solution.
 - changed maximum filter mask on fingerprints (now configurable)
This commit is contained in:
mrepetto 2019-10-11 19:49:38 -03:00
parent 9a1c71b349
commit 27bf6b380b
7 changed files with 93 additions and 36 deletions

View file

@ -50,6 +50,7 @@ class Dejavu:
def get_fingerprinted_songs(self) -> List[Dict[str, any]]: def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
""" """
To pull all fingerprinted songs from the database. To pull all fingerprinted songs from the database.
:return: a list of fingerprinted audios from the database. :return: a list of fingerprinted audios from the database.
""" """
return self.db.get_songs() return self.db.get_songs()
@ -57,6 +58,7 @@ class Dejavu:
def delete_songs_by_id(self, song_ids: List[int]) -> None: def delete_songs_by_id(self, song_ids: List[int]) -> None:
""" """
Deletes all audios given their ids. Deletes all audios given their ids.
:param song_ids: song ids to delete from the database. :param song_ids: song ids to delete from the database.
""" """
self.db.delete_songs_by_id(song_ids) self.db.delete_songs_by_id(song_ids)
@ -64,6 +66,7 @@ class Dejavu:
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None: def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
""" """
Given a directory and a set of extensions it fingerprints all files that match each extension specified. Given a directory and a set of extensions it fingerprints all files that match each extension specified.
:param path: path to the directory. :param path: path to the directory.
:param extensions: list of file extensions to consider. :param extensions: list of file extensions to consider.
:param nprocesses: amount of processes to fingerprint the files within the directory. :param nprocesses: amount of processes to fingerprint the files within the directory.
@ -119,6 +122,7 @@ class Dejavu:
""" """
Given a path to a file the method generates hashes for it and stores them in the database Given a path to a file the method generates hashes for it and stores them in the database
for later being queried. for later being queried.
:param file_path: path to the file. :param file_path: path to the file.
:param song_name: song name associated to the audio file. :param song_name: song name associated to the audio file.
""" """
@ -143,6 +147,7 @@ class Dejavu:
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]: def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
f""" f"""
Generate the fingerprints for the given sample data (channel). Generate the fingerprints for the given sample data (channel).
:param samples: list of ints which represents the channel info of the given audio file. :param samples: list of ints which represents the channel info of the given audio file.
:param Fs: sampling rate which defaults to {DEFAULT_FS}. :param Fs: sampling rate which defaults to {DEFAULT_FS}.
:return: a list of tuples for hash and its corresponding offset, together with the generation time. :return: a list of tuples for hash and its corresponding offset, together with the generation time.
@ -155,6 +160,7 @@ class Dejavu:
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]: def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
""" """
Finds the corresponding matches on the fingerprinted audios for the given hashes. Finds the corresponding matches on the fingerprinted audios for the given hashes.
:param hashes: list of tuples for hashes and their corresponding offsets :param hashes: list of tuples for hashes and their corresponding offsets
:return: a tuple containing the matches found against the db, a dictionary which counts the different :return: a tuple containing the matches found against the db, a dictionary which counts the different
hashes matched for each song (with the song id as key), and the time that the query took. hashes matched for each song (with the song id as key), and the time that the query took.
@ -171,6 +177,7 @@ class Dejavu:
""" """
Finds hash matches that align in time with other matches and finds Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio. consensus about which hashes are "true" signal from the audio.
:param matches: matches from the database :param matches: matches from the database
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song :param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
(key is the song id). (key is the song id).

View file

@ -172,6 +172,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
""" """
Given a list of song ids it deletes all songs specified and their corresponding fingerprints. Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
:param song_ids: song ids to be deleted from the database. :param song_ids: song ids to be deleted from the database.
:param batch_size: number of query's batches. :param batch_size: number of query's batches.
""" """
@ -181,6 +182,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
def get_database(database_type: str = "mysql") -> BaseDatabase: def get_database(database_type: str = "mysql") -> BaseDatabase:
""" """
Given a database type it returns a database instance for that type. Given a database type it returns a database instance for that type.
:param database_type: type of the database. :param database_type: type of the database.
:return: an instance of BaseDatabase depending on given database_type. :return: an instance of BaseDatabase depending on given database_type.
""" """

View file

@ -220,6 +220,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None: def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
""" """
Given a list of song ids it deletes all songs specified and their corresponding fingerprints. Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
:param song_ids: song ids to be deleted from the database. :param song_ids: song ids to be deleted from the database.
:param batch_size: number of query's batches. :param batch_size: number of query's batches.
""" """

View file

@ -48,6 +48,14 @@ FIELD_HASH = 'hash'
FIELD_OFFSET = 'offset' FIELD_OFFSET = 'offset'
# FINGERPRINTS CONFIG: # FINGERPRINTS CONFIG:
# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter
# changes the morphology mask when looking for maximum peaks on the spectrogram matrix.
# Possible values are: [1, 2]
# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this
# is the value used in the original dejavu code).
# And 2 sets a square mask, i.e. all elements are considered neighbors.
CONNECTIVITY_MASK = 2
# Sampling rate, related to the Nyquist conditions, which affects # Sampling rate, related to the Nyquist conditions, which affects
# the range frequencies we can detect. # the range frequencies we can detect.
DEFAULT_FS = 44100 DEFAULT_FS = 44100
@ -60,8 +68,8 @@ DEFAULT_WINDOW_SIZE = 4096
# matching, but potentially more fingerprints. # matching, but potentially more fingerprints.
DEFAULT_OVERLAP_RATIO = 0.5 DEFAULT_OVERLAP_RATIO = 0.5
# Degree to which a fingerprint can be paired with its neighbors -- # Degree to which a fingerprint can be paired with its neighbors. Higher values will
# higher will cause more fingerprints, but potentially better accuracy. # cause more fingerprints, but potentially better accuracy.
DEFAULT_FAN_VALUE = 5 # 15 was the original value. DEFAULT_FAN_VALUE = 5 # 15 was the original value.
# Minimum amplitude in spectrogram in order to be considered a peak. # Minimum amplitude in spectrogram in order to be considered a peak.

View file

@ -34,6 +34,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str:
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]: def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
""" """
Get all files that meet the specified extensions. Get all files that meet the specified extensions.
:param path: path to a directory with audio files. :param path: path to a directory with audio files.
:param extensions: file extensions to look for. :param extensions: file extensions to look for.
:return: a list of tuples with file name and its extension. :return: a list of tuples with file name and its extension.
@ -97,6 +98,8 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
def get_audio_name_from_path(file_path: str) -> str: def get_audio_name_from_path(file_path: str) -> str:
""" """
Extracts song name from a file path. Extracts song name from a file path.
:param file_path: path to an audio file. :param file_path: path to an audio file.
:return: file name
""" """
return os.path.splitext(os.path.basename(file_path))[0] return os.path.splitext(os.path.basename(file_path))[0]

View file

@ -1,5 +1,6 @@
import hashlib import hashlib
from operator import itemgetter from operator import itemgetter
from typing import List, Tuple
import matplotlib.mlab as mlab import matplotlib.mlab as mlab
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -9,24 +10,30 @@ from scipy.ndimage.morphology import (binary_erosion,
generate_binary_structure, generate_binary_structure,
iterate_structure) iterate_structure)
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN,
DEFAULT_FS, DEFAULT_OVERLAP_RATIO, DEFAULT_FAN_VALUE, DEFAULT_FS,
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA,
MIN_HASH_TIME_DELTA,
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
IDX_FREQ_I = 0
IDX_TIME_J = 1
def fingerprint(channel_samples: List[int],
def fingerprint(channel_samples, Fs: int = DEFAULT_FS,
Fs=DEFAULT_FS, wsize: int = DEFAULT_WINDOW_SIZE,
wsize=DEFAULT_WINDOW_SIZE, wratio: float = DEFAULT_OVERLAP_RATIO,
wratio=DEFAULT_OVERLAP_RATIO, fan_value: int = DEFAULT_FAN_VALUE,
fan_value=DEFAULT_FAN_VALUE, amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]:
amp_min=DEFAULT_AMP_MIN):
""" """
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes. FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
:param channel_samples: channel samples to fingerprint.
:param Fs: audio sampling rate.
:param wsize: FFT windows size.
:param wratio: ratio by which each sequential window overlaps the last and the next window.
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
:return: a list of hashes with their corresponding offsets.
""" """
# FFT the signal and extract frequency components # FFT the signal and extract frequency components
arr2D = mlab.specgram( arr2D = mlab.specgram(
@ -36,7 +43,7 @@ def fingerprint(channel_samples,
window=mlab.window_hanning, window=mlab.window_hanning,
noverlap=int(wsize * wratio))[0] noverlap=int(wsize * wratio))[0]
# Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning. # Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning.
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0)) arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
@ -45,18 +52,45 @@ def fingerprint(channel_samples,
return generate_hashes(local_maxima, fan_value=fan_value) return generate_hashes(local_maxima, fan_value=fan_value)
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure -> List[Tuple[List[int], List[int]]]:
struct = generate_binary_structure(2, 1) """
Extract maximum peaks from the spectogram matrix (arr2D).
:param arr2D: matrix representing the spectogram.
:param plot: for plotting the results.
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
:return: a list composed by a list of frequencies and times.
"""
# Original code from the repo is using a morphology mask that does not consider diagonal elements
# as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing
# is to change from the current diamond figure to a just a normal square one:
# F T F T T T
# T T T ==> T T T
# F T F T T T
# In my local tests time performance of the square mask was ~3 times faster
# respect to the diamond one, without hurting accuracy of the predictions.
# I've made now the mask shape configurable in order to allow both ways of find maximum peaks.
# That being said, we generate the mask by using the following function
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html
struct = generate_binary_structure(2, CONNECTIVITY_MASK)
# And then we apply dilation using the following function
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html
# Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just
# change it by the following code:
# neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# find local maxima using our filter shape # find local maxima using our filter mask
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
# Applying erosion, the dejavu documentation does not talk about this step.
background = (arr2D == 0) background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
# Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^') # Boolean mask of arr2D with True at peaks (applying XOR on both matrices).
detected_peaks = local_max ^ eroded_background detected_peaks = local_max != eroded_background
# extract peaks # extract peaks
amps = arr2D[detected_peaks] amps = arr2D[detected_peaks]
@ -64,6 +98,7 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# filter peaks # filter peaks
amps = amps.flatten() amps = amps.flatten()
# get indices for frequency and time # get indices for frequency and time
filter_idxs = np.where(amps > amp_min) filter_idxs = np.where(amps > amp_min)
@ -84,12 +119,21 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
return list(zip(freqs_filter, times_filter)) return list(zip(freqs_filter, times_filter))
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]:
""" """
Hash list structure: Hash list structure:
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
[(e05b341a9b77a51fd26, 32), ... ] [(e05b341a9b77a51fd26, 32), ... ]
:param peaks: list of peak frequencies and times.
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
:return: a list of hashes with their corresponding offsets.
""" """
# frequencies are in the first position of the tuples
idx_freq = 0
# times are in the second position of the tuples
idx_time = 1
if PEAK_SORT: if PEAK_SORT:
peaks.sort(key=itemgetter(1)) peaks.sort(key=itemgetter(1))
@ -98,10 +142,10 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
for j in range(1, fan_value): for j in range(1, fan_value):
if (i + j) < len(peaks): if (i + j) < len(peaks):
freq1 = peaks[i][IDX_FREQ_I] freq1 = peaks[i][idx_freq]
freq2 = peaks[i + j][IDX_FREQ_I] freq2 = peaks[i + j][idx_freq]
t1 = peaks[i][IDX_TIME_J] t1 = peaks[i][idx_time]
t2 = peaks[i + j][IDX_TIME_J] t2 = peaks[i + j][idx_time]
t_delta = t2 - t1 t_delta = t2 - t1
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:

View file

@ -32,11 +32,3 @@ if __name__ == '__main__':
recognizer = FileRecognizer(djv) recognizer = FileRecognizer(djv)
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3") results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
print(f"No shortcut, we recognized: {results}\n") print(f"No shortcut, we recognized: {results}\n")
# To list all fingerprinted songs in the db you can use the following:
# fingerprinted_songs = djv.get_fingerprinted_songs()
# print(fingerprinted_songs)
# And to delete a song or a set of songs you can use the following:
# song_ids_to_delete = [1]
# djv.delete_songs_by_ids(song_ids_to_delete)