- Extracted common logic from database handlers.

- Fixed tests.
- Refactored solution architecture once more.
- Refactored solution hierarchy.
- Adding audios for testing.
- Solved flake8 issues and reorganized several imports in the process.
This commit is contained in:
mrepetto 2019-09-25 17:38:25 -03:00
parent 6bc15ab8d4
commit e046eeee93
32 changed files with 1154 additions and 1058 deletions

4
.gitignore vendored
View file

@ -1,7 +1,3 @@
*.pyc *.pyc
wav
mp3
*.wav
*.mp3
.DS_Store .DS_Store
*.cnf *.cnf

View file

@ -4,5 +4,6 @@
"user": "root", "user": "root",
"password": "rootpass", "password": "rootpass",
"database": "dejavu" "database": "dejavu"
} },
"database_type": "mysql"
} }

View file

@ -1,16 +1,12 @@
#!/usr/bin/python
import argparse import argparse
import json import json
import os from os.path import isdir
import sys import sys
import warnings
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
from dejavu import Dejavu from dejavu import Dejavu
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
warnings.filterwarnings("ignore")
DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE" DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE"
@ -58,7 +54,6 @@ if __name__ == '__main__':
config_file = args.config config_file = args.config
if config_file is None: if config_file is None:
config_file = DEFAULT_CONFIG_FILE config_file = DEFAULT_CONFIG_FILE
# print ("Using default config file: {config_file}")
djv = init(config_file) djv = init(config_file)
if args.fingerprint: if args.fingerprint:
@ -71,22 +66,19 @@ if __name__ == '__main__':
elif len(args.fingerprint) == 1: elif len(args.fingerprint) == 1:
filepath = args.fingerprint[0] filepath = args.fingerprint[0]
if os.path.isdir(filepath): if isdir(filepath):
print("Please specify an extension if you'd like to fingerprint a directory!") print("Please specify an extension if you'd like to fingerprint a directory!")
sys.exit(1) sys.exit(1)
djv.fingerprint_file(filepath) djv.fingerprint_file(filepath)
elif args.recognize: elif args.recognize:
# Recognize audio source # Recognize audio source
song = None songs = None
source = args.recognize[0] source = args.recognize[0]
opt_arg = args.recognize[1] opt_arg = args.recognize[1]
if source in ('mic', 'microphone'): if source in ('mic', 'microphone'):
song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg) songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
elif source == 'file': elif source == 'file':
song = djv.recognize(FileRecognizer, opt_arg) songs = djv.recognize(FileRecognizer, opt_arg)
decoded_song = repr(song).decode('string_escape') print(songs)
print(decoded_song)
sys.exit(0)

View file

@ -3,13 +3,13 @@ import os
import sys import sys
import traceback import traceback
import dejavu.decoder as decoder import dejavu.logic.decoder as decoder
from dejavu.config.config import (CONFIDENCE, DEFAULT_FS, from dejavu.base_classes.base_database import get_database
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE, from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
FIELD_FILE_SHA1, OFFSET, OFFSET_SECS, DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
SONG_ID, SONG_NAME, TOPN) FIELD_FILE_SHA1, OFFSET, OFFSET_SECS,
from dejavu.database import get_database SONG_ID, SONG_NAME, TOPN)
from dejavu.fingerprint import fingerprint from dejavu.logic.fingerprint import fingerprint
class Dejavu: class Dejavu:
@ -71,7 +71,7 @@ class Dejavu:
continue continue
except StopIteration: except StopIteration:
break break
except: except Exception:
print("Failed fingerprinting") print("Failed fingerprinting")
# Print traceback because we can't reraise it here # Print traceback because we can't reraise it here
traceback.print_exc(file=sys.stdout) traceback.print_exc(file=sys.stdout)
@ -119,6 +119,7 @@ class Dejavu:
diff_counter = {} diff_counter = {}
largest_count = 0 largest_count = 0
# TODO: review logic to get topn results.
for tup in matches: for tup in matches:
sid, diff = tup sid, diff = tup
if diff not in diff_counter: if diff not in diff_counter:

View file

View file

@ -1,9 +1,11 @@
import abc import abc
import importlib import importlib
from dejavu.config.config import DATABASES from typing import Dict
from dejavu.config.settings import DATABASES
class Database(object, metaclass=abc.ABCMeta): class BaseDatabase(object, metaclass=abc.ABCMeta):
# Name of your Database subclass, this is used in configuration # Name of your Database subclass, this is used in configuration
# to refer to your class # to refer to your class
type = None type = None
@ -70,17 +72,16 @@ class Database(object, metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_songs(self): def get_songs(self) -> Dict[str, str]:
""" """
Returns all fully fingerprinted songs in the database. Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def get_song_by_id(self, sid): def get_song_by_id(self, sid) -> Dict[str, str]:
""" """
Return a song by its identifier Return a song by its identifier. Result must be a Dictionary.
sid: Song identifier sid: Song identifier
""" """
pass pass
@ -124,7 +125,7 @@ class Database(object, metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
def insert_hashes(self, sid, hashes): def insert_hashes(self, sid, hashes, batch=1000):
""" """
Insert a multitude of fingerprints. Insert a multitude of fingerprints.
@ -133,7 +134,6 @@ class Database(object, metaclass=abc.ABCMeta):
- hash: Part of a sha1 hash, in hexadecimal format - hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at. - offset: Offset this hash was created from/at.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def return_matches(self, hashes): def return_matches(self, hashes):

View file

@ -0,0 +1,19 @@
import abc
from dejavu.config.settings import DEFAULT_FS
class BaseRecognizer(object, metaclass=abc.ABCMeta):
def __init__(self, dejavu):
self.dejavu = dejavu
self.Fs = DEFAULT_FS
def _recognize(self, *data):
matches = []
for d in data:
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
return self.dejavu.align_matches(matches)
@abc.abstractmethod
def recognize(self):
pass # base class does nothing

View file

@ -0,0 +1,194 @@
import abc
from dejavu.base_classes.base_database import BaseDatabase
class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
# Since several methods across different databases are actually just the same
# I've built this class with the idea to reuse that logic instead of copy pasting
# over and over the same code.
def __init__(self):
super().__init__()
def before_fork(self):
"""
Called before the database instance is given to the new process
"""
pass
def after_fork(self):
"""
Called after the database instance has been given to the new process
This will be called in the new process.
"""
pass
def setup(self):
"""
Called on creation or shortly afterwards.
"""
with self.cursor() as cur:
cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Called when the database should be cleared of all data.
"""
with self.cursor() as cur:
cur.execute(self.DROP_FINGERPRINTS)
cur.execute(self.DROP_SONGS)
self.setup()
def delete_unfingerprinted_songs(self):
"""
Called to remove any song entries that do not have any fingerprints
associated with them.
"""
with self.cursor() as cur:
cur.execute(self.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns the amount of songs in the database.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
return count
def get_num_fingerprints(self):
"""
Returns the number of fingerprints in the database.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_NUM_FINGERPRINTS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
cur.close()
return count
def set_song_fingerprinted(self, sid):
"""
Sets a specific song as having all fingerprints in the database.
sid: Song identifier
"""
with self.cursor() as cur:
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
"""
with self.cursor(dictionary=True) as cur:
cur.execute(self.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Return a song by its identifier. Result must be a Dictionary.
sid: Song identifier
"""
with self.cursor(dictionary=True) as cur:
cur.execute(self.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, hash, sid, offset):
"""
Inserts a single fingerprint into the database.
hash: Part of a sha1 hash, in hexadecimal format
sid: Song identifier this fingerprint is off
offset: The offset this hash is from
"""
with self.cursor() as cur:
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
@abc.abstractmethod
def insert_song(self, song_name):
"""
Inserts a song name into the database, returns the new
identifier of the song.
song_name: The name of the song.
"""
pass
def query(self, fingerprint):
"""
Returns all matching fingerprint entries associated with
the given hash as parameter.
hash: Part of a sha1 hash, in hexadecimal format
"""
if fingerprint:
with self.cursor() as cur:
cur.execute(self.SELECT, (fingerprint,))
for sid, offset in cur:
yield (sid, offset)
else: # select all if no key
with self.cursor() as cur:
cur.execute(self.SELECT_ALL)
for sid, offset in cur:
yield (sid, offset)
def get_iterable_kv_pairs(self):
"""
Returns all fingerprints in the database.
"""
return self.query(None)
def insert_hashes(self, sid, hashes, batch=1000):
"""
Insert a multitude of fingerprints.
sid: Song identifier the fingerprints belong to
hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
"""
values = [(sid, hsh, int(offset)) for hsh, offset in hashes]
with self.cursor() as cur:
for index in range(0, len(hashes), batch):
cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch])
def return_matches(self, hashes, batch=1000):
"""
Searches the database for pairs of (hash, offset) values.
hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
Returns a sequence of (sid, offset_difference) tuples.
sid: Song identifier
offset_difference: (offset - database_offset)
"""
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hsh, offset in hashes:
mapper[hsh.upper()] = offset
# Get an iterable of all the hashes we need
values = list(mapper.keys())
with self.cursor() as cur:
for index in range(0, len(values), batch):
# Create our IN part of the query
query = self.SELECT_MULTIPLE
query = query % ', '.join([self.IN_MATCH] * len(values[index: index + batch]))
cur.execute(query, values[index: index + batch])
for hsh, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hsh])

