More changes:

- added even more docstring to the solution.
 - changed maximum filter mask on fingerprints (now configurable)
This commit is contained in:
mrepetto 2019-10-11 19:49:38 -03:00
parent 9a1c71b349
commit 27bf6b380b
7 changed files with 93 additions and 36 deletions

View file

@ -50,6 +50,7 @@ class Dejavu:
def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
"""
To pull all fingerprinted songs from the database.
:return: a list of fingerprinted audios from the database.
"""
return self.db.get_songs()
@ -57,6 +58,7 @@ class Dejavu:
def delete_songs_by_id(self, song_ids: List[int]) -> None:
"""
Deletes all audios given their ids.
:param song_ids: song ids to delete from the database.
"""
self.db.delete_songs_by_id(song_ids)
@ -64,6 +66,7 @@ class Dejavu:
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
"""
Given a directory and a set of extensions it fingerprints all files that match each extension specified.
:param path: path to the directory.
:param extensions: list of file extensions to consider.
:param nprocesses: amount of processes to fingerprint the files within the directory.
@ -119,6 +122,7 @@ class Dejavu:
"""
Given a path to a file the method generates hashes for it and stores them in the database
for later being queried.
:param file_path: path to the file.
:param song_name: song name associated to the audio file.
"""
@ -143,6 +147,7 @@ class Dejavu:
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
f"""
Generate the fingerprints for the given sample data (channel).
:param samples: list of ints which represents the channel info of the given audio file.
:param Fs: sampling rate which defaults to {DEFAULT_FS}.
:return: a list of tuples for hash and its corresponding offset, together with the generation time.
@ -155,6 +160,7 @@ class Dejavu:
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
"""
Finds the corresponding matches on the fingerprinted audios for the given hashes.
:param hashes: list of tuples for hashes and their corresponding offsets
:return: a tuple containing the matches found against the db, a dictionary which counts the different
hashes matched for each song (with the song id as key), and the time that the query took.
@ -171,6 +177,7 @@ class Dejavu:
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
:param matches: matches from the database
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
(key is the song id).

View file

@ -172,6 +172,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
"""
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
:param song_ids: song ids to be deleted from the database.
:param batch_size: number of query's batches.
"""
@ -181,6 +182,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
def get_database(database_type: str = "mysql") -> BaseDatabase:
"""
Given a database type it returns a database instance for that type.
:param database_type: type of the database.
:return: an instance of BaseDatabase depending on given database_type.
"""

View file

@ -220,6 +220,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
"""
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
:param song_ids: song ids to be deleted from the database.
:param batch_size: number of query's batches.
"""

View file

@ -48,6 +48,14 @@ FIELD_HASH = 'hash'
FIELD_OFFSET = 'offset'
# FINGERPRINTS CONFIG:
# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter
# changes the morphology mask when looking for maximum peaks on the spectrogram matrix.
# Possible values are: [1, 2]
# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this
# is the value used in the original dejavu code).
# And 2 sets a square mask, i.e. all elements are considered neighbors.
CONNECTIVITY_MASK = 2
# Sampling rate, related to the Nyquist conditions, which affects
# the range frequencies we can detect.
DEFAULT_FS = 44100
@ -60,8 +68,8 @@ DEFAULT_WINDOW_SIZE = 4096
# matching, but potentially more fingerprints.
DEFAULT_OVERLAP_RATIO = 0.5
# Degree to which a fingerprint can be paired with its neighbors --
# higher will cause more fingerprints, but potentially better accuracy.
# Degree to which a fingerprint can be paired with its neighbors. Higher values will
# cause more fingerprints, but potentially better accuracy.
DEFAULT_FAN_VALUE = 5 # 15 was the original value.
# Minimum amplitude in spectrogram in order to be considered a peak.

View file

@ -34,6 +34,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str:
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
"""
Get all files that meet the specified extensions.
:param path: path to a directory with audio files.
:param extensions: file extensions to look for.
:return: a list of tuples with file name and its extension.
@ -97,6 +98,8 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
def get_audio_name_from_path(file_path: str) -> str:
"""
Extracts song name from a file path.
:param file_path: path to an audio file.
:return: file name
"""
return os.path.splitext(os.path.basename(file_path))[0]

View file

