mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
More changes:
- added even more docstring to the solution. - changed maximum filter mask on fingerprints (now configurable)
This commit is contained in:
parent
9a1c71b349
commit
27bf6b380b
7 changed files with 93 additions and 36 deletions
|
@ -50,6 +50,7 @@ class Dejavu:
|
||||||
def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
|
def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
|
||||||
"""
|
"""
|
||||||
To pull all fingerprinted songs from the database.
|
To pull all fingerprinted songs from the database.
|
||||||
|
|
||||||
:return: a list of fingerprinted audios from the database.
|
:return: a list of fingerprinted audios from the database.
|
||||||
"""
|
"""
|
||||||
return self.db.get_songs()
|
return self.db.get_songs()
|
||||||
|
@ -57,6 +58,7 @@ class Dejavu:
|
||||||
def delete_songs_by_id(self, song_ids: List[int]) -> None:
|
def delete_songs_by_id(self, song_ids: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
Deletes all audios given their ids.
|
Deletes all audios given their ids.
|
||||||
|
|
||||||
:param song_ids: song ids to delete from the database.
|
:param song_ids: song ids to delete from the database.
|
||||||
"""
|
"""
|
||||||
self.db.delete_songs_by_id(song_ids)
|
self.db.delete_songs_by_id(song_ids)
|
||||||
|
@ -64,6 +66,7 @@ class Dejavu:
|
||||||
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
|
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
|
||||||
"""
|
"""
|
||||||
Given a directory and a set of extensions it fingerprints all files that match each extension specified.
|
Given a directory and a set of extensions it fingerprints all files that match each extension specified.
|
||||||
|
|
||||||
:param path: path to the directory.
|
:param path: path to the directory.
|
||||||
:param extensions: list of file extensions to consider.
|
:param extensions: list of file extensions to consider.
|
||||||
:param nprocesses: amount of processes to fingerprint the files within the directory.
|
:param nprocesses: amount of processes to fingerprint the files within the directory.
|
||||||
|
@ -119,6 +122,7 @@ class Dejavu:
|
||||||
"""
|
"""
|
||||||
Given a path to a file the method generates hashes for it and stores them in the database
|
Given a path to a file the method generates hashes for it and stores them in the database
|
||||||
for later being queried.
|
for later being queried.
|
||||||
|
|
||||||
:param file_path: path to the file.
|
:param file_path: path to the file.
|
||||||
:param song_name: song name associated to the audio file.
|
:param song_name: song name associated to the audio file.
|
||||||
"""
|
"""
|
||||||
|
@ -143,6 +147,7 @@ class Dejavu:
|
||||||
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
|
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
|
||||||
f"""
|
f"""
|
||||||
Generate the fingerprints for the given sample data (channel).
|
Generate the fingerprints for the given sample data (channel).
|
||||||
|
|
||||||
:param samples: list of ints which represents the channel info of the given audio file.
|
:param samples: list of ints which represents the channel info of the given audio file.
|
||||||
:param Fs: sampling rate which defaults to {DEFAULT_FS}.
|
:param Fs: sampling rate which defaults to {DEFAULT_FS}.
|
||||||
:return: a list of tuples for hash and its corresponding offset, together with the generation time.
|
:return: a list of tuples for hash and its corresponding offset, together with the generation time.
|
||||||
|
@ -155,6 +160,7 @@ class Dejavu:
|
||||||
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
|
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
|
||||||
"""
|
"""
|
||||||
Finds the corresponding matches on the fingerprinted audios for the given hashes.
|
Finds the corresponding matches on the fingerprinted audios for the given hashes.
|
||||||
|
|
||||||
:param hashes: list of tuples for hashes and their corresponding offsets
|
:param hashes: list of tuples for hashes and their corresponding offsets
|
||||||
:return: a tuple containing the matches found against the db, a dictionary which counts the different
|
:return: a tuple containing the matches found against the db, a dictionary which counts the different
|
||||||
hashes matched for each song (with the song id as key), and the time that the query took.
|
hashes matched for each song (with the song id as key), and the time that the query took.
|
||||||
|
@ -171,6 +177,7 @@ class Dejavu:
|
||||||
"""
|
"""
|
||||||
Finds hash matches that align in time with other matches and finds
|
Finds hash matches that align in time with other matches and finds
|
||||||
consensus about which hashes are "true" signal from the audio.
|
consensus about which hashes are "true" signal from the audio.
|
||||||
|
|
||||||
:param matches: matches from the database
|
:param matches: matches from the database
|
||||||
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
|
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
|
||||||
(key is the song id).
|
(key is the song id).
|
||||||
|
|
|
@ -172,6 +172,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
|
||||||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||||
"""
|
"""
|
||||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||||
|
|
||||||
:param song_ids: song ids to be deleted from the database.
|
:param song_ids: song ids to be deleted from the database.
|
||||||
:param batch_size: number of query's batches.
|
:param batch_size: number of query's batches.
|
||||||
"""
|
"""
|
||||||
|
@ -181,6 +182,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
|
||||||
def get_database(database_type: str = "mysql") -> BaseDatabase:
|
def get_database(database_type: str = "mysql") -> BaseDatabase:
|
||||||
"""
|
"""
|
||||||
Given a database type it returns a database instance for that type.
|
Given a database type it returns a database instance for that type.
|
||||||
|
|
||||||
:param database_type: type of the database.
|
:param database_type: type of the database.
|
||||||
:return: an instance of BaseDatabase depending on given database_type.
|
:return: an instance of BaseDatabase depending on given database_type.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -220,6 +220,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
|
||||||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||||
"""
|
"""
|
||||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||||
|
|
||||||
:param song_ids: song ids to be deleted from the database.
|
:param song_ids: song ids to be deleted from the database.
|
||||||
:param batch_size: number of query's batches.
|
:param batch_size: number of query's batches.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -48,6 +48,14 @@ FIELD_HASH = 'hash'
|
||||||
FIELD_OFFSET = 'offset'
|
FIELD_OFFSET = 'offset'
|
||||||
|
|
||||||
# FINGERPRINTS CONFIG:
|
# FINGERPRINTS CONFIG:
|
||||||
|
# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter
|
||||||
|
# changes the morphology mask when looking for maximum peaks on the spectrogram matrix.
|
||||||
|
# Possible values are: [1, 2]
|
||||||
|
# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this
|
||||||
|
# is the value used in the original dejavu code).
|
||||||
|
# And 2 sets a square mask, i.e. all elements are considered neighbors.
|
||||||
|
CONNECTIVITY_MASK = 2
|
||||||
|
|
||||||
# Sampling rate, related to the Nyquist conditions, which affects
|
# Sampling rate, related to the Nyquist conditions, which affects
|
||||||
# the range frequencies we can detect.
|
# the range frequencies we can detect.
|
||||||
DEFAULT_FS = 44100
|
DEFAULT_FS = 44100
|
||||||
|
@ -60,8 +68,8 @@ DEFAULT_WINDOW_SIZE = 4096
|
||||||
# matching, but potentially more fingerprints.
|
# matching, but potentially more fingerprints.
|
||||||
DEFAULT_OVERLAP_RATIO = 0.5
|
DEFAULT_OVERLAP_RATIO = 0.5
|
||||||
|
|
||||||
# Degree to which a fingerprint can be paired with its neighbors --
|
# Degree to which a fingerprint can be paired with its neighbors. Higher values will
|
||||||
# higher will cause more fingerprints, but potentially better accuracy.
|
# cause more fingerprints, but potentially better accuracy.
|
||||||
DEFAULT_FAN_VALUE = 5 # 15 was the original value.
|
DEFAULT_FAN_VALUE = 5 # 15 was the original value.
|
||||||
|
|
||||||
# Minimum amplitude in spectrogram in order to be considered a peak.
|
# Minimum amplitude in spectrogram in order to be considered a peak.
|
||||||
|
|
|
@ -34,6 +34,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str:
|
||||||
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
|
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Get all files that meet the specified extensions.
|
Get all files that meet the specified extensions.
|
||||||
|
|
||||||
:param path: path to a directory with audio files.
|
:param path: path to a directory with audio files.
|
||||||
:param extensions: file extensions to look for.
|
:param extensions: file extensions to look for.
|
||||||
:return: a list of tuples with file name and its extension.
|
:return: a list of tuples with file name and its extension.
|
||||||
|
@ -97,6 +98,8 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
|
||||||
def get_audio_name_from_path(file_path: str) -> str:
|
def get_audio_name_from_path(file_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extracts song name from a file path.
|
Extracts song name from a file path.
|
||||||
|
|
||||||
:param file_path: path to an audio file.
|
:param file_path: path to an audio file.
|
||||||
|
:return: file name
|
||||||
"""
|
"""
|
||||||
return os.path.splitext(os.path.basename(file_path))[0]
|
return os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import matplotlib.mlab as mlab
|
import matplotlib.mlab as mlab
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -9,24 +10,30 @@ from scipy.ndimage.morphology import (binary_erosion,
|
||||||
generate_binary_structure,
|
generate_binary_structure,
|
||||||
iterate_structure)
|
iterate_structure)
|
||||||
|
|
||||||
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN,
|
||||||
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
DEFAULT_FAN_VALUE, DEFAULT_FS,
|
||||||
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||||
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA,
|
||||||
|
MIN_HASH_TIME_DELTA,
|
||||||
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
||||||
|
|
||||||
IDX_FREQ_I = 0
|
|
||||||
IDX_TIME_J = 1
|
|
||||||
|
|
||||||
|
def fingerprint(channel_samples: List[int],
|
||||||
def fingerprint(channel_samples,
|
Fs: int = DEFAULT_FS,
|
||||||
Fs=DEFAULT_FS,
|
wsize: int = DEFAULT_WINDOW_SIZE,
|
||||||
wsize=DEFAULT_WINDOW_SIZE,
|
wratio: float = DEFAULT_OVERLAP_RATIO,
|
||||||
wratio=DEFAULT_OVERLAP_RATIO,
|
fan_value: int = DEFAULT_FAN_VALUE,
|
||||||
fan_value=DEFAULT_FAN_VALUE,
|
amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]:
|
||||||
amp_min=DEFAULT_AMP_MIN):
|
|
||||||
"""
|
"""
|
||||||
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
|
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
|
||||||
|
|
||||||
|
:param channel_samples: channel samples to fingerprint.
|
||||||
|
:param Fs: audio sampling rate.
|
||||||
|
:param wsize: FFT windows size.
|
||||||
|
:param wratio: ratio by which each sequential window overlaps the last and the next window.
|
||||||
|
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||||
|
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||||
|
:return: a list of hashes with their corresponding offsets.
|
||||||
"""
|
"""
|
||||||
# FFT the signal and extract frequency components
|
# FFT the signal and extract frequency components
|
||||||
arr2D = mlab.specgram(
|
arr2D = mlab.specgram(
|
||||||
|
@ -36,7 +43,7 @@ def fingerprint(channel_samples,
|
||||||
window=mlab.window_hanning,
|
window=mlab.window_hanning,
|
||||||
noverlap=int(wsize * wratio))[0]
|
noverlap=int(wsize * wratio))[0]
|
||||||
|
|
||||||
# Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning.
|
# Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning.
|
||||||
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
|
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
|
||||||
|
|
||||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||||
|
@ -45,18 +52,45 @@ def fingerprint(channel_samples,
|
||||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||||
|
|
||||||
|
|
||||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\
|
||||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure
|
-> List[Tuple[List[int], List[int]]]:
|
||||||
struct = generate_binary_structure(2, 1)
|
"""
|
||||||
|
Extract maximum peaks from the spectogram matrix (arr2D).
|
||||||
|
|
||||||
|
:param arr2D: matrix representing the spectogram.
|
||||||
|
:param plot: for plotting the results.
|
||||||
|
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||||
|
:return: a list composed by a list of frequencies and times.
|
||||||
|
"""
|
||||||
|
# Original code from the repo is using a morphology mask that does not consider diagonal elements
|
||||||
|
# as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing
|
||||||
|
# is to change from the current diamond figure to a just a normal square one:
|
||||||
|
# F T F T T T
|
||||||
|
# T T T ==> T T T
|
||||||
|
# F T F T T T
|
||||||
|
# In my local tests time performance of the square mask was ~3 times faster
|
||||||
|
# respect to the diamond one, without hurting accuracy of the predictions.
|
||||||
|
# I've made now the mask shape configurable in order to allow both ways of find maximum peaks.
|
||||||
|
# That being said, we generate the mask by using the following function
|
||||||
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html
|
||||||
|
struct = generate_binary_structure(2, CONNECTIVITY_MASK)
|
||||||
|
|
||||||
|
# And then we apply dilation using the following function
|
||||||
|
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html
|
||||||
|
# Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just
|
||||||
|
# change it by the following code:
|
||||||
|
# neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool)
|
||||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||||
|
|
||||||
# find local maxima using our filter shape
|
# find local maxima using our filter mask
|
||||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||||
|
|
||||||
|
# Applying erosion, the dejavu documentation does not talk about this step.
|
||||||
background = (arr2D == 0)
|
background = (arr2D == 0)
|
||||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||||
|
|
||||||
# Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^')
|
# Boolean mask of arr2D with True at peaks (applying XOR on both matrices).
|
||||||
detected_peaks = local_max ^ eroded_background
|
detected_peaks = local_max != eroded_background
|
||||||
|
|
||||||
# extract peaks
|
# extract peaks
|
||||||
amps = arr2D[detected_peaks]
|
amps = arr2D[detected_peaks]
|
||||||
|
@ -64,6 +98,7 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
|
|
||||||
# filter peaks
|
# filter peaks
|
||||||
amps = amps.flatten()
|
amps = amps.flatten()
|
||||||
|
|
||||||
# get indices for frequency and time
|
# get indices for frequency and time
|
||||||
filter_idxs = np.where(amps > amp_min)
|
filter_idxs = np.where(amps > amp_min)
|
||||||
|
|
||||||
|
@ -84,12 +119,21 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
return list(zip(freqs_filter, times_filter))
|
return list(zip(freqs_filter, times_filter))
|
||||||
|
|
||||||
|
|
||||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Hash list structure:
|
Hash list structure:
|
||||||
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
|
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
|
||||||
[(e05b341a9b77a51fd26, 32), ... ]
|
[(e05b341a9b77a51fd26, 32), ... ]
|
||||||
|
|
||||||
|
:param peaks: list of peak frequencies and times.
|
||||||
|
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||||
|
:return: a list of hashes with their corresponding offsets.
|
||||||
"""
|
"""
|
||||||
|
# frequencies are in the first position of the tuples
|
||||||
|
idx_freq = 0
|
||||||
|
# times are in the second position of the tuples
|
||||||
|
idx_time = 1
|
||||||
|
|
||||||
if PEAK_SORT:
|
if PEAK_SORT:
|
||||||
peaks.sort(key=itemgetter(1))
|
peaks.sort(key=itemgetter(1))
|
||||||
|
|
||||||
|
@ -98,10 +142,10 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||||
for j in range(1, fan_value):
|
for j in range(1, fan_value):
|
||||||
if (i + j) < len(peaks):
|
if (i + j) < len(peaks):
|
||||||
|
|
||||||
freq1 = peaks[i][IDX_FREQ_I]
|
freq1 = peaks[i][idx_freq]
|
||||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
freq2 = peaks[i + j][idx_freq]
|
||||||
t1 = peaks[i][IDX_TIME_J]
|
t1 = peaks[i][idx_time]
|
||||||
t2 = peaks[i + j][IDX_TIME_J]
|
t2 = peaks[i + j][idx_time]
|
||||||
t_delta = t2 - t1
|
t_delta = t2 - t1
|
||||||
|
|
||||||
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
|
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
|
||||||
|
|
|
@ -32,11 +32,3 @@ if __name__ == '__main__':
|
||||||
recognizer = FileRecognizer(djv)
|
recognizer = FileRecognizer(djv)
|
||||||
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||||
print(f"No shortcut, we recognized: {results}\n")
|
print(f"No shortcut, we recognized: {results}\n")
|
||||||
|
|
||||||
# To list all fingerprinted songs in the db you can use the following:
|
|
||||||
# fingerprinted_songs = djv.get_fingerprinted_songs()
|
|
||||||
# print(fingerprinted_songs)
|
|
||||||
|
|
||||||
# And to delete a song or a set of songs you can use the following:
|
|
||||||
# song_ids_to_delete = [1]
|
|
||||||
# djv.delete_songs_by_ids(song_ids_to_delete)
|
|
||||||
|
|
Loading…
Reference in a new issue