View file

@ -3,13 +3,114 @@ import queue
import mysql.connector import mysql.connector
from mysql.connector.errors import DatabaseError from mysql.connector.errors import DatabaseError
import dejavu.database_handler.mysql_queries as queries from dejavu.base_classes.common_database import CommonDatabase
from dejavu.database import Database from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
SONGS_TABLENAME)
class MySQLDatabase(Database): class MySQLDatabase(CommonDatabase):
type = "mysql" type = "mysql"
# CREATES
CREATE_SONGS_TABLE = f"""
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
) ENGINE=INNODB;
"""
CREATE_FINGERPRINTS_TABLE = f"""
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
`{FIELD_HASH}` BINARY(10) NOT NULL
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
) ENGINE=INNODB;
"""
# INSERTS (IGNORES DUPLICATES)
INSERT_FINGERPRINT = f"""
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
`{FIELD_SONG_ID}`
, `{FIELD_HASH}`
, `{FIELD_OFFSET}`)
VALUES (%s, UNHEX(%s), %s);
"""
INSERT_SONG = f"""
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
VALUES (%s, UNHEX(%s));
"""
# SELECTS
SELECT = f"""
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
FROM `{FINGERPRINTS_TABLENAME}`
WHERE `{FIELD_HASH}` = UNHEX(%s);
"""
SELECT_MULTIPLE = f"""
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
FROM `{FINGERPRINTS_TABLENAME}`
WHERE `{FIELD_HASH}` IN (%s);
"""
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
SELECT_SONG = f"""
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_SONG_ID}` = %s;
"""
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
SELECT_UNIQUE_SONG_IDS = f"""
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_FINGERPRINTED}` = 1;
"""
SELECT_SONGS = f"""
SELECT
`{FIELD_SONG_ID}`
, `{FIELD_SONGNAME}`
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_FINGERPRINTED}` = 1;
"""
# DROPS
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
# update
UPDATE_SONG_FINGERPRINTED = f"""
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
"""
# DELETE
DELETE_UNFINGERPRINTED = f"""
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
"""
# IN
IN_MATCH = f"UNHEX(%s)"
def __init__(self, **options): def __init__(self, **options):
super().__init__() super().__init__()
self.cursor = cursor_factory(**options) self.cursor = cursor_factory(**options)
@ -20,160 +121,14 @@ class MySQLDatabase(Database):
# the previous process. # the previous process.
Cursor.clear_cache() Cursor.clear_cache()
def setup(self):
"""
Creates any non-existing tables required for dejavu to function.
This also removes all songs that have been added but have no
fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(queries.CREATE_SONGS_TABLE)
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
cur.execute(queries.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Drops tables created by dejavu and then creates them again
by calling `SQLDatabase.setup`.
.. warning:
This will result in a loss of data
"""
with self.cursor() as cur:
cur.execute(queries.DROP_FINGERPRINTS)
cur.execute(queries.DROP_SONGS)
self.setup()
def delete_unfingerprinted_songs(self):
"""
Removes all songs that have no fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(queries.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns number of songs the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
return count
def get_num_fingerprints(self):
"""
Returns number of fingerprints the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
cur.close()
return count
def set_song_fingerprinted(self, sid):
"""
Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database.
"""
with self.cursor() as cur:
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Return songs that have the fingerprinted flag set TRUE (1).
"""
with self.cursor(dictionary=True) as cur:
cur.execute(queries.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Returns song by its ID.
"""
with self.cursor(dictionary=True) as cur:
cur.execute(queries.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, hash, sid, offset):
"""
Insert a (sha1, song_id, offset) row into database.
"""
with self.cursor() as cur:
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
def insert_song(self, song_name, file_hash): def insert_song(self, song_name, file_hash):
""" """
Inserts song in the database and returns the ID of the inserted record. Inserts song in the database and returns the ID of the inserted record.
""" """
with self.cursor() as cur: with self.cursor() as cur:
cur.execute(queries.INSERT_SONG, (song_name, file_hash)) cur.execute(self.INSERT_SONG, (song_name, file_hash))
return cur.lastrowid return cur.lastrowid
def query(self, hash):
"""
Return all tuples associated with hash.
If hash is None, returns all entries in the
database (be careful with that one!).
"""
if hash:
with self.cursor() as cur:
cur.execute(queries.SELECT, (hash,))
for sid, offset in cur:
yield (sid, offset)
else: # select all if no key
with self.cursor() as cur:
cur.execute(queries.SELECT_ALL)
for sid, offset in cur:
yield (sid, offset)
def get_iterable_kv_pairs(self):
"""
Returns all tuples in database.
"""
return self.query(None)
def insert_hashes(self, sid, hashes, batch=1000):
"""
Insert series of hash => song_id, offset
values into the database.
"""
values = [(sid, hash, int(offset)) for hash, offset in hashes]
with self.cursor() as cur:
for index in range(0, len(hashes), batch):
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
def return_matches(self, hashes, batch=1000):
"""
Return the (song_id, offset_diff) tuples associated with
a list of (sha1, sample_offset) values.
"""
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hash, offset in hashes:
mapper[hash.upper()] = offset
# Get an iterable of all the hashes we need
values = list(mapper.keys())
with self.cursor() as cur:
for index in range(0, len(values), batch):
# Create our IN part of the query
query = queries.SELECT_MULTIPLE
query = query % ', '.join(['UNHEX(%s)'] * len(values[index: index + batch]))
cur.execute(query, values[index: index + batch])
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash])
def __getstate__(self): def __getstate__(self):
return self._options, return self._options,

View file

