%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import regreg.api as rr
n, p, q = 1000, 5000, 10
Let's make some data
X = np.random.standard_normal((n,p))
beta = np.zeros((p,q))
beta[:3] = np.random.standard_normal((3,q)) * 2
Y = np.random.standard_normal((n,q)) + np.dot(X, beta)
The loss is squared error -- I fixed the primal shape (shape
) to (p,q)
because it does not guess that immediately in the code.
loss = rr.squared_error(X, Y)
loss.shape = (p, q)
The penalty is the $\ell_{1,2}$ norm of the matrix of coefficients $$ \sum_{i=1}^p \|A[i]\|_2. $$
Its generic solutions have entire rows set to 0 -- it is sort of like a row-wise soft-thresholding operation where the an entire row is shrunk towards 0.
penalty = rr.l1_l2((p,q), lagrange=2)
penalty
The smallest $\lambda$ such that $\hat{\beta}=0$ is Lmax
below. It can be found from the conjugate of the penalty, the $\ell_{\infty, 2}$ constraint. Since we want the function that is the norm and not the constraint, we create a new instance.
dual_constraint = penalty.conjugate
dual_constraint
dual_norm = type(dual_constraint)((p,q), lagrange=1)
dual_norm
Lmax = dual_norm.objective(loss.smooth_objective(np.zeros((p,q)), mode='grad'))
Lmax
6816.196847463144
Let's form our problem
problem = rr.simple_problem(loss, penalty)
Now, we solve the problem over a sequence of $\lambda$ values.
solns = []
Lvals = np.exp(np.linspace(np.log(0.01), 0, 50))[::-1] * Lmax
for L in Lvals:
penalty.lagrange = L
soln = problem.solve()
solns.append(soln.copy())
solns = np.array(solns)
solns.shape
(50, 5000, 10)
Let's plot the $\ell_2$ norm of each group of coefficients.
norms = np.sqrt((solns**2).sum(2))
lines = [plt.plot(Lvals, norms[:,i]) for i in range(p) if np.linalg.norm(norms[:,i]) > 1.e-3]
plt.show()