from pylab import *
from shogun.Mathematics import LogDetEstimator
from shogun.Mathematics import ProbingSampler
from shogun.Mathematics import NormalSampler
from shogun.Library import SerialComputationEngine
from shogun.Mathematics import LogRationalApproximationCGM
from shogun.Mathematics import RealSparseMatrixOperator
from shogun.Mathematics import LanczosEigenSolver
from shogun.Mathematics import CCGMShiftedFamilySolver
from shogun.Mathematics import Statistics
from scipy.sparse import csc_matrix
from scipy.sparse import spdiags
# create diagonal matrix with challenging eigenspectrum
n=100
difficulty_level=2 # should be in [1,infty]
min_eig=1e-10
data=abs(randn(n)**difficulty_level)+min_eig
plot(sort(data))
title("Eigenspectrum")
matrix=csc_matrix(spdiags(data, 0, n, n))
true_log_det=sum(log(data))
# create shogun representation and prepare things for log-det sampler
op=RealSparseMatrixOperator(csc_matrix(matrix))
engine=SerialComputationEngine()
linear_solver=CCGMShiftedFamilySolver()
accuracy=1e-5
eigen_solver=LanczosEigenSolver(op)
eigen_solver.compute()
op_func=LogRationalApproximationCGM(op, engine, eigen_solver, linear_solver, accuracy)
print "computed eigenvalues:", eigen_solver.get_min_eigenvalue(), eigen_solver.get_max_eigenvalue()
print "true eigenvalues:", min(data), max(data)
computed eigenvalues: 0.00188810553797 6.09129291465 true eigenvalues: 0.00188810553799 6.09129291465
def plot_estimate_convergence(estimates, true_value):
plot(cumsum(estimates)/(arange(n_estimates)+1))
plot([0,n_estimates], [true_value, true_value])
legend(["Estimates", "True"])
# log det with probing sampler
trace_sampler=ProbingSampler(op)
log_det_estimator=LogDetEstimator(trace_sampler, op_func, engine)
n_estimates=10
estimates=log_det_estimator.sample(n_estimates)
plot_estimate_convergence(estimates, true_log_det)
title("Probe sampler")
print "Probe sampler:", mean(estimates)
print "True Value", true_log_det
print estimates
Probe sampler: -116.782753434 True Value -116.782629602 [-116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343 -116.78275343]
# log det with normal sampler
trace_sampler=NormalSampler(n)
log_det_estimator=LogDetEstimator(trace_sampler, op_func, engine)
n_estimates=100
estimates=log_det_estimator.sample(n_estimates)
plot_estimate_convergence(estimates, true_log_det)
title("Normal sampler")
print "Normal sampler:", mean(estimates)
print "True Value", true_log_det
Normal sampler: -114.82351731 True Value -116.782629602
normal_sampler=NormalSampler(n)
normal_sampler.precompute()
print normal_sampler.sample(0)
probing_sampler=ProbingSampler(op)
probing_sampler.precompute()
print probing_sampler.sample(0)
[ 1.22653986 1.3883058 -0.47655529 -1.21699776 -1.15778274 1.84666631 2.494643 -0.1874382 -0.82103147 0.85768052 -1.50545777 0.97528306 -0.19671847 0.08502129 -0.48344791 0.41546624 0.18628257 0.42903447 -0.7815423 -0.30006225 1.19479564 1.67644109 -0.73432638 0.34468957 -1.08222905 -1.4007216 0.73667239 1.3147536 0.21813536 -0.48862077 -0.35285779 -0.1624966 1.57846492 0.94328045 -0.82385468 -2.04247316 -0.49095529 0.53338902 -2.08100983 0.70900668 -0.80084716 1.33329597 -0.37854009 -0.12754887 2.30423816 -1.19082578 0.75991217 0.39801085 -2.33360796 0.07726854 -1.11620565 -1.04109785 0.30610265 0.12297408 1.80658378 -0.19238408 -1.3095859 -1.04766519 -0.41542999 -1.46094381 -1.61996456 -1.03015577 1.21373168 -0.34655053 0.55497302 -1.26072115 -0.94741222 -0.52560867 0.22969606 -0.81701341 -0.58335553 -0.54548283 -0.23018482 -0.29006255 1.8405689 0.64690527 -0.0191345 0.15867294 -0.36197828 0.00632797 -0.35797949 0.74199818 -0.69998587 -1.62932089 1.1716233 -0.05668058 0.93409355 1.63175137 0.47041927 1.0323755 -1.3066249 1.26229966 -0.59269603 -0.89441963 1.81748294 -0.05738524 -1.33727461 0.86395343 -0.43426314 -0.05963288] [ 1. -1. -1. 1. -1. 1. -1. -1. 1. -1. 1. 1. -1. -1. -1. -1. -1. 1. -1. 1. -1. 1. 1. 1. 1. 1. 1. 1. -1. -1. -1. -1. 1. -1. 1. -1. 1. 1. 1. 1. 1. 1. 1. -1. 1. 1. -1. 1. 1. 1. -1. -1. -1. 1. 1. 1. -1. -1. -1. 1. 1. -1. -1. -1. -1. 1. 1. 1. 1. -1. -1. 1. -1. 1. -1. 1. -1. 1. -1. -1. 1. -1. -1. 1. 1. -1. 1. 1. -1. -1. -1. 1. -1. -1. 1. 1. 1. 1. -1. -1.]