@ -1,134 +0,0 @@
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
SONGS_TABLENAME)
"""
Queries:
1) Find duplicates (shouldn't be any, though):
select `hash`, `song_id`, `offset`, count(*) cnt
from fingerprints
group by `hash`, `song_id`, `offset`
having cnt > 1
order by cnt asc;
2) Get number of hashes by song:
select song_id, song_name, count(song_id) as num
from fingerprints
natural join songs
group by song_id
order by count(song_id) desc;
3) get hashes with highest number of collisions
select
hash,
count(distinct song_id) as n
from fingerprints
group by `hash`
order by n DESC;
=> 26 different songs with same fingerprint (392 times):
select songs.song_name, fingerprints.offset
from fingerprints natural join songs
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
"""
# creates
CREATE_SONGS_TABLE = f"""
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
) ENGINE=INNODB;
"""
CREATE_FINGERPRINTS_TABLE = f"""
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
`{FIELD_HASH}` BINARY(10) NOT NULL
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
) ENGINE=INNODB;
"""
# inserts (ignores duplicates)
INSERT_FINGERPRINT = f"""
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
`{FIELD_SONG_ID}`
, `{FIELD_HASH}`
, `{FIELD_OFFSET}`)
VALUES (%s, UNHEX(%s), %s);
"""
INSERT_SONG = f"""
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
VALUES (%s, UNHEX(%s));
"""
# selects
SELECT = f"""
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
FROM `{FINGERPRINTS_TABLENAME}`
WHERE `{FIELD_HASH}` = UNHEX(%s);
"""
SELECT_MULTIPLE = f"""
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
FROM `{FINGERPRINTS_TABLENAME}`
WHERE `{FIELD_HASH}` IN (%s);
"""
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
SELECT_SONG = f"""
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_SONG_ID}` = %s;
"""
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
SELECT_UNIQUE_SONG_IDS = f"""
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_FINGERPRINTED}` = 1;
"""
SELECT_SONGS = f"""
SELECT
`{FIELD_SONG_ID}`
, `{FIELD_SONGNAME}`
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
FROM `{SONGS_TABLENAME}`
WHERE `{FIELD_FINGERPRINTED}` = 1;
"""
# drops
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
# update
UPDATE_SONG_FINGERPRINTED = f"""
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
"""
# delete
DELETE_UNFINGERPRINTED = f"""
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
"""

View file

@ -3,13 +3,124 @@ import queue
import psycopg2 import psycopg2
from psycopg2.extras import DictCursor from psycopg2.extras import DictCursor
import dejavu.database_handler.postgres_queries as queries from dejavu.base_classes.common_database import CommonDatabase
from dejavu.database import Database from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
SONGS_TABLENAME)
class PostgreSQLDatabase(Database): class PostgreSQLDatabase(CommonDatabase):
type = "postgres" type = "postgres"
# CREATES
CREATE_SONGS_TABLE = f"""
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
"{FIELD_SONG_ID}" SERIAL
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
, "{FIELD_FILE_SHA1}" BYTEA
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
);
"""
CREATE_FINGERPRINTS_TABLE = f"""
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
"{FIELD_HASH}" BYTEA NOT NULL
, "{FIELD_SONG_ID}" INT NOT NULL
, "{FIELD_OFFSET}" INT NOT NULL
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
USING hash ("{FIELD_HASH}");
"""
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
USING hash ("{FIELD_HASH}");
"""
# INSERTS (IGNORES DUPLICATES)
INSERT_FINGERPRINT = f"""
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
"{FIELD_SONG_ID}"
, "{FIELD_HASH}"
, "{FIELD_OFFSET}")
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
"""
INSERT_SONG = f"""
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
VALUES (%s, decode(%s, 'hex'))
RETURNING "{FIELD_SONG_ID}";
"""
# SELECTS
SELECT = f"""
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
FROM "{FINGERPRINTS_TABLENAME}"
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
"""
SELECT_MULTIPLE = f"""
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
FROM "{FINGERPRINTS_TABLENAME}"
WHERE "{FIELD_HASH}" IN (%s);
"""
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
SELECT_SONG = f"""
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_SONG_ID}" = %s;
"""
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
SELECT_UNIQUE_SONG_IDS = f"""
SELECT COUNT("{FIELD_SONG_ID}") AS n
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_FINGERPRINTED}" = 1;
"""
SELECT_SONGS = f"""
SELECT
"{FIELD_SONG_ID}"
, "{FIELD_SONGNAME}"
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_FINGERPRINTED}" = 1;
"""
# DROPS
DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
# UPDATE
UPDATE_SONG_FINGERPRINTED = f"""
UPDATE "{SONGS_TABLENAME}" SET
"{FIELD_FINGERPRINTED}" = 1
, "date_modified" = now()
WHERE "{FIELD_SONG_ID}" = %s;
"""
# DELETE
DELETE_UNFINGERPRINTED = f"""
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
"""
# IN
IN_MATCH = f"decode(%s, 'hex')"
def __init__(self, **options): def __init__(self, **options):
super().__init__() super().__init__()
self.cursor = cursor_factory(**options) self.cursor = cursor_factory(**options)
@ -20,160 +131,14 @@ class PostgreSQLDatabase(Database):
# the previous process. # the previous process.
Cursor.clear_cache() Cursor.clear_cache()
def setup(self):
"""
Creates any non-existing tables required for dejavu to function.
This also removes all songs that have been added but have no
fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(queries.CREATE_SONGS_TABLE)
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
cur.execute(queries.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Drops tables created by dejavu and then creates them again
by calling `SQLDatabase.setup`.
.. warning:
This will result in a loss of data
"""
with self.cursor() as cur:
cur.execute(queries.DROP_FINGERPRINTS)
cur.execute(queries.DROP_SONGS)
self.setup()
def delete_unfingerprinted_songs(self):
"""
Removes all songs that have no fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(queries.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns number of songs the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
return count
def get_num_fingerprints(self):
"""
Returns number of fingerprints the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
cur.close()
return count
def set_song_fingerprinted(self, sid):
"""
Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database.
"""
with self.cursor() as cur:
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Return songs that have the fingerprinted flag set TRUE (1).
"""
with self.cursor(dictionary=True) as cur:
cur.execute(queries.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Returns song by its ID.
"""
with self.cursor(dictionary=True) as cur:
cur.execute(queries.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, hash, sid, offset):
"""
Insert a (sha1, song_id, offset) row into database.
"""
with self.cursor() as cur:
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
def insert_song(self, song_name, file_hash): def insert_song(self, song_name, file_hash):
""" """
Inserts song in the database and returns the ID of the inserted record. Inserts song in the database and returns the ID of the inserted record.
""" """
with self.cursor() as cur: with self.cursor() as cur:
cur.execute(queries.INSERT_SONG, (song_name, file_hash)) cur.execute(self.INSERT_SONG, (song_name, file_hash))
return cur.fetchone()[0] return cur.fetchone()[0]
def query(self, hash):
"""
Return all tuples associated with hash.
If hash is None, returns all entries in the
database (be careful with that one!).
"""
if hash:
with self.cursor() as cur:
cur.execute(queries.SELECT, (hash,))
for sid, offset in cur:
yield (sid, offset)
else: # select all if no key
with self.cursor() as cur:
cur.execute(queries.SELECT_ALL)
for sid, offset in cur:
yield (sid, offset)
def get_iterable_kv_pairs(self):
"""
Returns all tuples in database.
"""
return self.query(None)
def insert_hashes(self, sid, hashes, batch=1000):
"""
Insert series of hash => song_id, offset
values into the database.
"""
values = [(sid, hash, int(offset)) for hash, offset in hashes]
with self.cursor() as cur:
for index in range(0, len(hashes), batch):
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
def return_matches(self, hashes, batch=1000):
"""
Return the (song_id, offset_diff) tuples associated with
a list of (sha1, sample_offset) values.
"""
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hash, offset in hashes:
mapper[hash.upper()] = offset
# Get an iterable of all the hashes we need
values = list(mapper.keys())
with self.cursor() as cur:
for index in range(0, len(values), batch):
# Create our IN part of the query
query = queries.SELECT_MULTIPLE
query = query % ', '.join(["decode(%s, 'hex')"] * len(values[index: index + batch]))
cur.execute(query, values[index: index + batch])
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash.upper()])
def __getstate__(self): def __getstate__(self):
return self._options, return self._options,

View file

@ -1,104 +0,0 @@
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
SONGS_TABLENAME)
# creates
CREATE_SONGS_TABLE = f"""
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
"{FIELD_SONG_ID}" SERIAL
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
, "{FIELD_FILE_SHA1}" BYTEA
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
);
"""
CREATE_FINGERPRINTS_TABLE = f"""
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
"{FIELD_HASH}" BYTEA NOT NULL
, "{FIELD_SONG_ID}" INT NOT NULL
, "{FIELD_OFFSET}" INT NOT NULL
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
"""
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
"""
# inserts (ignores duplicates)
INSERT_FINGERPRINT = f"""
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
"{FIELD_SONG_ID}"
, "{FIELD_HASH}"
, "{FIELD_OFFSET}")
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
"""
INSERT_SONG = f"""
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
VALUES (%s, decode(%s, 'hex'))
RETURNING "{FIELD_SONG_ID}";
"""
# selects
SELECT = f"""
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
FROM "{FINGERPRINTS_TABLENAME}"
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
"""
SELECT_MULTIPLE = f"""
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
FROM "{FINGERPRINTS_TABLENAME}"
WHERE "{FIELD_HASH}" IN (%s);
"""
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
SELECT_SONG = f"""
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_SONG_ID}" = %s;
"""
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
SELECT_UNIQUE_SONG_IDS = f"""
SELECT COUNT("{FIELD_SONG_ID}") AS n
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_FINGERPRINTED}" = 1;
"""
SELECT_SONGS = f"""
SELECT
"{FIELD_SONG_ID}"
, "{FIELD_SONGNAME}"
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
FROM "{SONGS_TABLENAME}"
WHERE "{FIELD_FINGERPRINTED}" = 1;
"""
# drops
DROP_FINGERPRINTS = f'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
DROP_SONGS = f'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
# update
UPDATE_SONG_FINGERPRINTED = f"""
UPDATE "{SONGS_TABLENAME}" SET "{FIELD_FINGERPRINTED}" = 1, "date_modified" = now() WHERE "{FIELD_SONG_ID}" = %s;
"""
# delete
DELETE_UNFINGERPRINTED = f"""
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
"""

0
dejavu/logic/__init__.py Normal file
View file

View file

@ -1,10 +1,12 @@
import os
import fnmatch import fnmatch
import os
from hashlib import sha1
import numpy as np import numpy as np
from pydub import AudioSegment from pydub import AudioSegment
from pydub.utils import audioop from pydub.utils import audioop
from . import wavio
from hashlib import sha1 from dejavu.third_party import wavio
def unique_hash(filepath, blocksize=2**20): def unique_hash(filepath, blocksize=2**20):

View file

@ -9,11 +9,11 @@ from scipy.ndimage.morphology import (binary_erosion,
generate_binary_structure, generate_binary_structure,
iterate_structure) iterate_structure)
from dejavu.config.config import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE, from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
DEFAULT_FS, DEFAULT_OVERLAP_RATIO, DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION, DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA, MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT) PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
IDX_FREQ_I = 0 IDX_FREQ_I = 0
IDX_TIME_J = 1 IDX_TIME_J = 1

