mirror of
https://github.com/correl/dejavu.git
synced 2024-12-28 19:19:18 +00:00
Create results_dejavu.py
This commit is contained in:
parent
bfff19b887
commit
87469a0fe7
1 changed files with 355 additions and 0 deletions
355
scripts/results_dejavu.py
Normal file
355
scripts/results_dejavu.py
Normal file
|
@ -0,0 +1,355 @@
|
|||
# result generator for dejavu
|
||||
|
||||
# TODO: Don't work very well with musics with special chars.
|
||||
# use test file on the format below, with no special chars and only one "-" to separate artist from song
|
||||
|
||||
import os, subprocess, json, re, sys
|
||||
import logging, time
|
||||
from os import listdir
|
||||
from os.path import isfile, join
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
from optparse import OptionParser
|
||||
|
||||
#####
|
||||
### Test files are in specific format:
|
||||
### 'artist_name'-'song_name'_'start_time'_'duration'sec.wav
|
||||
#####
|
||||
|
||||
DEFAULT_FS = 44100
|
||||
DEFAULT_WINDOW_SIZE = 4096
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
|
||||
#FIELD_SONG_ID = 'song_id'
|
||||
FIELD_SONG_NAME = 'song_name'
|
||||
FIELD_CONFIDENCE = 'confidence'
|
||||
FIELD_QUERY_TIME = 'match_time'
|
||||
FIELD_OFFSET = 'offset'
|
||||
|
||||
# Parse options
|
||||
|
||||
usage = "usage: %prog [options] DEJAVU_PATH TEST_FOLDER"
|
||||
parser = OptionParser(usage=usage, version="%prog 1.1")
|
||||
|
||||
parser.add_option("--no-log",
|
||||
action="store_false",
|
||||
dest="log",
|
||||
default=True,
|
||||
help='Disables logging'
|
||||
)
|
||||
parser.add_option("--log-file",
|
||||
dest="log_file",
|
||||
default="results-compare.log",
|
||||
metavar="LOG_FILE",
|
||||
help='Set the path and filename of the log file'
|
||||
)
|
||||
parser.add_option("--test-seconds",
|
||||
action="append",
|
||||
dest="test_seconds",
|
||||
default=[],
|
||||
metavar="Xsec",
|
||||
help='Appends seconds to test suit'
|
||||
)
|
||||
parser.add_option("--results-folder",
|
||||
action="store",
|
||||
dest="results_folder",
|
||||
metavar="FOLDER",
|
||||
help='Sets the path where the results are saved'
|
||||
)
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
if len(args) != 2:
|
||||
parser.error("wrong number of arguments")
|
||||
|
||||
if len(options.test_seconds) == 0:
|
||||
options.test_seconds = ['1sec','2sec','3sec','4sec','5sec','6sec','7sec','8sec','9sec','10sec']
|
||||
|
||||
if options.log == True:
|
||||
logging.basicConfig( filename=options.log_file, level=logging.DEBUG )
|
||||
|
||||
if options.results_folder != "" and options.results_folder[len(options.results_folder)-1] != '/':
|
||||
options.results_folder += "/"
|
||||
|
||||
def log_msg(msg):
|
||||
if options.log == True:
|
||||
logging.debug(msg)
|
||||
|
||||
class DejavuTest (object):
|
||||
def __init__(self, folder, seconds):
|
||||
super(DejavuTest, self).__init__()
|
||||
|
||||
self.test_folder = folder
|
||||
self.test_seconds = seconds
|
||||
self.test_songs = []
|
||||
self.test_files = [ f for f in listdir(self.test_folder) if isfile(join(self.test_folder,f))
|
||||
and re.findall("[0-9]*sec",f)[0] in self.test_seconds ]
|
||||
self.n_columns = len(self.test_seconds)
|
||||
self.n_lines = len(self.test_files) / self.n_columns
|
||||
|
||||
# variable match results (yes, no, invalid)
|
||||
self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
# variable match precision (if matched in the corrected time)
|
||||
self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
# variable mahing time (query time)
|
||||
self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
# variable confidence
|
||||
self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
self.begin()
|
||||
|
||||
def get_column_id ( self,secs ):
|
||||
for i, sec in enumerate(self.test_seconds):
|
||||
if secs == sec:
|
||||
return i
|
||||
|
||||
def get_line_id ( self,artist, song ):
|
||||
elem = artist + " - " + song
|
||||
|
||||
for i, s in enumerate(self.test_songs):
|
||||
if elem == s:
|
||||
return i
|
||||
|
||||
self.test_songs.append(elem)
|
||||
return len(self.test_songs)-1
|
||||
|
||||
def begin(self):
|
||||
for f in self.test_files:
|
||||
log_msg('--------------------------------------------------')
|
||||
log_msg('file: %s' % f)
|
||||
# get column
|
||||
col = self.get_column_id(re.findall("[0-9]*sec",f)[0])
|
||||
|
||||
# get artist and song
|
||||
artist = re.findall("^[^\-]+",f)
|
||||
artist = artist[0].rstrip()
|
||||
|
||||
song = re.findall("\-[^\_]+",f)
|
||||
song = song[0].lstrip("- ")
|
||||
|
||||
line = self.get_line_id ( artist, song)
|
||||
|
||||
result = subprocess.check_output([args[0], 'recognize', 'file', self.test_folder+"/"+f])
|
||||
log_msg('RESULT: %s' % result.strip() )
|
||||
|
||||
if result.strip() == "None":
|
||||
log_msg('No match')
|
||||
self.result_match[line][col] = 'no'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
|
||||
else:
|
||||
result = result.strip()
|
||||
result = result.replace(" \'", ' "')
|
||||
result = result.replace("{\'", '{"')
|
||||
result = result.replace("\':", '":')
|
||||
result = result.replace("\',", '",')
|
||||
|
||||
result = json.loads(result)
|
||||
|
||||
artist_result = re.findall("^[^\-]+",result[FIELD_SONG_NAME])
|
||||
artist_result = artist_result[0].rstrip()
|
||||
|
||||
song_result = re.findall("\-[^\_]+",result[FIELD_SONG_NAME])
|
||||
song_result = song_result[0].lstrip("- ")
|
||||
|
||||
log_msg('artist: %s' % artist)
|
||||
log_msg('artist_result: %s' % artist_result)
|
||||
log_msg('song: %s' % song)
|
||||
log_msg('song_result: %s' % song_result)
|
||||
|
||||
if artist_result != artist or song_result != song:
|
||||
log_msg('invalid match')
|
||||
self.result_match[line][col] = 'invalid'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
else:
|
||||
log_msg('correct match')
|
||||
self.result_match[line][col] = 'yes'
|
||||
self.result_query_duration[line][col] = round(result[FIELD_QUERY_TIME],3)
|
||||
self.result_match_confidence[line][col] = result[FIELD_CONFIDENCE]
|
||||
|
||||
song_start_time = re.findall("\_[^\_]+",f)
|
||||
song_start_time = song_start_time[0].lstrip("_ ")
|
||||
|
||||
#result_start_time = round((result[FIELD_SONG_DURATION] * result[FIELD_OFFSET]) / float(result[FIELD_SONG_SPEC_DURATION]), 0)
|
||||
result_start_time = round((result[FIELD_OFFSET] * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS),0)
|
||||
|
||||
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||||
if (abs(self.result_matching_times[line][col]) == 1):
|
||||
self.result_matching_times[line][col] = 0
|
||||
|
||||
log_msg('query duration: %s' % round(result[FIELD_QUERY_TIME],3))
|
||||
log_msg('confidence: %s' % result[FIELD_CONFIDENCE])
|
||||
log_msg('song start_time: %s' % song_start_time)
|
||||
log_msg('result start time: %s' % result_start_time)
|
||||
if (self.result_matching_times[line][col] == 0):
|
||||
log_msg('accurate match')
|
||||
else:
|
||||
log_msg('inaccurate match')
|
||||
log_msg('--------------------------------------------------\n')
|
||||
|
||||
print "obtaining results from dejavu"
|
||||
log_msg('obtaining results from dejavu')
|
||||
tm = time.time()
|
||||
djv = DejavuTest(args[1], options.test_seconds)
|
||||
print "finished obtaining results from dejavu in %s" % (time.time() - tm)
|
||||
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm))
|
||||
|
||||
tests_n_lines = djv.n_lines
|
||||
tests_n_columns = djv.n_columns # len(options.test_seconds)
|
||||
tests = 1 # djv
|
||||
n_secs = len(options.test_seconds) # = tests.n_columns
|
||||
|
||||
# set result variables -> 4d variables
|
||||
all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)]
|
||||
all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)]
|
||||
all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
|
||||
all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
|
||||
|
||||
# agroup results by seconds
|
||||
for line in range(0, tests_n_lines):
|
||||
for col in range(0, tests_n_columns):
|
||||
# for dejavu
|
||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||||
|
||||
djv_match_result = djv.result_match[line][col]
|
||||
|
||||
if djv_match_result == 'yes':
|
||||
all_match_counter[col][0][0] += 1
|
||||
elif djv_match_result == 'no':
|
||||
all_match_counter[col][1][0] += 1
|
||||
else:
|
||||
all_match_counter[col][2][0] += 1
|
||||
|
||||
djv_match_acc = djv.result_matching_times[line][col]
|
||||
|
||||
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||||
all_matching_times_counter[col][0][0] += 1
|
||||
elif djv_match_acc != 0:
|
||||
all_matching_times_counter[col][1][0] += 1
|
||||
|
||||
def autolabel(rects,ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%d'%int(height),
|
||||
ha='center', va='bottom')
|
||||
|
||||
def autolabeldoubles(rects,ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%s'%round(float(height),3),
|
||||
ha='center', va='bottom')
|
||||
|
||||
def create_plots(name,results):
|
||||
for sec in range(0,n_secs):
|
||||
ind = np.arange(tests_n_lines) #
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1*width, 2*width])
|
||||
|
||||
means_dvj = [x[0] for x in results[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
# add some
|
||||
ax.set_ylabel(name)
|
||||
ax.set_title("%s %s Results" % (options.test_seconds[sec],name))
|
||||
ax.set_xticks(ind+width)
|
||||
|
||||
labels = [0 for x in range(0,tests_n_lines)]
|
||||
for x in range(0,tests_n_lines):
|
||||
labels[x] = "song %s" % (x+1)
|
||||
ax.set_xticklabels( labels )
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if name == 'Confidence':
|
||||
autolabel(rects1,ax)
|
||||
else:
|
||||
autolabeldoubles(rects1,ax)
|
||||
|
||||
plt.grid()
|
||||
|
||||
fig_name = "%s%s_%s.png" % (options.results_folder,name,options.test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
|
||||
create_plots('Confidence',all_match_confidence)
|
||||
create_plots('Query duration',all_query_duration)
|
||||
|
||||
for sec in range(0,n_secs):
|
||||
ind = np.arange(3) #
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1*width, 2.75])
|
||||
|
||||
means_dvj = [round(x[0]*100/tests_n_lines,1) for x in all_match_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Percentage')
|
||||
ax.set_title('%s Matching Percentage' % options.test_seconds[sec])
|
||||
ax.set_xticks(ind+width)
|
||||
|
||||
labels = ['yes','no','invalid']
|
||||
ax.set_xticklabels( labels )
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
|
||||
plt.grid()
|
||||
|
||||
fig_name = "%smatching_perc_%s.png" % (options.results_folder,options.test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
|
||||
for sec in range(0,n_secs):
|
||||
ind = np.arange(2) #
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1*width, 1.75])
|
||||
|
||||
div = all_match_counter[sec][0][0]
|
||||
if div == 0 :
|
||||
div = 1000000
|
||||
|
||||
means_dvj = [round(x[0]*100/div,1) for x in all_matching_times_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Accuracy')
|
||||
ax.set_title('%s Matching Times Accuracy' % options.test_seconds[sec])
|
||||
ax.set_xticks(ind+width)
|
||||
|
||||
labels = ['yes','no']
|
||||
ax.set_xticklabels( labels )
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
|
||||
plt.grid()
|
||||
|
||||
fig_name = "%smatching_acc_%s.png" % (options.results_folder,options.test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
Loading…
Reference in a new issue