@ -1,5 +1,6 @@
import hashlib
from operator import itemgetter
from typing import List, Tuple
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
@ -9,24 +10,30 @@ from scipy.ndimage.morphology import (binary_erosion,
generate_binary_structure,
iterate_structure)
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN,
DEFAULT_FAN_VALUE, DEFAULT_FS,
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA,
MIN_HASH_TIME_DELTA,
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
IDX_FREQ_I = 0
IDX_TIME_J = 1
def fingerprint(channel_samples,
Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE,
amp_min=DEFAULT_AMP_MIN):
def fingerprint(channel_samples: List[int],
Fs: int = DEFAULT_FS,
wsize: int = DEFAULT_WINDOW_SIZE,
wratio: float = DEFAULT_OVERLAP_RATIO,
fan_value: int = DEFAULT_FAN_VALUE,
amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]:
"""
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
:param channel_samples: channel samples to fingerprint.
:param Fs: audio sampling rate.
:param wsize: FFT windows size.
:param wratio: ratio by which each sequential window overlaps the last and the next window.
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
:return: a list of hashes with their corresponding offsets.
"""
# FFT the signal and extract frequency components
arr2D = mlab.specgram(
@ -36,7 +43,7 @@ def fingerprint(channel_samples,
window=mlab.window_hanning,
noverlap=int(wsize * wratio))[0]
# Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning.
# Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning.
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
@ -45,18 +52,45 @@ def fingerprint(channel_samples,
return generate_hashes(local_maxima, fan_value=fan_value)
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure
struct = generate_binary_structure(2, 1)
def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\
-> List[Tuple[List[int], List[int]]]:
"""
Extract maximum peaks from the spectogram matrix (arr2D).
:param arr2D: matrix representing the spectogram.
:param plot: for plotting the results.
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
:return: a list composed by a list of frequencies and times.
"""
# Original code from the repo is using a morphology mask that does not consider diagonal elements
# as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing
# is to change from the current diamond figure to a just a normal square one:
# F T F T T T
# T T T ==> T T T
# F T F T T T
# In my local tests time performance of the square mask was ~3 times faster
# respect to the diamond one, without hurting accuracy of the predictions.
# I've made now the mask shape configurable in order to allow both ways of find maximum peaks.
# That being said, we generate the mask by using the following function
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html
struct = generate_binary_structure(2, CONNECTIVITY_MASK)
# And then we apply dilation using the following function
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html
# Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just
# change it by the following code:
# neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# find local maxima using our filter shape
# find local maxima using our filter mask
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
# Applying erosion, the dejavu documentation does not talk about this step.
background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
# Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^')
detected_peaks = local_max ^ eroded_background
# Boolean mask of arr2D with True at peaks (applying XOR on both matrices).
detected_peaks = local_max != eroded_background
# extract peaks
amps = arr2D[detected_peaks]
@ -64,6 +98,7 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# filter peaks
amps = amps.flatten()
# get indices for frequency and time
filter_idxs = np.where(amps > amp_min)
@ -84,12 +119,21 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
return list(zip(freqs_filter, times_filter))
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]:
"""
Hash list structure:
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
[(e05b341a9b77a51fd26, 32), ... ]
:param peaks: list of peak frequencies and times.
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
:return: a list of hashes with their corresponding offsets.
"""
# frequencies are in the first position of the tuples
idx_freq = 0
# times are in the second position of the tuples
idx_time = 1
if PEAK_SORT:
peaks.sort(key=itemgetter(1))
@ -98,10 +142,10 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
for j in range(1, fan_value):
if (i + j) < len(peaks):
freq1 = peaks[i][IDX_FREQ_I]
freq2 = peaks[i + j][IDX_FREQ_I]
t1 = peaks[i][IDX_TIME_J]
t2 = peaks[i + j][IDX_TIME_J]
freq1 = peaks[i][idx_freq]
freq2 = peaks[i + j][idx_freq]
t1 = peaks[i][idx_time]
t2 = peaks[i + j][idx_time]
t_delta = t2 - t1
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:

View file

@ -32,11 +32,3 @@ if __name__ == '__main__':
recognizer = FileRecognizer(djv)
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
print(f"No shortcut, we recognized: {results}\n")
# To list all fingerprinted songs in the db you can use the following:
# fingerprinted_songs = djv.get_fingerprinted_songs()
# print(fingerprinted_songs)
# And to delete a song or a set of songs you can use the following:
# song_ids_to_delete = [1]
# djv.delete_songs_by_ids(song_ids_to_delete)