View file

View file

@ -0,0 +1,24 @@
import time
import dejavu.logic.decoder as decoder
from dejavu.base_classes.base_recognizer import BaseRecognizer
class FileRecognizer(BaseRecognizer):
def __init__(self, dejavu):
super().__init__(dejavu)
def recognize_file(self, filename):
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
t = time.time()
matches = self._recognize(*frames)
t = time.time() - t
for match in matches:
match['match_time'] = t
return matches
def recognize(self, filename):
return self.recognize_file(filename)

View file

@ -1,45 +1,7 @@
import time
import numpy as np import numpy as np
import pyaudio import pyaudio
import dejavu.decoder as decoder from dejavu.base_classes.base_recognizer import BaseRecognizer
from dejavu.config.config import DEFAULT_FS
class BaseRecognizer(object):
def __init__(self, dejavu):
self.dejavu = dejavu
self.Fs = DEFAULT_FS
def _recognize(self, *data):
matches = []
for d in data:
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
return self.dejavu.align_matches(matches)
def recognize(self):
pass # base class does nothing
class FileRecognizer(BaseRecognizer):
def __init__(self, dejavu):
super().__init__(dejavu)
def recognize_file(self, filename):
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
t = time.time()
matches = self._recognize(*frames)
t = time.time() - t
for match in matches:
match['match_time'] = t
return matches
def recognize(self, filename):
return self.recognize_file(filename)
class MicrophoneRecognizer(BaseRecognizer): class MicrophoneRecognizer(BaseRecognizer):

0
dejavu/tests/__init__.py Normal file
View file

View file

