import numpy as np
from scipy import io
from sklearn.metrics import roc_auc_score, average_precision_score
import pmf
train_tracks = list()
with open('train_tracks.txt', 'rb') as f:
for line in f:
train_tracks.append(line.strip())
test_tracks = list()
with open('test_tracks.txt', 'rb') as f:
for line in f:
test_tracks.append(line.strip())
tags = list()
with open('voc.txt', 'rb') as f:
for line in f:
tags.append(line.strip())
# compute evaluation metrics
def construct_pred_mask(tags_predicted, predictat):
n_samples, n_tags = tags_predicted.shape
rankings = np.argsort(-tags_predicted, axis=1)[:, :predictat]
tags_predicted_binary = np.zeros_like(tags_predicted, dtype=bool)
for i in xrange(n_samples):
tags_predicted_binary[i, rankings[i]] = 1
return tags_predicted_binary
def per_tag_prec_recall(tags_predicted_binary, tags_true_binary):
mask = np.logical_and(tags_predicted_binary, tags_true_binary)
prec = mask.sum(axis=0) / (tags_predicted_binary.sum(axis=0) + np.spacing(1))
tags_true_count = tags_true_binary.sum(axis=0).astype(float)
idx = (tags_true_count > 0)
recall = mask.sum(axis=0)[idx] / tags_true_count[idx]
return prec, recall
def aroc_ap(tags_true_binary, tags_predicted):
n_tags = tags_true_binary.shape[1]
auc = list()
aprec = list()
for i in xrange(n_tags):
if np.sum(tags_true_binary[:, i]) != 0:
auc.append(roc_auc_score(tags_true_binary[:, i], tags_predicted[:, i]))
aprec.append(average_precision_score(tags_true_binary[:, i], tags_predicted[:, i]))
return auc, aprec
def print_out_metrics(tags_true_binary, tags_predicted, predictat):
tags_predicted_binary = construct_pred_mask(tags_predicted, predictat)
prec, recall = per_tag_prec_recall(tags_predicted_binary, tags_true_binary)
mprec, mrecall = np.mean(prec), np.mean(recall)
print 'Precision = %.3f (%.3f)' % (mprec, np.std(prec) / sqrt(prec.size))
print 'Recall = %.3f (%.3f)' % (mrecall, np.std(recall) / sqrt(recall.size))
print 'F-score = %.3f' % (2 * mprec * mrecall / (mprec + mrecall))
auc, aprec = aroc_ap(tags_true_binary, tags_predicted)
print 'AROC = %.3f (%.3f)' % (np.mean(auc), np.std(auc) / sqrt(len(auc)))
print 'AP = %.3f (%.3f)' % (np.mean(aprec), np.std(aprec) / sqrt(len(aprec)))
# codebook size (for in memoery, should not be too large)
K = 512
# load the pre-saved data
data_mat = io.loadmat('data_K%d.mat' % K)
X, X_test, y_test = data_mat['X'], data_mat['X_test'], data_mat['y_test']
tmp = X[:, K:]
tmp[tmp > 0] = 1
X[:, K:] = tmp
hist(np.sum( (y_test > 0), axis=1), bins=50)
pass
D = K + len(tags)
# pick a random song and take a look
bar(np.arange(D), X[0])
<Container object of 1073 artists>
n_components = 100
coder = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)
coder.fit(X)
After ITERATION: 39 Objective: 10746172.92 Old objective: 10741000.95 Improvement: 0.00048
pmf.py:164: RuntimeWarning: invalid value encountered in double_scalars improvement = (bound - old_bd) / abs(old_bd)
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100, tol=0.0005, verbose=True)
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)
tagger.set_components(coder.gamma_b[:, :K], coder.rho_b)
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100, tol=0.0005, verbose=True)
Et = tagger.transform(X_test)
After ITERATION: 18 Objective: 3308237.31 Old objective: 3306803.62 Improvement: 0.00043
Et /= Et.sum(axis=1, keepdims=True)
tags_predicted = Et.dot(coder.Eb[:, K:])
print tags_predicted.min(), tags_predicted.max()
div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)
0.00094841825642 1.35247354242
predictat = 20
tags_true_binary = (y_test > 0)
print_out_metrics(tags_true_binary, tags_predicted, predictat)
Precision = 0.111 (0.007) Recall = 0.106 (0.006) F-score = 0.108 AROC = 0.640 (0.005) AP = 0.097 (0.005)
n_components = 100
online_coder = pmf.OnlinePoissonMF(n_components=n_components, batch_size=500, n_pass=1,
random_state=98765, verbose=True)
online_coder.fit(X, est_total=len(train_tracks))
Iteration 0: passing through the data... Minibatch 1: After ITERATION: 99 Objective: -478725.05 Old objective: -478998.31 Improvement: 0.00057 Minibatch 2: After ITERATION: 69 Objective: 83380.31 Old objective: 83338.70 Improvement: 0.00050 Minibatch 3: After ITERATION: 41 Objective: 217748.29 Old objective: 217640.87 Improvement: 0.00049 Minibatch 4: After ITERATION: 33 Objective: 316526.52 Old objective: 316374.09 Improvement: 0.00048 Minibatch 5: After ITERATION: 30 Objective: 322068.90 Old objective: 321908.73 Improvement: 0.00050 Minibatch 6: After ITERATION: 30 Objective: 311596.59 Old objective: 311442.51 Improvement: 0.00049 Minibatch 7: After ITERATION: 28 Objective: 335036.23 Old objective: 334869.13 Improvement: 0.00050 Minibatch 8: After ITERATION: 27 Objective: 335628.81 Old objective: 335470.08 Improvement: 0.00047 Minibatch 9: After ITERATION: 26 Objective: 345758.47 Old objective: 345589.10 Improvement: 0.00049 Minibatch 10: After ITERATION: 26 Objective: 336569.17 Old objective: 336411.84 Improvement: 0.00047 Minibatch 11: After ITERATION: 24 Objective: 401768.07 Old objective: 401568.16 Improvement: 0.00050 Minibatch 12: After ITERATION: 24 Objective: 418374.06 Old objective: 418176.19 Improvement: 0.00047 Minibatch 13: After ITERATION: 24 Objective: 391494.32 Old objective: 391309.49 Improvement: 0.00047 Minibatch 14: After ITERATION: 24 Objective: 398904.84 Old objective: 398724.56 Improvement: 0.00045 Minibatch 15: After ITERATION: 24 Objective: 389183.27 Old objective: 389004.84 Improvement: 0.00046 Minibatch 16: After ITERATION: 22 Objective: 446196.69 Old objective: 445983.59 Improvement: 0.00048 Minibatch 17: After ITERATION: 23 Objective: 420530.37 Old objective: 420340.72 Improvement: 0.00045 Minibatch 18: After ITERATION: 21 Objective: 475206.58 Old objective: 474976.67 Improvement: 0.00048 Minibatch 19: After ITERATION: 21 Objective: 449503.35 Old objective: 449283.21 Improvement: 0.00049 Minibatch 20: After ITERATION: 21 Objective: 473768.83 Old objective: 473546.76 Improvement: 0.00047
OnlinePoissonMF(batch_size=500, max_iter=100, n_components=100, n_pass=1, random_state=98765, shuffle=True, smoothness=100, tol=0.0005, verbose=True)
plot(online_coder.bound)
pass
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)
tagger.set_components(online_coder.gamma_b[:, :K], online_coder.rho_b[:, :K])
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100, tol=0.0005, verbose=True)
Et = tagger.transform(X_test)
After ITERATION: 20 Objective: 3068245.32 Old objective: 3066785.25 Improvement: 0.00048
Et /= Et.sum(axis=1, keepdims=True)
tags_predicted = Et.dot(online_coder.Eb[:, K:])
n_samples, n_tags = tags_predicted.shape
print tags_predicted.min(), tags_predicted.max()
div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)
5.46251938084e-05 1.09877408196
predictat = 20
tags_true_binary = (y_test > 0)
print_out_metrics(tags_true_binary, tags_predicted, predictat)
Precision = 0.112 (0.007) Recall = 0.128 (0.007) F-score = 0.120 AROC = 0.684 (0.005) AP = 0.110 (0.006)
data_mat = io.loadmat('X_train_K%d.mat' % K)
tag_mat = io.loadmat('y_train.mat')
X = np.hstack((data_mat['X'], tag_mat['y_train']))
n_components = 100
batch_size = 1000
online_coder_full = pmf.OnlinePoissonMF(n_components=n_components, batch_size=batch_size, n_pass=1,
random_state=98765, verbose=True)
online_coder_full.fit(X)
Iteration 0: passing through the data... Minibatch 1: After ITERATION: 99 Objective: -838968.48 Old objective: -839520.23 Improvement: 0.00066 Minibatch 2: After ITERATION: 55 Objective: 330451.59 Old objective: 330293.21 Improvement: 0.00048 Minibatch 3: After ITERATION: 38 Objective: 553125.94 Old objective: 552867.14 Improvement: 0.00047 Minibatch 4: After ITERATION: 33 Objective: 635871.16 Old objective: 635556.16 Improvement: 0.00050 Minibatch 5: After ITERATION: 30 Objective: 695878.24 Old objective: 695546.55 Improvement: 0.00048 Minibatch 6: After ITERATION: 28 Objective: 768466.09 Old objective: 768087.58 Improvement: 0.00049 Minibatch 7: After ITERATION: 28 Objective: 739469.91 Old objective: 739118.06 Improvement: 0.00048 Minibatch 8: After ITERATION: 26 Objective: 836343.78 Old objective: 835954.64 Improvement: 0.00047 Minibatch 9: After ITERATION: 26 Objective: 852619.10 Old objective: 852227.48 Improvement: 0.00046 Minibatch 10: After ITERATION: 25 Objective: 819482.01 Old objective: 819082.19 Improvement: 0.00049 Minibatch 11: After ITERATION: 24 Objective: 845208.45 Old objective: 844814.10 Improvement: 0.00047 Minibatch 12: After ITERATION: 24 Objective: 845044.81 Old objective: 844642.22 Improvement: 0.00048 Minibatch 13: After ITERATION: 24 Objective: 816059.30 Old objective: 815666.91 Improvement: 0.00048 Minibatch 14: After ITERATION: 24 Objective: 794065.94 Old objective: 793680.55 Improvement: 0.00049 Minibatch 15: After ITERATION: 22 Objective: 974278.44 Old objective: 973796.11 Improvement: 0.00050 Minibatch 16: After ITERATION: 23 Objective: 896549.23 Old objective: 896143.46 Improvement: 0.00045 Minibatch 17: After ITERATION: 22 Objective: 863382.90 Old objective: 862961.57 Improvement: 0.00049 Minibatch 18: After ITERATION: 23 Objective: 882706.00 Old objective: 882306.14 Improvement: 0.00045 Minibatch 19: After ITERATION: 22 Objective: 891975.51 Old objective: 891550.33 Improvement: 0.00048 Minibatch 20: After ITERATION: 21 Objective: 924968.90 Old objective: 924517.78 Improvement: 0.00049 Minibatch 21: After ITERATION: 21 Objective: 948423.94 Old objective: 947972.97 Improvement: 0.00048 Minibatch 22: After ITERATION: 21 Objective: 855681.70 Old objective: 855254.13 Improvement: 0.00050 Minibatch 23: After ITERATION: 20 Objective: 979569.87 Old objective: 979086.64 Improvement: 0.00049 Minibatch 24: After ITERATION: 21 Objective: 959566.25 Old objective: 959114.87 Improvement: 0.00047 Minibatch 25: After ITERATION: 20 Objective: 966727.63 Old objective: 966257.28 Improvement: 0.00049 Minibatch 26: After ITERATION: 21 Objective: 924176.66 Old objective: 923751.42 Improvement: 0.00046 Minibatch 27: After ITERATION: 20 Objective: 955808.82 Old objective: 955350.94 Improvement: 0.00048 Minibatch 28: After ITERATION: 20 Objective: 933417.55 Old objective: 932984.50 Improvement: 0.00046 Minibatch 29: After ITERATION: 20 Objective: 934965.92 Old objective: 934510.41 Improvement: 0.00049 Minibatch 30: After ITERATION: 20 Objective: 923400.92 Old objective: 922950.94 Improvement: 0.00049 Minibatch 31: After ITERATION: 19 Objective: 1011637.22 Old objective: 1011153.42 Improvement: 0.00048 Minibatch 32: After ITERATION: 20 Objective: 919927.35 Old objective: 919490.85 Improvement: 0.00047 Minibatch 33: After ITERATION: 19 Objective: 983848.50 Old objective: 983359.24 Improvement: 0.00050 Minibatch 34: After ITERATION: 19 Objective: 1011851.79 Old objective: 1011374.01 Improvement: 0.00047 Minibatch 35: After ITERATION: 19 Objective: 943447.83 Old objective: 942977.35 Improvement: 0.00050 Minibatch 36: After ITERATION: 19 Objective: 950197.67 Old objective: 949724.08 Improvement: 0.00050 Minibatch 37: After ITERATION: 19 Objective: 992837.84 Old objective: 992395.29 Improvement: 0.00045 Minibatch 38: After ITERATION: 19 Objective: 981298.24 Old objective: 980827.19 Improvement: 0.00048 Minibatch 39: After ITERATION: 19 Objective: 994155.85 Old objective: 993681.70 Improvement: 0.00048 Minibatch 40: After ITERATION: 19 Objective: 944250.43 Old objective: 943814.25 Improvement: 0.00046 Minibatch 41: After ITERATION: 19 Objective: 952229.46 Old objective: 951816.84 Improvement: 0.00043 Minibatch 42: After ITERATION: 18 Objective: 991395.34 Old objective: 990905.65 Improvement: 0.00049 Minibatch 43: After ITERATION: 19 Objective: 979112.73 Old objective: 978678.59 Improvement: 0.00044 Minibatch 44: After ITERATION: 18 Objective: 1021684.05 Old objective: 1021176.08 Improvement: 0.00050 Minibatch 45: After ITERATION: 19 Objective: 964037.28 Old objective: 963590.83 Improvement: 0.00046 Minibatch 46: After ITERATION: 19 Objective: 996080.71 Old objective: 995649.74 Improvement: 0.00043 Minibatch 47: After ITERATION: 18 Objective: 1007201.75 Old objective: 1006727.54 Improvement: 0.00047 Minibatch 48: After ITERATION: 19 Objective: 963821.48 Old objective: 963378.24 Improvement: 0.00046 Minibatch 49: After ITERATION: 18 Objective: 977100.85 Old objective: 976616.29 Improvement: 0.00050 Minibatch 50: After ITERATION: 18 Objective: 1037832.34 Old objective: 1037336.13 Improvement: 0.00048 Minibatch 51: After ITERATION: 18 Objective: 1019284.29 Old objective: 1018816.05 Improvement: 0.00046 Minibatch 52: After ITERATION: 19 Objective: 957454.85 Old objective: 957031.01 Improvement: 0.00044 Minibatch 53: After ITERATION: 19 Objective: 932207.71 Old objective: 931787.92 Improvement: 0.00045 Minibatch 54: After ITERATION: 18 Objective: 1043985.60 Old objective: 1043509.42 Improvement: 0.00046 Minibatch 55: After ITERATION: 17 Objective: 1054681.95 Old objective: 1054169.84 Improvement: 0.00049 Minibatch 56: After ITERATION: 18 Objective: 947185.90 Old objective: 946717.33 Improvement: 0.00049 Minibatch 57: After ITERATION: 18 Objective: 962078.55 Old objective: 961625.79 Improvement: 0.00047 Minibatch 58: After ITERATION: 18 Objective: 954449.56 Old objective: 953988.60 Improvement: 0.00048 Minibatch 59: After ITERATION: 18 Objective: 1033015.34 Old objective: 1032560.56 Improvement: 0.00044 Minibatch 60: After ITERATION: 18 Objective: 950030.86 Old objective: 949565.39 Improvement: 0.00049 Minibatch 61: After ITERATION: 18 Objective: 1006097.61 Old objective: 1005648.49 Improvement: 0.00045 Minibatch 62: After ITERATION: 18 Objective: 1010524.60 Old objective: 1010063.96 Improvement: 0.00046 Minibatch 63: After ITERATION: 18 Objective: 960181.94 Old objective: 959740.15 Improvement: 0.00046 Minibatch 64: After ITERATION: 18 Objective: 1005279.89 Old objective: 1004814.24 Improvement: 0.00046 Minibatch 65: After ITERATION: 18 Objective: 1006458.92 Old objective: 1006024.37 Improvement: 0.00043 Minibatch 66: After ITERATION: 18 Objective: 1025903.07 Old objective: 1025452.39 Improvement: 0.00044 Minibatch 67: After ITERATION: 18 Objective: 984433.06 Old objective: 983998.63 Improvement: 0.00044 Minibatch 68: After ITERATION: 17 Objective: 1024722.72 Old objective: 1024232.46 Improvement: 0.00048 Minibatch 69: After ITERATION: 17 Objective: 1049168.83 Old objective: 1048676.28 Improvement: 0.00047 Minibatch 70: After ITERATION: 18 Objective: 939387.34 Old objective: 938955.50 Improvement: 0.00046 Minibatch 71: After ITERATION: 18 Objective: 1017890.59 Old objective: 1017460.27 Improvement: 0.00042 Minibatch 72: After ITERATION: 17 Objective: 1083556.84 Old objective: 1083076.20 Improvement: 0.00044 Minibatch 73: After ITERATION: 17 Objective: 1003922.28 Old objective: 1003433.21 Improvement: 0.00049 Minibatch 74: After ITERATION: 17 Objective: 1043622.76 Old objective: 1043121.90 Improvement: 0.00048 Minibatch 75: After ITERATION: 18 Objective: 974777.54 Old objective: 974338.72 Improvement: 0.00045 Minibatch 76: After ITERATION: 17 Objective: 1018929.83 Old objective: 1018472.21 Improvement: 0.00045 Minibatch 77: After ITERATION: 17 Objective: 977949.65 Old objective: 977465.20 Improvement: 0.00050 Minibatch 78: After ITERATION: 17 Objective: 1066141.66 Old objective: 1065660.94 Improvement: 0.00045 Minibatch 79: After ITERATION: 18 Objective: 969007.97 Old objective: 968573.76 Improvement: 0.00045 Minibatch 80: After ITERATION: 18 Objective: 990091.70 Old objective: 989666.39 Improvement: 0.00043 Minibatch 81: After ITERATION: 17 Objective: 1012554.69 Old objective: 1012060.69 Improvement: 0.00049 Minibatch 82: After ITERATION: 17 Objective: 1075548.13 Old objective: 1075070.57 Improvement: 0.00044 Minibatch 83: After ITERATION: 17 Objective: 1013357.05 Old objective: 1012917.77 Improvement: 0.00043 Minibatch 84: After ITERATION: 17 Objective: 951811.85 Old objective: 951341.39 Improvement: 0.00049 Minibatch 85: After ITERATION: 17 Objective: 1039373.30 Old objective: 1038916.34 Improvement: 0.00044 Minibatch 86: After ITERATION: 17 Objective: 1041026.72 Old objective: 1040543.29 Improvement: 0.00046 Minibatch 87: After ITERATION: 17 Objective: 981806.83 Old objective: 981337.39 Improvement: 0.00048 Minibatch 88: After ITERATION: 17 Objective: 984409.75 Old objective: 983922.09 Improvement: 0.00050 Minibatch 89: After ITERATION: 17 Objective: 974591.89 Old objective: 974136.58 Improvement: 0.00047 Minibatch 90: After ITERATION: 17 Objective: 1042582.16 Old objective: 1042111.95 Improvement: 0.00045 Minibatch 91: After ITERATION: 17 Objective: 1072026.12 Old objective: 1071567.77 Improvement: 0.00043 Minibatch 92: After ITERATION: 16 Objective: 1040577.82 Old objective: 1040063.57 Improvement: 0.00049 Minibatch 93: After ITERATION: 17 Objective: 974010.32 Old objective: 973526.75 Improvement: 0.00050 Minibatch 94: After ITERATION: 17 Objective: 1070298.98 Old objective: 1069809.17 Improvement: 0.00046 Minibatch 95: After ITERATION: 17 Objective: 919223.29 Old objective: 918769.84 Improvement: 0.00049 Minibatch 96: After ITERATION: 17 Objective: 1029070.82 Old objective: 1028594.58 Improvement: 0.00046 Minibatch 97: After ITERATION: 17 Objective: 1017395.00 Old objective: 1016947.57 Improvement: 0.00044 Minibatch 98: After ITERATION: 17 Objective: 1017731.10 Old objective: 1017278.11 Improvement: 0.00045 Minibatch 99: After ITERATION: 17 Objective: 972302.05 Old objective: 971851.16 Improvement: 0.00046 Minibatch 100: After ITERATION: 17 Objective: 1032231.04 Old objective: 1031776.26 Improvement: 0.00044 Minibatch 101: After ITERATION: 17 Objective: 992875.16 Old objective: 992428.54 Improvement: 0.00045 Minibatch 102: After ITERATION: 16 Objective: 1071424.25 Old objective: 1070903.82 Improvement: 0.00049 Minibatch 103: After ITERATION: 16 Objective: 1068643.98 Old objective: 1068111.96 Improvement: 0.00050 Minibatch 104: After ITERATION: 16 Objective: 1076202.58 Old objective: 1075678.17 Improvement: 0.00049 Minibatch 105: After ITERATION: 17 Objective: 1052727.79 Old objective: 1052275.29 Improvement: 0.00043 Minibatch 106: After ITERATION: 16 Objective: 1077327.28 Old objective: 1076794.34 Improvement: 0.00049 Minibatch 107: After ITERATION: 17 Objective: 1035360.77 Old objective: 1034917.38 Improvement: 0.00043 Minibatch 108: After ITERATION: 16 Objective: 1051049.96 Old objective: 1050535.16 Improvement: 0.00049 Minibatch 109: After ITERATION: 16 Objective: 1089714.15 Old objective: 1089189.53 Improvement: 0.00048 Minibatch 110: After ITERATION: 16 Objective: 1056248.04 Old objective: 1055756.68 Improvement: 0.00047 Minibatch 111: After ITERATION: 16 Objective: 1120971.59 Old objective: 1120436.46 Improvement: 0.00048 Minibatch 112: After ITERATION: 16 Objective: 1063079.05 Old objective: 1062564.31 Improvement: 0.00048 Minibatch 113: After ITERATION: 17 Objective: 988163.42 Old objective: 987743.95 Improvement: 0.00042 Minibatch 114: After ITERATION: 17 Objective: 1013004.93 Old objective: 1012561.58 Improvement: 0.00044 Minibatch 115: After ITERATION: 16 Objective: 1082110.48 Old objective: 1081606.30 Improvement: 0.00047 Minibatch 116: After ITERATION: 17 Objective: 1005101.06 Old objective: 1004662.91 Improvement: 0.00044 Minibatch 117: After ITERATION: 17 Objective: 1003218.58 Old objective: 1002777.58 Improvement: 0.00044 Minibatch 118: After ITERATION: 16 Objective: 1115720.47 Old objective: 1115202.88 Improvement: 0.00046 Minibatch 119: After ITERATION: 16 Objective: 1051355.74 Old objective: 1050854.45 Improvement: 0.00048 Minibatch 120: After ITERATION: 16 Objective: 1086837.69 Old objective: 1086318.34 Improvement: 0.00048 Minibatch 121: After ITERATION: 16 Objective: 1061516.71 Old objective: 1061009.79 Improvement: 0.00048 Minibatch 122: After ITERATION: 17 Objective: 988589.83 Old objective: 988164.22 Improvement: 0.00043 Minibatch 123: After ITERATION: 16 Objective: 1025563.35 Old objective: 1025072.77 Improvement: 0.00048 Minibatch 124: After ITERATION: 17 Objective: 1067431.06 Old objective: 1066969.37 Improvement: 0.00043 Minibatch 125: After ITERATION: 16 Objective: 1050025.33 Old objective: 1049537.93 Improvement: 0.00046 Minibatch 126: After ITERATION: 16 Objective: 1019439.84 Old objective: 1018956.99 Improvement: 0.00047 Minibatch 127: After ITERATION: 16 Objective: 1045361.87 Old objective: 1044845.38 Improvement: 0.00049 Minibatch 128: After ITERATION: 16 Objective: 1056190.13 Old objective: 1055689.20 Improvement: 0.00047 Minibatch 129: After ITERATION: 16 Objective: 1102463.37 Old objective: 1101964.14 Improvement: 0.00045 Minibatch 130: After ITERATION: 16 Objective: 1051361.05 Old objective: 1050865.13 Improvement: 0.00047 Minibatch 131: After ITERATION: 16 Objective: 1060032.77 Old objective: 1059545.76 Improvement: 0.00046 Minibatch 132: After ITERATION: 16 Objective: 1043248.71 Old objective: 1042758.84 Improvement: 0.00047 Minibatch 133: After ITERATION: 16 Objective: 1012186.66 Old objective: 1011731.13 Improvement: 0.00045 Minibatch 134: After ITERATION: 16 Objective: 1023425.71 Old objective: 1022977.10 Improvement: 0.00044 Minibatch 135: After ITERATION: 17 Objective: 1012973.09 Old objective: 1012536.65 Improvement: 0.00043 Minibatch 136: After ITERATION: 17 Objective: 933590.39 Old objective: 933189.01 Improvement: 0.00043 Minibatch 137: After ITERATION: 16 Objective: 1060170.38 Old objective: 1059658.45 Improvement: 0.00048 Minibatch 138: After ITERATION: 17 Objective: 1000358.72 Old objective: 999924.72 Improvement: 0.00043 Minibatch 139: After ITERATION: 17 Objective: 1010732.43 Old objective: 1010312.39 Improvement: 0.00042 Minibatch 140: After ITERATION: 16 Objective: 1025983.25 Old objective: 1025508.92 Improvement: 0.00046 Minibatch 141: After ITERATION: 16 Objective: 1090344.03 Old objective: 1089878.39 Improvement: 0.00043 Minibatch 142: After ITERATION: 16 Objective: 1041233.29 Old objective: 1040722.70 Improvement: 0.00049 Minibatch 143: After ITERATION: 16 Objective: 1020799.80 Old objective: 1020322.90 Improvement: 0.00047 Minibatch 144: After ITERATION: 16 Objective: 1093435.88 Old objective: 1092948.65 Improvement: 0.00045 Minibatch 145: After ITERATION: 16 Objective: 1012263.74 Old objective: 1011792.12 Improvement: 0.00047 Minibatch 146: After ITERATION: 16 Objective: 1055632.66 Old objective: 1055154.68 Improvement: 0.00045 Minibatch 147: After ITERATION: 15 Objective: 1147956.81 Old objective: 1147393.11 Improvement: 0.00049 Minibatch 148: After ITERATION: 16 Objective: 1092930.99 Old objective: 1092449.68 Improvement: 0.00044 Minibatch 149: After ITERATION: 16 Objective: 1055100.46 Old objective: 1054621.33 Improvement: 0.00045 Minibatch 150: After ITERATION: 16 Objective: 1065273.87 Old objective: 1064767.40 Improvement: 0.00048 Minibatch 151: After ITERATION: 16 Objective: 1066387.52 Old objective: 1065905.20 Improvement: 0.00045 Minibatch 152: After ITERATION: 16 Objective: 1009737.48 Old objective: 1009272.22 Improvement: 0.00046 Minibatch 153: After ITERATION: 16 Objective: 1096544.91 Old objective: 1096056.93 Improvement: 0.00045 Minibatch 154: After ITERATION: 16 Objective: 967781.76 Old objective: 967302.45 Improvement: 0.00050 Minibatch 155: After ITERATION: 16 Objective: 1123083.12 Old objective: 1122600.07 Improvement: 0.00043 Minibatch 156: After ITERATION: 15 Objective: 1077953.25 Old objective: 1077473.43 Improvement: 0.00045 Minibatch 289: After ITERATION: 15 Objective: 1089217.47 Old objective: 1088718.46 Improvement: 0.00046 Minibatch 290: After ITERATION: 14 Objective: 1165825.84 Old objective: 1165249.45 Improvement: 0.00049 Minibatch 291: After ITERATION: 15 Objective: 1106836.86 Old objective: 1106340.41 Improvement: 0.00045 Minibatch 292: After ITERATION: 15 Objective: 1032560.01 Old objective: 1032077.71 Improvement: 0.00047 Minibatch 293: After ITERATION: 15 Objective: 1027049.26 Old objective: 1026549.71 Improvement: 0.00049 Minibatch 294: After ITERATION: 14 Objective: 1228125.42 Old objective: 1227538.27 Improvement: 0.00048 Minibatch 295: After ITERATION: 15 Objective: 1097882.46 Old objective: 1097374.40 Improvement: 0.00046 Minibatch 296: After ITERATION: 15 Objective: 1056043.15 Old objective: 1055546.46 Improvement: 0.00047 Minibatch 297: After ITERATION: 15 Objective: 1076393.49 Old objective: 1075893.55 Improvement: 0.00046 Minibatch 298: After ITERATION: 15 Objective: 1093885.09 Old objective: 1093368.91 Improvement: 0.00047 Minibatch 299: After ITERATION: 15 Objective: 1017895.30 Old objective: 1017392.30 Improvement: 0.00049 Minibatch 300: After ITERATION: 15 Objective: 1093496.13 Old objective: 1093015.35 Improvement: 0.00044 Minibatch 301: After ITERATION: 15 Objective: 1038203.66 Old objective: 1037734.18 Improvement: 0.00045 Minibatch 302: After ITERATION: 15 Objective: 1091164.84 Old objective: 1090694.54 Improvement: 0.00043 Minibatch 303: After ITERATION: 15 Objective: 1099725.70 Old objective: 1099207.36 Improvement: 0.00047 Minibatch 304: After ITERATION: 15 Objective: 1177362.43 Old objective: 1176845.71 Improvement: 0.00044 Minibatch 305: After ITERATION: 15 Objective: 1139684.47 Old objective: 1139149.42 Improvement: 0.00047 Minibatch 306: After ITERATION: 15 Objective: 1070106.28 Old objective: 1069603.90 Improvement: 0.00047 Minibatch 307: After ITERATION: 15 Objective: 1127686.79 Old objective: 1127213.48 Improvement: 0.00042 Minibatch 308: After ITERATION: 15 Objective: 1107884.48 Old objective: 1107392.22 Improvement: 0.00044 Minibatch 309: After ITERATION: 15 Objective: 1094298.45 Old objective: 1093776.53 Improvement: 0.00048 Minibatch 310: After ITERATION: 15 Objective: 1063033.31 Old objective: 1062522.29 Improvement: 0.00048 Minibatch 311: After ITERATION: 15 Objective: 1105122.11 Old objective: 1104636.89 Improvement: 0.00044 Minibatch 312: After ITERATION: 15 Objective: 1142106.67 Old objective: 1141616.15 Improvement: 0.00043 Minibatch 313: After ITERATION: 15 Objective: 1108809.55 Old objective: 1108329.21 Improvement: 0.00043 Minibatch 314: After ITERATION: 15 Objective: 1069988.46 Old objective: 1069468.85 Improvement: 0.00049 Minibatch 315: After ITERATION: 15 Objective: 1097799.84 Old objective: 1097298.06 Improvement: 0.00046 Minibatch 316: After ITERATION: 15 Objective: 1162605.14 Old objective: 1162114.87 Improvement: 0.00042 Minibatch 317: After ITERATION: 15 Objective: 1109930.94 Old objective: 1109448.15 Improvement: 0.00044 Minibatch 318: After ITERATION: 15 Objective: 1038812.48 Old objective: 1038346.80 Improvement: 0.00045 Minibatch 319: After ITERATION: 15 Objective: 1067610.58 Old objective: 1067128.58 Improvement: 0.00045 Minibatch 320: After ITERATION: 15 Objective: 1069347.33 Old objective: 1068856.14 Improvement: 0.00046 Minibatch 321: After ITERATION: 15 Objective: 1094636.35 Old objective: 1094137.08 Improvement: 0.00046 Minibatch 322: After ITERATION: 15 Objective: 1096606.83 Old objective: 1096113.41 Improvement: 0.00045 Minibatch 323: After ITERATION: 15 Objective: 1038113.38 Old objective: 1037628.07 Improvement: 0.00047 Minibatch 324: After ITERATION: 14 Objective: 1173746.46 Old objective: 1173168.28 Improvement: 0.00049 Minibatch 325: After ITERATION: 15 Objective: 1104455.22 Old objective: 1103955.33 Improvement: 0.00045 Minibatch 326: After ITERATION: 15 Objective: 1031683.65 Old objective: 1031186.31 Improvement: 0.00048 Minibatch 327: After ITERATION: 15 Objective: 1081303.79 Old objective: 1080816.41 Improvement: 0.00045 Minibatch 328: After ITERATION: 15 Objective: 1113955.30 Old objective: 1113462.86 Improvement: 0.00044 Minibatch 329: After ITERATION: 15 Objective: 1113190.81 Old objective: 1112713.83 Improvement: 0.00043 Minibatch 330: After ITERATION: 15 Objective: 1093214.67 Old objective: 1092716.61 Improvement: 0.00046 Minibatch 331: After ITERATION: 15 Objective: 1037706.20 Old objective: 1037235.50 Improvement: 0.00045 Minibatch 332: After ITERATION: 15 Objective: 1023487.36 Old objective: 1022989.60 Improvement: 0.00049 Minibatch 333: After ITERATION: 15 Objective: 1013706.55 Old objective: 1013220.04 Improvement: 0.00048 Minibatch 334: After ITERATION: 15 Objective: 1048025.35 Old objective: 1047560.67 Improvement: 0.00044 Minibatch 335: After ITERATION: 15 Objective: 1026867.44 Old objective: 1026385.18 Improvement: 0.00047 Minibatch 336: After ITERATION: 15 Objective: 1039473.65 Old objective: 1038973.57 Improvement: 0.00048 Minibatch 337: After ITERATION: 15 Objective: 1073516.73 Old objective: 1073039.06 Improvement: 0.00045 Minibatch 338: After ITERATION: 15 Objective: 1131881.75 Old objective: 1131397.01 Improvement: 0.00043 Minibatch 339: After ITERATION: 15 Objective: 1081600.08 Old objective: 1081118.21 Improvement: 0.00045 Minibatch 340: After ITERATION: 15 Objective: 1060541.36 Old objective: 1060060.66 Improvement: 0.00045 Minibatch 341: After ITERATION: 15 Objective: 1085819.39 Old objective: 1085344.10 Improvement: 0.00044 Minibatch 342: After ITERATION: 15 Objective: 1120657.77 Old objective: 1120163.03 Improvement: 0.00044 Minibatch 343: After ITERATION: 15 Objective: 1086096.47 Old objective: 1085602.53 Improvement: 0.00045 Minibatch 344: After ITERATION: 15 Objective: 1016054.40 Old objective: 1015585.74 Improvement: 0.00046 Minibatch 345: After ITERATION: 15 Objective: 1024968.83 Old objective: 1024483.74 Improvement: 0.00047 Minibatch 346: After ITERATION: 15 Objective: 1024766.33 Old objective: 1024270.37 Improvement: 0.00048 Minibatch 347: After ITERATION: 15 Objective: 1071465.65 Old objective: 1070975.24 Improvement: 0.00046 Minibatch 348: After ITERATION: 15 Objective: 1085319.86 Old objective: 1084823.87 Improvement: 0.00046 Minibatch 349: After ITERATION: 15 Objective: 1069978.64 Old objective: 1069490.76 Improvement: 0.00046 Minibatch 350: After ITERATION: 15 Objective: 1125202.56 Old objective: 1124713.97 Improvement: 0.00043 Minibatch 351: After ITERATION: 15 Objective: 1109112.28 Old objective: 1108605.04 Improvement: 0.00046 Minibatch 352: After ITERATION: 15 Objective: 1080803.82 Old objective: 1080348.42 Improvement: 0.00042 Minibatch 353: After ITERATION: 15 Objective: 1112990.55 Old objective: 1112501.74 Improvement: 0.00044 Minibatch 354: After ITERATION: 15 Objective: 1136617.48 Old objective: 1136130.21 Improvement: 0.00043 Minibatch 355: After ITERATION: 15 Objective: 1130881.53 Old objective: 1130410.07 Improvement: 0.00042 Minibatch 356: After ITERATION: 15 Objective: 1079733.41 Old objective: 1079255.95 Improvement: 0.00044 Minibatch 357: After ITERATION: 15 Objective: 1079419.66 Old objective: 1078932.18 Improvement: 0.00045 Minibatch 358: After ITERATION: 15 Objective: 1092988.90 Old objective: 1092499.83 Improvement: 0.00045 Minibatch 359: After ITERATION: 15 Objective: 1088301.20 Old objective: 1087839.03 Improvement: 0.00042 Minibatch 360: After ITERATION: 15 Objective: 1116411.89 Old objective: 1115935.17 Improvement: 0.00043 Minibatch 361: After ITERATION: 15 Objective: 1120961.60 Old objective: 1120464.83 Improvement: 0.00044 Minibatch 362: After ITERATION: 15 Objective: 1101373.62 Old objective: 1100895.62 Improvement: 0.00043 Minibatch 363: After ITERATION: 15 Objective: 1052165.65 Old objective: 1051686.17 Improvement: 0.00046 Minibatch 364: After ITERATION: 15 Objective: 1045082.61 Old objective: 1044591.19 Improvement: 0.00047 Minibatch 365: After ITERATION: 14 Objective: 1106699.89 Old objective: 1106155.34 Improvement: 0.00049 Minibatch 366: After ITERATION: 15 Objective: 1110061.80 Old objective: 1109568.83 Improvement: 0.00044 Minibatch 367: After ITERATION: 15 Objective: 1049802.42 Old objective: 1049321.96 Improvement: 0.00046 Minibatch 368: After ITERATION: 15 Objective: 1081693.68 Old objective: 1081217.30 Improvement: 0.00044 Minibatch 369: After ITERATION: 15 Objective: 1074035.69 Old objective: 1073549.99 Improvement: 0.00045 Minibatch 370: After ITERATION: 14 Objective: 1188333.55 Old objective: 1187756.43 Improvement: 0.00049 Minibatch 371: After ITERATION: 14 Objective: 1190537.93 Old objective: 1189976.66 Improvement: 0.00047 Minibatch 372: After ITERATION: 14 Objective: 265251.94 Old objective: 265135.87 Improvement: 0.00044
OnlinePoissonMF(batch_size=1000, max_iter=100, n_components=100, n_pass=1, random_state=98765, shuffle=True, smoothness=100, tol=0.0005, verbose=True)
# the last batch is not full
plot(online_coder_full.bound[:-1])
pass
tagger = pmf.PoissonMF(n_components=n_components, random_state=98765, verbose=True)
tagger.set_components(online_coder_full.gamma_b[:, :K], online_coder_full.rho_b[:, :K])
PoissonMF(max_iter=100, n_components=100, random_state=98765, smoothness=100, tol=0.0005, verbose=True)
Et = tagger.transform(X_test)
After ITERATION: 15 Objective: 3303796.93 Old objective: 3302344.12 Improvement: 0.00044
Et /= Et.sum(axis=1, keepdims=True)
tags_predicted = Et.dot(online_coder_full.Eb[:, K:])
print tags_predicted.min(), tags_predicted.max()
div_factor = 3
tags_predicted = tags_predicted - div_factor * np.mean(tags_predicted, axis=0)
2.70206653139e-05 1.07069406765
predictat = 20
tags_true_binary = (y_test > 0)
print_out_metrics(tags_true_binary, tags_predicted, predictat)
Precision = 0.131 (0.008) Recall = 0.154 (0.008) F-score = 0.141 AROC = 0.718 (0.005) AP = 0.122 (0.006)