mirror of
https://github.com/correl/dejavu.git
synced 2024-11-27 11:09:51 +00:00
356 lines
11 KiB
Python
356 lines
11 KiB
Python
|
# result generator for dejavu
|
||
|
|
||
|
# TODO: Don't work very well with musics with special chars.
|
||
|
# use test file on the format below, with no special chars and only one "-" to separate artist from song
|
||
|
|
||
|
import os, subprocess, json, re, sys
|
||
|
import logging, time
|
||
|
from os import listdir
|
||
|
from os.path import isfile, join
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.animation as animation
|
||
|
from optparse import OptionParser
|
||
|
|
||
|
#####
|
||
|
### Test files are in specific format:
|
||
|
### 'artist_name'-'song_name'_'start_time'_'duration'sec.wav
|
||
|
#####
|
||
|
|
||
|
DEFAULT_FS = 44100
|
||
|
DEFAULT_WINDOW_SIZE = 4096
|
||
|
DEFAULT_OVERLAP_RATIO = 0.5
|
||
|
|
||
|
#FIELD_SONG_ID = 'song_id'
|
||
|
FIELD_SONG_NAME = 'song_name'
|
||
|
FIELD_CONFIDENCE = 'confidence'
|
||
|
FIELD_QUERY_TIME = 'match_time'
|
||
|
FIELD_OFFSET = 'offset'
|
||
|
|
||
|
# Parse options
|
||
|
|
||
|
usage = "usage: %prog [options] DEJAVU_PATH TEST_FOLDER"
|
||
|
parser = OptionParser(usage=usage, version="%prog 1.1")
|
||
|
|
||
|
parser.add_option("--no-log",
|
||
|
action="store_false",
|
||
|
dest="log",
|
||
|
default=True,
|
||
|
help='Disables logging'
|
||
|
)
|
||
|
parser.add_option("--log-file",
|
||
|
dest="log_file",
|
||
|
default="results-compare.log",
|
||
|
metavar="LOG_FILE",
|
||
|
help='Set the path and filename of the log file'
|
||
|
)
|
||
|
parser.add_option("--test-seconds",
|
||
|
action="append",
|
||
|
dest="test_seconds",
|
||
|
default=[],
|
||
|
metavar="Xsec",
|
||
|
help='Appends seconds to test suit'
|
||
|
)
|
||
|
parser.add_option("--results-folder",
|
||
|
action="store",
|
||
|
dest="results_folder",
|
||
|
metavar="FOLDER",
|
||
|
help='Sets the path where the results are saved'
|
||
|
)
|
||
|
|
||
|
(options, args) = parser.parse_args()
|
||
|
|
||
|
if len(args) != 2:
|
||
|
parser.error("wrong number of arguments")
|
||
|
|
||
|
if len(options.test_seconds) == 0:
|
||
|
options.test_seconds = ['1sec','2sec','3sec','4sec','5sec','6sec','7sec','8sec','9sec','10sec']
|
||
|
|
||
|
if options.log == True:
|
||
|
logging.basicConfig( filename=options.log_file, level=logging.DEBUG )
|
||
|
|
||
|
if options.results_folder != "" and options.results_folder[len(options.results_folder)-1] != '/':
|
||
|
options.results_folder += "/"
|
||
|
|
||
|
def log_msg(msg):
|
||
|
if options.log == True:
|
||
|
logging.debug(msg)
|
||
|
|
||
|
class DejavuTest (object):
|
||
|
def __init__(self, folder, seconds):
|
||
|
super(DejavuTest, self).__init__()
|
||
|
|
||
|
self.test_folder = folder
|
||
|
self.test_seconds = seconds
|
||
|
self.test_songs = []
|
||
|
self.test_files = [ f for f in listdir(self.test_folder) if isfile(join(self.test_folder,f))
|
||
|
and re.findall("[0-9]*sec",f)[0] in self.test_seconds ]
|
||
|
self.n_columns = len(self.test_seconds)
|
||
|
self.n_lines = len(self.test_files) / self.n_columns
|
||
|
|
||
|
# variable match results (yes, no, invalid)
|
||
|
self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||
|
|
||
|
# variable match precision (if matched in the corrected time)
|
||
|
self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||
|
|
||
|
# variable mahing time (query time)
|
||
|
self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||
|
|
||
|
# variable confidence
|
||
|
self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||
|
|
||
|
self.begin()
|
||
|
|
||
|
def get_column_id ( self,secs ):
|
||
|
for i, sec in enumerate(self.test_seconds):
|
||
|
if secs == sec:
|
||
|
return i
|
||
|
|
||
|
def get_line_id ( self,artist, song ):
|
||
|
elem = artist + " - " + song
|
||
|
|
||
|
for i, s in enumerate(self.test_songs):
|
||
|
if elem == s:
|
||
|
return i
|
||
|
|
||
|
self.test_songs.append(elem)
|
||
|
return len(self.test_songs)-1
|
||
|
|
||
|
def begin(self):
|
||
|
for f in self.test_files:
|
||
|
log_msg('--------------------------------------------------')
|
||
|
log_msg('file: %s' % f)
|
||
|
# get column
|
||
|
col = self.get_column_id(re.findall("[0-9]*sec",f)[0])
|
||
|
|
||
|
# get artist and song
|
||
|
artist = re.findall("^[^\-]+",f)
|
||
|
artist = artist[0].rstrip()
|
||
|
|
||
|
song = re.findall("\-[^\_]+",f)
|
||
|
song = song[0].lstrip("- ")
|
||
|
|
||
|
line = self.get_line_id ( artist, song)
|
||
|
|
||
|
result = subprocess.check_output([args[0], 'recognize', 'file', self.test_folder+"/"+f])
|
||
|
log_msg('RESULT: %s' % result.strip() )
|
||
|
|
||
|
if result.strip() == "None":
|
||
|
log_msg('No match')
|
||
|
self.result_match[line][col] = 'no'
|
||
|
self.result_matching_times[line][col] = 0
|
||
|
self.result_query_duration[line][col] = 0
|
||
|
self.result_match_confidence[line][col] = 0
|
||
|
|
||
|
else:
|
||
|
result = result.strip()
|
||
|
result = result.replace(" \'", ' "')
|
||
|
result = result.replace("{\'", '{"')
|
||
|
result = result.replace("\':", '":')
|
||
|
result = result.replace("\',", '",')
|
||
|
|
||
|
result = json.loads(result)
|
||
|
|
||
|
artist_result = re.findall("^[^\-]+",result[FIELD_SONG_NAME])
|
||
|
artist_result = artist_result[0].rstrip()
|
||
|
|
||
|
song_result = re.findall("\-[^\_]+",result[FIELD_SONG_NAME])
|
||
|
song_result = song_result[0].lstrip("- ")
|
||
|
|
||
|
log_msg('artist: %s' % artist)
|
||
|
log_msg('artist_result: %s' % artist_result)
|
||
|
log_msg('song: %s' % song)
|
||
|
log_msg('song_result: %s' % song_result)
|
||
|
|
||
|
if artist_result != artist or song_result != song:
|
||
|
log_msg('invalid match')
|
||
|
self.result_match[line][col] = 'invalid'
|
||
|
self.result_matching_times[line][col] = 0
|
||
|
self.result_query_duration[line][col] = 0
|
||
|
self.result_match_confidence[line][col] = 0
|
||
|
else:
|
||
|
log_msg('correct match')
|
||
|
self.result_match[line][col] = 'yes'
|
||
|
self.result_query_duration[line][col] = round(result[FIELD_QUERY_TIME],3)
|
||
|
self.result_match_confidence[line][col] = result[FIELD_CONFIDENCE]
|
||
|
|
||
|
song_start_time = re.findall("\_[^\_]+",f)
|
||
|
song_start_time = song_start_time[0].lstrip("_ ")
|
||
|
|
||
|
#result_start_time = round((result[FIELD_SONG_DURATION] * result[FIELD_OFFSET]) / float(result[FIELD_SONG_SPEC_DURATION]), 0)
|
||
|
result_start_time = round((result[FIELD_OFFSET] * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS),0)
|
||
|
|
||
|
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||
|
if (abs(self.result_matching_times[line][col]) == 1):
|
||
|
self.result_matching_times[line][col] = 0
|
||
|
|
||
|
log_msg('query duration: %s' % round(result[FIELD_QUERY_TIME],3))
|
||
|
log_msg('confidence: %s' % result[FIELD_CONFIDENCE])
|
||
|
log_msg('song start_time: %s' % song_start_time)
|
||
|
log_msg('result start time: %s' % result_start_time)
|
||
|
if (self.result_matching_times[line][col] == 0):
|
||
|
log_msg('accurate match')
|
||
|
else:
|
||
|
log_msg('inaccurate match')
|
||
|
log_msg('--------------------------------------------------\n')
|
||
|
|
||
|
print "obtaining results from dejavu"
|
||
|
log_msg('obtaining results from dejavu')
|
||
|
tm = time.time()
|
||
|
djv = DejavuTest(args[1], options.test_seconds)
|
||
|
print "finished obtaining results from dejavu in %s" % (time.time() - tm)
|
||
|
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm))
|
||
|
|
||
|
tests_n_lines = djv.n_lines
|
||
|
tests_n_columns = djv.n_columns # len(options.test_seconds)
|
||
|
tests = 1 # djv
|
||
|
n_secs = len(options.test_seconds) # = tests.n_columns
|
||
|
|
||
|
# set result variables -> 4d variables
|
||
|
all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)]
|
||
|
all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)]
|
||
|
all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
|
||
|
all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
|
||
|
|
||
|
# agroup results by seconds
|
||
|
for line in range(0, tests_n_lines):
|
||
|
for col in range(0, tests_n_columns):
|
||
|
# for dejavu
|
||
|
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||
|
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||
|
|
||
|
djv_match_result = djv.result_match[line][col]
|
||
|
|
||
|
if djv_match_result == 'yes':
|
||
|
all_match_counter[col][0][0] += 1
|
||
|
elif djv_match_result == 'no':
|
||
|
all_match_counter[col][1][0] += 1
|
||
|
else:
|
||
|
all_match_counter[col][2][0] += 1
|
||
|
|
||
|
djv_match_acc = djv.result_matching_times[line][col]
|
||
|
|
||
|
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||
|
all_matching_times_counter[col][0][0] += 1
|
||
|
elif djv_match_acc != 0:
|
||
|
all_matching_times_counter[col][1][0] += 1
|
||
|
|
||
|
def autolabel(rects,ax):
|
||
|
# attach some text labels
|
||
|
for rect in rects:
|
||
|
height = rect.get_height()
|
||
|
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%d'%int(height),
|
||
|
ha='center', va='bottom')
|
||
|
|
||
|
def autolabeldoubles(rects,ax):
|
||
|
# attach some text labels
|
||
|
for rect in rects:
|
||
|
height = rect.get_height()
|
||
|
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%s'%round(float(height),3),
|
||
|
ha='center', va='bottom')
|
||
|
|
||
|
def create_plots(name,results):
|
||
|
for sec in range(0,n_secs):
|
||
|
ind = np.arange(tests_n_lines) #
|
||
|
width = 0.25 # the width of the bars
|
||
|
|
||
|
fig = plt.figure()
|
||
|
ax = fig.add_subplot(111)
|
||
|
ax.set_xlim([-1*width, 2*width])
|
||
|
|
||
|
means_dvj = [x[0] for x in results[sec]]
|
||
|
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||
|
|
||
|
# add some
|
||
|
ax.set_ylabel(name)
|
||
|
ax.set_title("%s %s Results" % (options.test_seconds[sec],name))
|
||
|
ax.set_xticks(ind+width)
|
||
|
|
||
|
labels = [0 for x in range(0,tests_n_lines)]
|
||
|
for x in range(0,tests_n_lines):
|
||
|
labels[x] = "song %s" % (x+1)
|
||
|
ax.set_xticklabels( labels )
|
||
|
|
||
|
box = ax.get_position()
|
||
|
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||
|
|
||
|
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||
|
|
||
|
if name == 'Confidence':
|
||
|
autolabel(rects1,ax)
|
||
|
else:
|
||
|
autolabeldoubles(rects1,ax)
|
||
|
|
||
|
plt.grid()
|
||
|
|
||
|
fig_name = "%s%s_%s.png" % (options.results_folder,name,options.test_seconds[sec])
|
||
|
fig.savefig(fig_name)
|
||
|
|
||
|
create_plots('Confidence',all_match_confidence)
|
||
|
create_plots('Query duration',all_query_duration)
|
||
|
|
||
|
for sec in range(0,n_secs):
|
||
|
ind = np.arange(3) #
|
||
|
width = 0.25 # the width of the bars
|
||
|
|
||
|
fig = plt.figure()
|
||
|
ax = fig.add_subplot(111)
|
||
|
ax.set_xlim([-1*width, 2.75])
|
||
|
|
||
|
means_dvj = [round(x[0]*100/tests_n_lines,1) for x in all_match_counter[sec]]
|
||
|
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||
|
|
||
|
# add some
|
||
|
ax.set_ylabel('Matching Percentage')
|
||
|
ax.set_title('%s Matching Percentage' % options.test_seconds[sec])
|
||
|
ax.set_xticks(ind+width)
|
||
|
|
||
|
labels = ['yes','no','invalid']
|
||
|
ax.set_xticklabels( labels )
|
||
|
|
||
|
box = ax.get_position()
|
||
|
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||
|
|
||
|
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||
|
autolabeldoubles(rects1,ax)
|
||
|
|
||
|
plt.grid()
|
||
|
|
||
|
fig_name = "%smatching_perc_%s.png" % (options.results_folder,options.test_seconds[sec])
|
||
|
fig.savefig(fig_name)
|
||
|
|
||
|
for sec in range(0,n_secs):
|
||
|
ind = np.arange(2) #
|
||
|
width = 0.25 # the width of the bars
|
||
|
|
||
|
fig = plt.figure()
|
||
|
ax = fig.add_subplot(111)
|
||
|
ax.set_xlim([-1*width, 1.75])
|
||
|
|
||
|
div = all_match_counter[sec][0][0]
|
||
|
if div == 0 :
|
||
|
div = 1000000
|
||
|
|
||
|
means_dvj = [round(x[0]*100/div,1) for x in all_matching_times_counter[sec]]
|
||
|
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||
|
|
||
|
# add some
|
||
|
ax.set_ylabel('Matching Accuracy')
|
||
|
ax.set_title('%s Matching Times Accuracy' % options.test_seconds[sec])
|
||
|
ax.set_xticks(ind+width)
|
||
|
|
||
|
labels = ['yes','no']
|
||
|
ax.set_xticklabels( labels )
|
||
|
|
||
|
box = ax.get_position()
|
||
|
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||
|
|
||
|
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||
|
autolabeldoubles(rects1,ax)
|
||
|
|
||
|
plt.grid()
|
||
|
|
||
|
fig_name = "%smatching_acc_%s.png" % (options.results_folder,options.test_seconds[sec])
|
||
|
fig.savefig(fig_name)
|