@ -1,130 +1,26 @@
import ast
import fnmatch import fnmatch
import json
import logging import logging
import os
import random import random
import re import re
import subprocess import subprocess
import traceback import traceback
from os import listdir, makedirs, walk
from os.path import basename, exists, isfile, join, splitext
import matplotlib.pyplot as plt
import numpy as np
from pydub import AudioSegment from pydub import AudioSegment
from dejavu import Dejavu from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
from dejavu.decoder import path_to_songname DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
from dejavu.fingerprint import * MATCH_TIME, OFFSET, SONG_NAME)
from dejavu.logic.decoder import path_to_songname
def set_seed(seed=None): class DejavuTest:
"""
`seed` as None means that the sampling will be random.
Setting your own seed means that you can produce the
same experiment over and over.
"""
if seed != None:
random.seed(seed)
def get_files_recursive(src, fmt):
"""
`src` is the source directory.
`fmt` is the extension, ie ".mp3" or "mp3", etc.
"""
for root, dirnames, filenames in os.walk(src):
for filename in fnmatch.filter(filenames, '*' + fmt):
yield os.path.join(root, filename)
def get_length_audio(audiopath, extension):
"""
Returns length of audio in seconds.
Returns None if format isn't supported or in case of error.
"""
try:
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
except:
print(f"Error in get_length_audio(): {traceback.format_exc()}")
return None
return int(len(audio) / 1000.0)
def get_starttime(length, nseconds, padding):
"""
`length` is total audio length in seconds
`nseconds` is amount of time to sample in seconds
`padding` is off-limits seconds at beginning and ending
"""
maximum = length - padding - nseconds
if padding > maximum:
return 0
return random.randint(padding, maximum)
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
"""
Generates a test file for each file recursively in `src` directory
of given format using `nseconds` sampled from the audio file.
Results are written to `dest` directory.
`padding` is the number of off-limit seconds and the beginning and
end of a track that won't be sampled in testing. Often you want to
avoid silence, etc.
"""
# create directories if necessary
for directory in [src, dest]:
try:
os.stat(directory)
except:
os.mkdir(directory)
# find files recursively of a given file format
for fmt in fmts:
testsources = get_files_recursive(src, fmt)
for audiosource in testsources:
print("audiosource:", audiosource)
filename, extension = os.path.splitext(os.path.basename(audiosource))
length = get_length_audio(audiosource, extension)
starttime = get_starttime(length, nseconds, padding)
test_file_name = f"{os.path.join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
subprocess.check_output([
"ffmpeg", "-y",
"-ss", f"{starttime}",
'-t', f"{nseconds}",
"-i", audiosource,
test_file_name])
def log_msg(msg, log=True, silent=False):
if log:
logging.debug(msg)
if not silent:
print(msg)
def autolabel(rects, ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
def autolabeldoubles(rects, ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
ha='center', va='bottom')
class DejavuTest(object):
def __init__(self, folder, seconds): def __init__(self, folder, seconds):
super(DejavuTest, self).__init__() super().__init__()
self.test_folder = folder self.test_folder = folder
self.test_seconds = seconds self.test_seconds = seconds
@ -133,9 +29,10 @@ class DejavuTest(object):
print("test_seconds", self.test_seconds) print("test_seconds", self.test_seconds)
self.test_files = [ self.test_files = [
f for f in os.listdir(self.test_folder) f for f in listdir(self.test_folder)
if os.path.isfile(os.path.join(self.test_folder, f)) if isfile(join(self.test_folder, f))
and re.findall("[0-9]*sec", f)[0] in self.test_seconds] and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds])
]
print("test_files", self.test_files) print("test_files", self.test_files)
@ -154,7 +51,7 @@ class DejavuTest(object):
# variable match precision (if matched in the corrected time) # variable match precision (if matched in the corrected time)
self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
# variable mahing time (query time) # variable matching time (query time)
self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)] self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
# variable confidence # variable confidence
@ -162,12 +59,12 @@ class DejavuTest(object):
self.begin() self.begin()
def get_column_id (self, secs): def get_column_id(self, secs):
for i, sec in enumerate(self.test_seconds): for i, sec in enumerate(self.test_seconds):
if secs == sec: if secs == sec:
return i return i
def get_line_id (self, song): def get_line_id(self, song):
for i, s in enumerate(self.test_songs): for i, s in enumerate(self.test_songs):
if song == s: if song == s:
return i return i
@ -176,7 +73,7 @@ class DejavuTest(object):
def create_plots(self, name, results, results_folder): def create_plots(self, name, results, results_folder):
for sec in range(0, len(self.test_seconds)): for sec in range(0, len(self.test_seconds)):
ind = np.arange(self.n_lines) # ind = np.arange(self.n_lines)
width = 0.25 # the width of the bars width = 0.25 # the width of the bars
fig = plt.figure() fig = plt.figure()
@ -206,7 +103,7 @@ class DejavuTest(object):
plt.grid() plt.grid()
fig_name = os.path.join(results_folder, f"{name}_{self.test_seconds[sec]}.png") fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
fig.savefig(fig_name) fig.savefig(fig_name)
def begin(self): def begin(self):
@ -215,16 +112,18 @@ class DejavuTest(object):
log_msg(f'file: {f}') log_msg(f'file: {f}')
# get column # get column
col = self.get_column_id(re.findall("[0-9]*sec", f)[0]) col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0])
# format: XXXX_offset_length.mp3
song = path_to_songname(f).split("_")[0] # format: XXXX_offset_length.mp3, we also take into account underscores within XXXX
splits = path_to_songname(f).split("_")
song = "_".join(splits[0:len(path_to_songname(f).split("_")) - 2])
line = self.get_line_id(song) line = self.get_line_id(song)
result = subprocess.check_output([ result = subprocess.check_output([
"python", "python",
"dejavu.py", "dejavu.py",
'-r', '-r',
'file', 'file',
self.test_folder + "/" + f]) join(self.test_folder, f)])
if result.strip() == "None": if result.strip() == "None":
log_msg('No match') log_msg('No match')
@ -235,14 +134,12 @@ class DejavuTest(object):
else: else:
result = result.strip() result = result.strip()
result = result.replace(" \'", ' "') # we parse the output song back to a json
result = result.replace("{\'", '{"') result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"'))
result = result.replace("\':", '":')
result = result.replace("\',", '",')
# which song did we predict? # which song did we predict? We consider only the first match.
result = ast.literal_eval(result) result = result[0]
song_result = result["song_name"] song_result = result[SONG_NAME]
log_msg(f'song: {song}') log_msg(f'song: {song}')
log_msg(f'song_result: {song_result}') log_msg(f'song_result: {song_result}')
@ -256,21 +153,22 @@ class DejavuTest(object):
log_msg('correct match') log_msg('correct match')
print(self.result_match) print(self.result_match)
self.result_match[line][col] = 'yes' self.result_match[line][col] = 'yes'
self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3) self.result_query_duration[line][col] = round(result[MATCH_TIME], 3)
self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE] self.result_match_confidence[line][col] = result[CONFIDENCE]
song_start_time = re.findall("_[^_]+", f) # using replace in f for getting rid of underscores in name
song_start_time = re.findall("_[^_]+", f.replace(song, ""))
song_start_time = song_start_time[0].lstrip("_ ") song_start_time = song_start_time[0].lstrip("_ ")
result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE * result_start_time = round((result[OFFSET] * DEFAULT_WINDOW_SIZE *
DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0) DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0)
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time) self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
if abs(self.result_matching_times[line][col]) == 1: if abs(self.result_matching_times[line][col]) == 1:
self.result_matching_times[line][col] = 0 self.result_matching_times[line][col] = 0
log_msg(f'query duration: {round(result[Dejavu.MATCH_TIME], 3)}') log_msg(f'query duration: {round(result[MATCH_TIME], 3)}')
log_msg(f'confidence: {result[Dejavu.CONFIDENCE]}') log_msg(f'confidence: {result[CONFIDENCE]}')
log_msg(f'song start_time: {song_start_time}') log_msg(f'song start_time: {song_start_time}')
log_msg(f'result start time: {result_start_time}') log_msg(f'result start time: {result_start_time}')
@ -279,3 +177,110 @@ class DejavuTest(object):
else: else:
log_msg('inaccurate match') log_msg('inaccurate match')
log_msg('--------------------------------------------------\n') log_msg('--------------------------------------------------\n')
def set_seed(seed=None):
"""
`seed` as None means that the sampling will be random.
Setting your own seed means that you can produce the
same experiment over and over.
"""
if seed is not None:
random.seed(seed)
def get_files_recursive(src, fmt):
"""
`src` is the source directory.
`fmt` is the extension, ie ".mp3" or "mp3", etc.
"""
files = []
for root, dirnames, filenames in walk(src):
for filename in fnmatch.filter(filenames, '*' + fmt):
files.append(join(root, filename))
return files
def get_length_audio(audiopath, extension):
"""
Returns length of audio in seconds.
Returns None if format isn't supported or in case of error.
"""
try:
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
except Exception:
print(f"Error in get_length_audio(): {traceback.format_exc()}")
return None
return int(len(audio) / 1000.0)
def get_starttime(length, nseconds, padding):
"""
`length` is total audio length in seconds
`nseconds` is amount of time to sample in seconds
`padding` is off-limits seconds at beginning and ending
"""
maximum = length - padding - nseconds
if padding > maximum:
return 0
return random.randint(padding, maximum)
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
"""
Generates a test file for each file recursively in `src` directory
of given format using `nseconds` sampled from the audio file.
Results are written to `dest` directory.
`padding` is the number of off-limit seconds and the beginning and
end of a track that won't be sampled in testing. Often you want to
avoid silence, etc.
"""
# create directories if necessary
if not exists(dest):
makedirs(dest)
# find files recursively of a given file format
for fmt in fmts:
testsources = get_files_recursive(src, fmt)
for audiosource in testsources:
print("audiosource:", audiosource)
filename, extension = splitext(basename(audiosource))
length = get_length_audio(audiosource, extension)
starttime = get_starttime(length, nseconds, padding)
test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
subprocess.check_output([
"ffmpeg", "-y",
"-ss", f"{starttime}",
'-t', f"{nseconds}",
"-i", audiosource,
test_file_name])
def log_msg(msg, log=True, silent=False):
if log:
logging.debug(msg)
if not silent:
print(msg)
def autolabel(rects, ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
def autolabeldoubles(rects, ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
ha='center', va='bottom')

0
dejavu/third_party/__init__.py vendored Normal file
View file

357
dejavu/third_party/wavio.py vendored Normal file
View file

@ -0,0 +1,357 @@
# wavio.py
# Author: Warren Weckesser
# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause)
# Synopsis: A Python module for reading and writing 24 bit WAV files.
# Github: github.com/WarrenWeckesser/wavio
"""
The wavio module defines the functions:
read(file)
Read a WAV file and return a `wavio.Wav` object, with attributes
`data`, `rate` and `sampwidth`.
write(filename, data, rate, scale=None, sampwidth=None)
Write a numpy array to a WAV file.
-----
Author: Warren Weckesser
License: BSD 2-Clause:
Copyright (c) 2015, Warren Weckesser
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import wave as _wave
import numpy as _np
__version__ = "0.0.5.dev1"
def _wav2array(nchannels, sampwidth, data):
"""data must be the string containing the bytes from the wav file."""
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
if remainder > 0:
raise ValueError('The length of data is not a multiple of '
'sampwidth * num_channels.')
if sampwidth > 4:
raise ValueError("sampwidth must not be greater than 4.")
if sampwidth == 3:
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
raw_bytes = _np.frombuffer(data, dtype=_np.uint8)
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
result = a.view('<i4').reshape(a.shape[:-1])
else:
# 8 bit samples are stored as unsigned ints; others as signed ints.
dt_char = 'u' if sampwidth == 1 else 'i'
a = _np.frombuffer(data, dtype=f'<{dt_char}{sampwidth}')
result = a.reshape(-1, nchannels)
return result
def _array2wav(a, sampwidth):
"""
Convert the input array `a` to a string of WAV data.
a.dtype must be one of uint8, int16 or int32. Allowed sampwidth
values are:
dtype sampwidth
uint8 1
int16 2
int32 3 or 4
When sampwidth is 3, the *low* bytes of `a` are assumed to contain
the values to include in the string.
"""
if sampwidth == 3:
# `a` must have dtype int32
if a.ndim == 1:
# Convert to a 2D array with a single column.
a = a.reshape(-1, 1)
# By shifting first 0 bits, then 8, then 16, the resulting output
# is 24 bit little-endian.
a8 = (a.reshape(a.shape + (1,)) >> _np.array([0, 8, 16])) & 255
wavdata = a8.astype(_np.uint8).tostring()
else:
# Make sure the array is little-endian, and then convert using
# tostring()
a = a.astype('<' + a.dtype.str[1:], copy=False)
wavdata = a.tostring()
return wavdata
class Wav(object):
"""
Object returned by `wavio.read`. Attributes are:
data : numpy array
The array of data read from the WAV file.
rate : float
The sample rate of the WAV file.
sampwidth : int
The sample width (i.e. number of bytes per sample) of the WAV file.
For example, `sampwidth == 3` is a 24 bit WAV file.
"""
def __init__(self, data, rate, sampwidth):
self.data = data
self.rate = rate
self.sampwidth = sampwidth
def __repr__(self):
s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, "
f"rate={self.rate}, sampwidth={self.sampwidth})")
return s
def read(file):
"""
Read a WAV file.
Parameters
----------
file : string or file object
Either the name of a file or an open file pointer.
Returns
-------
wav : wavio.Wav() instance
The return value is an instance of the class `wavio.Wav`,
with the following attributes:
data : numpy array
The array containing the data. The shape of the array
is (num_samples, num_channels). num_channels is the
number of audio channels (1 for mono, 2 for stereo).
rate : float
The sampling frequency (i.e. frame rate)
sampwidth : float
The sample width, in bytes. E.g. for a 24 bit WAV file,
sampwidth is 3.
Notes
-----
This function uses the `wave` module of the Python standard libary
to read the WAV file, so it has the same limitations as that library.
In particular, the function does not read compressed WAV files, and
it does not read files with floating point data.
The array returned by `wavio.read` is always two-dimensional. If the
WAV data is mono, the array will have shape (num_samples, 1).
`wavio.read()` does not scale or normalize the data. The data in the
array `wav.data` is the data that was in the file. When the file
contains 24 bit samples, the resulting numpy array is 32 bit integers,
with values that have been sign-extended.
"""
wav = _wave.open(file)
rate = wav.getframerate()
nchannels = wav.getnchannels()
sampwidth = wav.getsampwidth()
nframes = wav.getnframes()
data = wav.readframes(nframes)
wav.close()
array = _wav2array(nchannels, sampwidth, data)
w = Wav(data=array, rate=rate, sampwidth=sampwidth)
return w
_sampwidth_dtypes = {1: _np.uint8,
2: _np.int16,
3: _np.int32,
4: _np.int32}
_sampwidth_ranges = {1: (0, 256),
2: (-2**15, 2**15),
3: (-2**23, 2**23),
4: (-2**31, 2**31)}
def _scale_to_sampwidth(data, sampwidth, vmin, vmax):
# Scale and translate the values to fit the range of the data type
# associated with the given sampwidth.
data = data.clip(vmin, vmax)
dt = _sampwidth_dtypes[sampwidth]
if vmax == vmin:
data = _np.zeros(data.shape, dtype=dt)
else:
outmin, outmax = _sampwidth_ranges[sampwidth]
if outmin != vmin or outmax != vmax:
vmin = float(vmin)
vmax = float(vmax)
data = (float(outmax - outmin) * (data - vmin) /
(vmax - vmin)).astype(_np.int64) + outmin
data[data == outmax] = outmax - 1
data = data.astype(dt)
return data
def write(file, data, rate, scale=None, sampwidth=None):
"""
Write the numpy array `data` to a WAV file.
The Python standard library "wave" is used to write the data
to the file, so this function has the same limitations as that
module. In particular, the Python library does not support
floating point data. When given a floating point array, this
function converts the values to integers. See below for the
conversion rules.
Parameters
----------
file : string, or file object open for writing in binary mode
Either the name of a file or an open file pointer.
data : numpy array, 1- or 2-dimensional, integer or floating point
If it is 2-d, the rows are the frames (i.e. samples) and the
columns are the channels.
rate : float
The sampling frequency (i.e. frame rate) of the data.
sampwidth : int, optional
The sample width, in bytes, of the output file.
If `sampwidth` is not given, it is inferred (if possible) from
the data type of `data`, as follows::
data.dtype sampwidth
---------- ---------
uint8, int8 1
uint16, int16 2
uint32, int32 4
For any other data types, or to write a 24 bit file, `sampwidth`
must be given.
scale : tuple or str, optional
By default, the data written to the file is scaled up or down to
occupy the full range of the output data type. So, for example,
the unsigned 8 bit data [0, 1, 2, 15] would be written to the file
as [0, 17, 30, 255]. More generally, the default behavior is
(roughly)::
vmin = data.min()
vmax = data.max()
outmin = <minimum integer of the output dtype>
outmax = <maximum integer of the output dtype>
outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin
The `scale` argument allows the scaling of the output data to be
changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which
case the given values override the use of `data.min()` and
`data.max()` for `vmin` and `vmax` shown above. (If either value
is `None`, the value shown above is used.) Data outside the
range (vmin, vmax) is clipped. If `vmin == vmax`, the output is
all zeros.
If `scale` is the string "none", then `vmin` and `vmax` are set to
`outmin` and `outmax`, respectively. This means the data is written
to the file with no scaling. (Note: `scale="none" is not the same
as `scale=None`. The latter means "use the default behavior",
which is to scale by the data minimum and maximum.)
If `scale` is the string "dtype-limits", then `vmin` and `vmax`
are set to the minimum and maximum integers of `data.dtype`.
The string "dtype-limits" is not allowed when the `data` is a
floating point array.
If using `scale` results in values that exceed the limits of the
output sample width, the data is clipped. For example, the
following code::
>> x = np.array([-100, 0, 100, 200, 300, 325])
>> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1)
will write the values [0, 0, 100, 200, 255, 255] to the file.
Example
-------
Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file.
>> import numpy as np
>> import wavio
>> rate = 22050 # samples per second
>> T = 3 # sample duration (seconds)
>> f = 440.0 # sound frequency (Hz)
>> t = np.linspace(0, T, T*rate, endpoint=False)
>> x = np.sin(2*np.pi * f * t)
>> wavio.write("sine24.wav", x, rate, sampwidth=3)
Create a file that contains the 16 bit integer values -10000 and 10000
repeated 100 times. Don't automatically scale the values. Use a sample
rate 8000.
>> x = np.empty(200, dtype=np.int16)
>> x[::2] = -10000
>> x[1::2] = 10000
>> wavio.write("foo.wav", x, 8000, scale='none')
Check that the file contains what we expect.
>> w = wavio.read("foo.wav")
>> np.all(w.data[:, 0] == x)
True
In the following, the values -10000 and 10000 (from within the 16 bit
range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and
168 (in the range [0, 2**8-1]).
>> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits')
>> w = wavio.read("foo.wav")
>> w.data[:4, 0]
array([ 88, 168, 88, 168], dtype=uint8)
"""
if sampwidth is None:
if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4:
raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.')
sampwidth = data.itemsize
else:
if sampwidth not in [1, 2, 3, 4]:
raise ValueError('sampwidth must be 1, 2, 3 or 4.')
outdtype = _sampwidth_dtypes[sampwidth]
outmin, outmax = _sampwidth_ranges[sampwidth]
if scale == "none":
data = data.clip(outmin, outmax-1).astype(outdtype)
elif scale == "dtype-limits":
if not _np.issubdtype(data.dtype, _np.integer):
raise ValueError("scale cannot be 'dtype-limits' with non-integer data.")
# Easy transforms that just changed the signedness of the data.
if sampwidth == 1 and data.dtype == _np.int8:
data = (data.astype(_np.int16) + 128).astype(_np.uint8)
elif sampwidth == 2 and data.dtype == _np.uint16:
data = (data.astype(_np.int32) - 32768).astype(_np.int16)
elif sampwidth == 4 and data.dtype == _np.uint32:
data = (data.astype(_np.int64) - 2**31).astype(_np.int32)
elif data.itemsize != sampwidth:
# Integer input, but rescaling is needed to adjust the
# input range to the output sample width.
ii = _np.iinfo(data.dtype)
vmin = ii.min
vmax = ii.max
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
else:
if scale is None:
vmin = data.min()
vmax = data.max()
else:
# scale must be a tuple of the form (vmin, vmax)
vmin, vmax = scale
if vmin is None:
vmin = data.min()
if vmax is None:
vmax = data.max()
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
# At this point, `data` has been converted to have one of the following:
# sampwidth dtype
# --------- -----
# 1 uint8
# 2 int16
# 3 int32
# 4 int32
# The values in `data` are in the form in which they will be saved;
# no more scaling will take place.
if data.ndim == 1:
data = data.reshape(-1, 1)
wavdata = _array2wav(data, sampwidth)
w = _wave.open(file, 'wb')
w.setnchannels(data.shape[1])
w.setsampwidth(sampwidth)
w.setframerate(rate)
w.writeframes(wavdata)
w.close()

View file

@ -1,122 +0,0 @@
# wavio.py
# Author: Warren Weckesser
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause)
# Synopsis: A Python module for reading and writing 24 bit WAV files.
# Github: github.com/WarrenWeckesser/wavio
import wave as _wave
import numpy as _np
def _wav2array(nchannels, sampwidth, data):
"""data must be the string containing the bytes from the wav file."""
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
if remainder > 0:
raise ValueError('The length of data is not a multiple of '
'sampwidth * num_channels.')
if sampwidth > 4:
raise ValueError("sampwidth must not be greater than 4.")
if sampwidth == 3:
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
raw_bytes = _np.fromstring(data, dtype=_np.uint8)
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
result = a.view('<i4').reshape(a.shape[:-1])
else:
# 8 bit samples are stored as unsigned ints; others as signed ints.
dt_char = 'u' if sampwidth == 1 else 'i'
a = _np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
result = a.reshape(-1, nchannels)
return result
def readwav(file):
"""
Read a WAV file.
Parameters
----------
file : string or file object
Either the name of a file or an open file pointer.
Return Values
-------------
rate : float
The sampling frequency (i.e. frame rate)
sampwidth : float
The sample width, in bytes. E.g. for a 24 bit WAV file,
sampwidth is 3.
data : numpy array
The array containing the data. The shape of the array is
(num_samples, num_channels). num_channels is the number of
audio channels (1 for mono, 2 for stereo).
Notes
-----
This function uses the `wave` module of the Python standard libary
to read the WAV file, so it has the same limitations as that library.
In particular, the function does not read compressed WAV files.
"""
wav = _wave.open(file)
rate = wav.getframerate()
nchannels = wav.getnchannels()
sampwidth = wav.getsampwidth()
nframes = wav.getnframes()
data = wav.readframes(nframes)
wav.close()
array = _wav2array(nchannels, sampwidth, data)
return rate, sampwidth, array
def writewav24(filename, rate, data):
"""
Create a 24 bit wav file.
Parameters
----------
filename : string
Name of the file to create.
rate : float
The sampling frequency (i.e. frame rate) of the data.
data : array-like collection of integer or floating point values
data must be "array-like", either 1- or 2-dimensional. If it
is 2-d, the rows are the frames (i.e. samples) and the columns
are the channels.
Notes
-----
The data is assumed to be signed, and the values are assumed to be
within the range of a 24 bit integer. Floating point values are
converted to integers. The data is not rescaled or normalized before
writing it to the file.
Example
-------
Create a 3 second 440 Hz sine wave.
>>> rate = 22050 # samples per second
>>> T = 3 # sample duration (seconds)
>>> f = 440.0 # sound frequency (Hz)
>>> t = np.linspace(0, T, T*rate, endpoint=False)
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t)
>>> writewav24("sine24.wav", rate, x)
"""
a32 = _np.asarray(data, dtype=_np.int32)
if a32.ndim == 1:
# Convert to a 2D array with a single column.
a32.shape = a32.shape + (1,)
# By shifting first 0 bits, then 8, then 16, the resulting output
# is 24 bit little-endian.
a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255
wavdata = a8.astype(_np.uint8).tostring()
w = _wave.open(filename, 'wb')
w.setnchannels(a32.shape[1])
w.setsampwidth(3)
w.setframerate(rate)
w.writeframes(wavdata)
w.close()

View file

@ -1,11 +1,8 @@
import json import json
import warnings
from dejavu import Dejavu from dejavu import Dejavu
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
warnings.filterwarnings("ignore")
# load config from a JSON file (or anything outputting a python dictionary) # load config from a JSON file (or anything outputting a python dictionary)
with open("dejavu.cnf.SAMPLE") as f: with open("dejavu.cnf.SAMPLE") as f:
@ -17,10 +14,10 @@ if __name__ == '__main__':
djv = Dejavu(config) djv = Dejavu(config)
# Fingerprint all the mp3's in the directory we give it # Fingerprint all the mp3's in the directory we give it
djv.fingerprint_directory("mp3", [".mp3"]) djv.fingerprint_directory("test", [".wav"])
# Recognize audio from a file # Recognize audio from a file
song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3") song = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
print(f"From file we recognized: {song}\n") print(f"From file we recognized: {song}\n")
# Or recognize audio from your microphone for `secs` seconds # Or recognize audio from your microphone for `secs` seconds
@ -29,7 +26,7 @@ if __name__ == '__main__':
if song is None: if song is None:
print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)") print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)")
else: else:
print(f"From mic with %d seconds we recognized: {(secs, song)}\n") print(f"From mic with {secs} seconds we recognized: {song}\n")
# Or use a recognizer without the shortcut, in anyway you would like # Or use a recognizer without the shortcut, in anyway you would like
recognizer = FileRecognizer(djv) recognizer = FileRecognizer(djv)

BIN
mp3/azan_test.wav Normal file

Binary file not shown.

View file

@ -5,3 +5,4 @@ scipy==1.3.1
matplotlib==3.1.1 matplotlib==3.1.1
mysql-connector-python==8.0.17 mysql-connector-python==8.0.17
psycopg2==2.8.3 psycopg2==2.8.3

View file

@ -1,184 +1,166 @@
from dejavu.testing import * import argparse
from dejavu import Dejavu import logging
from optparse import OptionParser
import matplotlib.pyplot as plt
import time import time
import shutil from os import makedirs
from os.path import exists, join
from shutil import rmtree
usage = "usage: %prog [options] TESTING_AUDIOFOLDER" import matplotlib.pyplot as plt
parser = OptionParser(usage=usage, version="%prog 1.1") import numpy as np
parser.add_option("--secs",
action="store",
dest="secs",
default=5,
type=int,
help='Number of seconds starting from zero to test')
parser.add_option("--results",
action="store",
dest="results_folder",
default="./dejavu_test_results",
help='Sets the path where the results are saved')
parser.add_option("--temp",
action="store",
dest="temp_folder",
default="./dejavu_temp_testing_files",
help='Sets the path where the temp files are saved')
parser.add_option("--log",
action="store_true",
dest="log",
default=True,
help='Enables logging')
parser.add_option("--silent",
action="store_false",
dest="silent",
default=False,
help='Disables printing')
parser.add_option("--log-file",
dest="log_file",
default="results-compare.log",
help='Set the path and filename of the log file')
parser.add_option("--padding",
action="store",
dest="padding",
default=10,
type=int,
help='Number of seconds to pad choice of place to test from')
parser.add_option("--seed",
action="store",
dest="seed",
default=None,
type=int,
help='Random seed')
options, args = parser.parse_args()
test_folder = args[0]
# set random seed if set by user from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles,
set_seed(options.seed) generate_test_files, log_msg, set_seed)
# ensure results folder exists
try:
os.stat(options.results_folder)
except:
os.mkdir(options.results_folder)
# set logging def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool,
if options.log: log_file: str, padding: int, seed: int, src: str):
logging.basicConfig(filename=options.log_file, level=logging.DEBUG)
# set test seconds # set random seed if set by user
test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)] set_seed(seed)
# generate testing files # ensure results folder exists
for i in range(1, options.secs + 1, 1): if not exists(results_folder):
generate_test_files(test_folder, options.temp_folder, makedirs(results_folder)
i, padding=options.padding)
# scan files # set logging
log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder, if log:
log=options.log, silent=options.silent) logging.basicConfig(filename=log_file, level=logging.DEBUG)
tm = time.time() # set test seconds
djv = DejavuTest(options.temp_folder, test_seconds) test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)]
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm),
log=options.log, silent=options.silent)
tests = 1 # djv # generate testing files
n_secs = len(test_seconds) for i in range(1, seconds + 1, 1):
generate_test_files(src, temp_folder, i, padding=padding)
# set result variables -> 4d variables # scan files
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)] log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent)
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
# group results by seconds tm = time.time()
for line in range(0, djv.n_lines): djv = DejavuTest(temp_folder, test_seconds)
for col in range(0, djv.n_columns): log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent)
# for dejavu
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
djv_match_result = djv.result_match[line][col] tests = 1 # djv
n_secs = len(test_seconds)
if djv_match_result == 'yes': # set result variables -> 4d variables
all_match_counter[col][0][0] += 1 all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
elif djv_match_result == 'no': all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
all_match_counter[col][1][0] += 1 all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
else: all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
all_match_counter[col][2][0] += 1
djv_match_acc = djv.result_matching_times[line][col] # group results by seconds
for line in range(0, djv.n_lines):
for col in range(0, djv.n_columns):
# for dejavu
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
if djv_match_acc == 0 and djv_match_result == 'yes': djv_match_result = djv.result_match[line][col]
all_matching_times_counter[col][0][0] += 1
elif djv_match_acc != 0:
all_matching_times_counter[col][1][0] += 1
# create plots if djv_match_result == 'yes':
djv.create_plots('Confidence', all_match_confidence, options.results_folder) all_match_counter[col][0][0] += 1
djv.create_plots('Query duration', all_query_duration, options.results_folder) elif djv_match_result == 'no':
all_match_counter[col][1][0] += 1
else:
all_match_counter[col][2][0] += 1
for sec in range(0, n_secs): djv_match_acc = djv.result_matching_times[line][col]
ind = np.arange(3) #
width = 0.25 # the width of the bars
fig = plt.figure() if djv_match_acc == 0 and djv_match_result == 'yes':
ax = fig.add_subplot(111) all_matching_times_counter[col][0][0] += 1
ax.set_xlim([-1 * width, 2.75]) elif djv_match_acc != 0:
all_matching_times_counter[col][1][0] += 1
means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]] # create plots
rects1 = ax.bar(ind, means_dvj, width, color='r') djv.create_plots('Confidence', all_match_confidence, results_folder)
djv.create_plots('Query duration', all_query_duration, results_folder)
# add some for sec in range(0, n_secs):
ax.set_ylabel('Matching Percentage') ind = np.arange(3)
ax.set_title('%s Matching Percentage' % test_seconds[sec]) width = 0.25 # the width of the bars
ax.set_xticks(ind + width)
labels = ['yes','no','invalid'] fig = plt.figure()
ax.set_xticklabels( labels ) ax = fig.add_subplot(111)
ax.set_xlim([-1 * width, 2.75])
box = ax.get_position() means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]]
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) rects1 = ax.bar(ind, means_dvj, width, color='r')
#ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
autolabeldoubles(rects1,ax)
plt.grid()
fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec]) # add some
fig.savefig(fig_name) ax.set_ylabel('Matching Percentage')
ax.set_title(f'{test_seconds[sec]} Matching Percentage')
ax.set_xticks(ind + width)
for sec in range(0, n_secs): labels = ['yes', 'no', 'invalid']
ind = np.arange(2) # ax.set_xticklabels(labels)
width = 0.25 # the width of the bars
fig = plt.figure() box = ax.get_position()
ax = fig.add_subplot(111) ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
ax.set_xlim([-1*width, 1.75]) autolabeldoubles(rects1, ax)
plt.grid()
div = all_match_counter[sec][0][0] fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png")
if div == 0 : fig.savefig(fig_name)
div = 1000000
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]] for sec in range(0, n_secs):
rects1 = ax.bar(ind, means_dvj, width, color='r') ind = np.arange(2)
width = 0.25 # the width of the bars
# add some fig = plt.figure()
ax.set_ylabel('Matching Accuracy') ax = fig.add_subplot(111)
ax.set_title('%s Matching Times Accuracy' % test_seconds[sec]) ax.set_xlim([-1 * width, 1.75])
ax.set_xticks(ind + width)
labels = ['yes','no'] div = all_match_counter[sec][0][0]
ax.set_xticklabels( labels ) if div == 0:
div = 1000000
box = ax.get_position() means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height]) rects1 = ax.bar(ind, means_dvj, width, color='r')
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5)) # add some
autolabeldoubles(rects1,ax) ax.set_ylabel('Matching Accuracy')
ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy')
ax.set_xticks(ind + width)
plt.grid() labels = ['yes', 'no']
ax.set_xticklabels(labels)
fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec]) box = ax.get_position()
fig.savefig(fig_name) ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
autolabeldoubles(rects1, ax)
# remove temporary folder plt.grid()
shutil.rmtree(options.temp_folder)
fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png")
fig.savefig(fig_name)
# remove temporary folder
rmtree(temp_folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate '
f'its configuration performance. '
f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER'
)
parser.add_argument("-sec", "--seconds", action="store", default=5, type=int,
help='Number of seconds starting from zero to test.')
parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results",
help='Sets the path where the results are saved.')
parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files",
help='Sets the path where the temp files are saved.')
parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.')
parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.')
parser.add_argument("-lf", "--log-file", default="results-compare.log",
help='Set the path and filename of the log file.')
parser.add_argument("-pad", "--padding", action="store", default=10, type=int,
help='Number of seconds to pad choice of place to test from.')
parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.')
parser.add_argument("src", type=str, help='Source folder for audios to use as tests.')
args = parser.parse_args()
main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding,
args.seed, args.src)

3
setup.cfg Normal file
View file

@ -0,0 +1,3 @@
[flake8]
max-line-length = 120

View file

@ -1,5 +1,4 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
# import os, sys
def parse_requirements(requirements): def parse_requirements(requirements):
@ -14,6 +13,7 @@ def parse_requirements(requirements):
reqs = list(filter((lambda x: x), nocomments)) reqs = list(filter((lambda x: x), nocomments))
return reqs return reqs
PACKAGE_NAME = "PyDejavu" PACKAGE_NAME = "PyDejavu"
PACKAGE_VERSION = "0.1.3" PACKAGE_VERSION = "0.1.3"
SUMMARY = 'Dejavu: Audio Fingerprinting in Python' SUMMARY = 'Dejavu: Audio Fingerprinting in Python'

BIN
test/sean_secs.wav Normal file

Binary file not shown.

BIN
test/woodward_43s.wav Normal file

Binary file not shown.