mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
More changes:
- added even more docstring to the solution. - changed maximum filter mask on fingerprints (now configurable)
This commit is contained in:
parent
9a1c71b349
commit
27bf6b380b
7 changed files with 93 additions and 36 deletions
|
@ -50,6 +50,7 @@ class Dejavu:
|
|||
def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
|
||||
"""
|
||||
To pull all fingerprinted songs from the database.
|
||||
|
||||
:return: a list of fingerprinted audios from the database.
|
||||
"""
|
||||
return self.db.get_songs()
|
||||
|
@ -57,6 +58,7 @@ class Dejavu:
|
|||
def delete_songs_by_id(self, song_ids: List[int]) -> None:
|
||||
"""
|
||||
Deletes all audios given their ids.
|
||||
|
||||
:param song_ids: song ids to delete from the database.
|
||||
"""
|
||||
self.db.delete_songs_by_id(song_ids)
|
||||
|
@ -64,6 +66,7 @@ class Dejavu:
|
|||
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
|
||||
"""
|
||||
Given a directory and a set of extensions it fingerprints all files that match each extension specified.
|
||||
|
||||
:param path: path to the directory.
|
||||
:param extensions: list of file extensions to consider.
|
||||
:param nprocesses: amount of processes to fingerprint the files within the directory.
|
||||
|
@ -119,6 +122,7 @@ class Dejavu:
|
|||
"""
|
||||
Given a path to a file the method generates hashes for it and stores them in the database
|
||||
for later being queried.
|
||||
|
||||
:param file_path: path to the file.
|
||||
:param song_name: song name associated to the audio file.
|
||||
"""
|
||||
|
@ -143,6 +147,7 @@ class Dejavu:
|
|||
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
|
||||
f"""
|
||||
Generate the fingerprints for the given sample data (channel).
|
||||
|
||||
:param samples: list of ints which represents the channel info of the given audio file.
|
||||
:param Fs: sampling rate which defaults to {DEFAULT_FS}.
|
||||
:return: a list of tuples for hash and its corresponding offset, together with the generation time.
|
||||
|
@ -155,6 +160,7 @@ class Dejavu:
|
|||
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
|
||||
"""
|
||||
Finds the corresponding matches on the fingerprinted audios for the given hashes.
|
||||
|
||||
:param hashes: list of tuples for hashes and their corresponding offsets
|
||||
:return: a tuple containing the matches found against the db, a dictionary which counts the different
|
||||
hashes matched for each song (with the song id as key), and the time that the query took.
|
||||
|
@ -171,6 +177,7 @@ class Dejavu:
|
|||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
:param matches: matches from the database
|
||||
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
|
||||
(key is the song id).
|
||||
|
|
|
@ -172,6 +172,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
|
|||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||
|
||||
:param song_ids: song ids to be deleted from the database.
|
||||
:param batch_size: number of query's batches.
|
||||
"""
|
||||
|
@ -181,6 +182,7 @@ class BaseDatabase(object, metaclass=abc.ABCMeta):
|
|||
def get_database(database_type: str = "mysql") -> BaseDatabase:
|
||||
"""
|
||||
Given a database type it returns a database instance for that type.
|
||||
|
||||
:param database_type: type of the database.
|
||||
:return: an instance of BaseDatabase depending on given database_type.
|
||||
"""
|
||||
|
|
|
@ -220,6 +220,7 @@ class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
|
|||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||
|
||||
:param song_ids: song ids to be deleted from the database.
|
||||
:param batch_size: number of query's batches.
|
||||
"""
|
||||
|
|
|
@ -48,6 +48,14 @@ FIELD_HASH = 'hash'
|
|||
FIELD_OFFSET = 'offset'
|
||||
|
||||
# FINGERPRINTS CONFIG:
|
||||
# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter
|
||||
# changes the morphology mask when looking for maximum peaks on the spectrogram matrix.
|
||||
# Possible values are: [1, 2]
|
||||
# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this
|
||||
# is the value used in the original dejavu code).
|
||||
# And 2 sets a square mask, i.e. all elements are considered neighbors.
|
||||
CONNECTIVITY_MASK = 2
|
||||
|
||||
# Sampling rate, related to the Nyquist conditions, which affects
|
||||
# the range frequencies we can detect.
|
||||
DEFAULT_FS = 44100
|
||||
|
@ -60,8 +68,8 @@ DEFAULT_WINDOW_SIZE = 4096
|
|||
# matching, but potentially more fingerprints.
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
|
||||
# Degree to which a fingerprint can be paired with its neighbors --
|
||||
# higher will cause more fingerprints, but potentially better accuracy.
|
||||
# Degree to which a fingerprint can be paired with its neighbors. Higher values will
|
||||
# cause more fingerprints, but potentially better accuracy.
|
||||
DEFAULT_FAN_VALUE = 5 # 15 was the original value.
|
||||
|
||||
# Minimum amplitude in spectrogram in order to be considered a peak.
|
||||
|
|
|
@ -34,6 +34,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str:
|
|||
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Get all files that meet the specified extensions.
|
||||
|
||||
:param path: path to a directory with audio files.
|
||||
:param extensions: file extensions to look for.
|
||||
:return: a list of tuples with file name and its extension.
|
||||
|
@ -97,6 +98,8 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
|
|||
def get_audio_name_from_path(file_path: str) -> str:
|
||||
"""
|
||||
Extracts song name from a file path.
|
||||
|
||||
:param file_path: path to an audio file.
|
||||
:return: file name
|
||||
"""
|
||||
return os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
from operator import itemgetter
|
||||
from typing import List, Tuple
|
||||
|
||||
import matplotlib.mlab as mlab
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -9,24 +10,30 @@ from scipy.ndimage.morphology import (binary_erosion,
|
|||
generate_binary_structure,
|
||||
iterate_structure)
|
||||
|
||||
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
||||
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
||||
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
||||
from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN,
|
||||
DEFAULT_FAN_VALUE, DEFAULT_FS,
|
||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||
FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA,
|
||||
MIN_HASH_TIME_DELTA,
|
||||
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
|
||||
|
||||
def fingerprint(channel_samples,
|
||||
Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
def fingerprint(channel_samples: List[int],
|
||||
Fs: int = DEFAULT_FS,
|
||||
wsize: int = DEFAULT_WINDOW_SIZE,
|
||||
wratio: float = DEFAULT_OVERLAP_RATIO,
|
||||
fan_value: int = DEFAULT_FAN_VALUE,
|
||||
amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
|
||||
|
||||
:param channel_samples: channel samples to fingerprint.
|
||||
:param Fs: audio sampling rate.
|
||||
:param wsize: FFT windows size.
|
||||
:param wratio: ratio by which each sequential window overlaps the last and the next window.
|
||||
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||
:return: a list of hashes with their corresponding offsets.
|
||||
"""
|
||||
# FFT the signal and extract frequency components
|
||||
arr2D = mlab.specgram(
|
||||
|
@ -36,7 +43,7 @@ def fingerprint(channel_samples,
|
|||
window=mlab.window_hanning,
|
||||
noverlap=int(wsize * wratio))[0]
|
||||
|
||||
# Apply log transform since specgram() returns linear array. 0s are excluded to avoid np warning.
|
||||
# Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning.
|
||||
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
|
||||
|
||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||
|
@ -45,18 +52,45 @@ def fingerprint(channel_samples,
|
|||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure
|
||||
struct = generate_binary_structure(2, 1)
|
||||
def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\
|
||||
-> List[Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Extract maximum peaks from the spectogram matrix (arr2D).
|
||||
|
||||
:param arr2D: matrix representing the spectogram.
|
||||
:param plot: for plotting the results.
|
||||
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||
:return: a list composed by a list of frequencies and times.
|
||||
"""
|
||||
# Original code from the repo is using a morphology mask that does not consider diagonal elements
|
||||
# as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing
|
||||
# is to change from the current diamond figure to a just a normal square one:
|
||||
# F T F T T T
|
||||
# T T T ==> T T T
|
||||
# F T F T T T
|
||||
# In my local tests time performance of the square mask was ~3 times faster
|
||||
# respect to the diamond one, without hurting accuracy of the predictions.
|
||||
# I've made now the mask shape configurable in order to allow both ways of find maximum peaks.
|
||||
# That being said, we generate the mask by using the following function
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html
|
||||
struct = generate_binary_structure(2, CONNECTIVITY_MASK)
|
||||
|
||||
# And then we apply dilation using the following function
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html
|
||||
# Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just
|
||||
# change it by the following code:
|
||||
# neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
||||
# find local maxima using our filter shape
|
||||
# find local maxima using our filter mask
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
|
||||
# Applying erosion, the dejavu documentation does not talk about this step.
|
||||
background = (arr2D == 0)
|
||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||
|
||||
# Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^')
|
||||
detected_peaks = local_max ^ eroded_background
|
||||
# Boolean mask of arr2D with True at peaks (applying XOR on both matrices).
|
||||
detected_peaks = local_max != eroded_background
|
||||
|
||||
# extract peaks
|
||||
amps = arr2D[detected_peaks]
|
||||
|
@ -64,6 +98,7 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
|||
|
||||
# filter peaks
|
||||
amps = amps.flatten()
|
||||
|
||||
# get indices for frequency and time
|
||||
filter_idxs = np.where(amps > amp_min)
|
||||
|
||||
|
@ -84,12 +119,21 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
|||
return list(zip(freqs_filter, times_filter))
|
||||
|
||||
|
||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||
def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
|
||||
:param peaks: list of peak frequencies and times.
|
||||
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||
:return: a list of hashes with their corresponding offsets.
|
||||
"""
|
||||
# frequencies are in the first position of the tuples
|
||||
idx_freq = 0
|
||||
# times are in the second position of the tuples
|
||||
idx_time = 1
|
||||
|
||||
if PEAK_SORT:
|
||||
peaks.sort(key=itemgetter(1))
|
||||
|
||||
|
@ -98,10 +142,10 @@ def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
|||
for j in range(1, fan_value):
|
||||
if (i + j) < len(peaks):
|
||||
|
||||
freq1 = peaks[i][IDX_FREQ_I]
|
||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
||||
t1 = peaks[i][IDX_TIME_J]
|
||||
t2 = peaks[i + j][IDX_TIME_J]
|
||||
freq1 = peaks[i][idx_freq]
|
||||
freq2 = peaks[i + j][idx_freq]
|
||||
t1 = peaks[i][idx_time]
|
||||
t2 = peaks[i + j][idx_time]
|
||||
t_delta = t2 - t1
|
||||
|
||||
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
|
||||
|
|
|
@ -32,11 +32,3 @@ if __name__ == '__main__':
|
|||
recognizer = FileRecognizer(djv)
|
||||
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||
print(f"No shortcut, we recognized: {results}\n")
|
||||
|
||||
# To list all fingerprinted songs in the db you can use the following:
|
||||
# fingerprinted_songs = djv.get_fingerprinted_songs()
|
||||
# print(fingerprinted_songs)
|
||||
|
||||
# And to delete a song or a set of songs you can use the following:
|
||||
# song_ids_to_delete = [1]
|
||||
# djv.delete_songs_by_ids(song_ids_to_delete)
|
||||
|
|
Loading…
Reference in a new issue