In [1]:
import numpy as np
from sklearn.linear_model import LassoCV, RidgeCV
from sklearn.gaussian_process.kernels import Matern, RBF

import plotly
import plotly.express as px

from docs.mse_estimator import ErrorComparer
from docs.data_generation import gen_rbf_X, gen_matern_X, gen_cov_mat
from docs.plotting_utils import gen_model_barplots
from spe.estimators import new_y_est, cp_arbitrary

# Trace Correction vs Random Correction

Simulations comparing using the deterministic trace correction $\mathrm{tr}(\Theta_p (\Sigma_{Y^*} - \Sigma_{W^\perp}))$ compared with the random correction $\mathrm{tr}(\Theta_p (\Sigma_{Y^*} - \Sigma_Y)) - \|\Sigma_Y\Sigma_\omega^{-1}\omega\|_2^2$

In [2]:
np.random.seed(1)

In [3]:
## number of realizations to run
niter = 100

## data generation parameters
gsize=10
n=20**2
p=200
s=5
delta = 0.75
snr = 0.4
tr_frac = .25

noise_kernel = 'matern'
noise_length_scale = 1.
noise_nu = .5

X_kernel = 'matern'
X_length_scale = 5.
X_nu = 2.5

## ErrorComparer parameters
alpha = .05
nboot = 100
lambdas = np.logspace(.01, 10,5)
models = [LassoCV(alphas=lambdas), RidgeCV(alphas=lambdas)]
ests = [
    new_y_est,
    new_y_est,
    cp_arbitrary,
    cp_arbitrary,
]
est_kwargs = [
    {'alpha': None,
    'full_refit': False},
    {'alpha': alpha,
    'full_refit': False},
    {'alpha': alpha, 
    'use_trace_corr': False, 
    'nboot': nboot},
    {'alpha': alpha, 
    'use_trace_corr': True, 
    'nboot': nboot}
]

## plot parameters
model_names = ["Lasso CV", "Ridge CV"]
est_names = ["Rand Corr", "Trace Corr"]

In [4]:
err_cmp = ErrorComparer()

In [5]:
nx = ny = int(np.sqrt(n))
xs = np.linspace(0, gsize, nx)
ys = np.linspace(0, gsize, ny)
c_x, c_y = np.meshgrid(xs, ys)
c_x = c_x.flatten()
c_y = c_y.flatten()
coord = np.stack([c_x, c_y]).T

In [6]:
if noise_kernel == 'rbf':
    Sigma_t = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))
elif noise_kernel == 'matern':
    Sigma_t = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))
else:
    Sigma_t = np.eye(n)
    
Cov_y_ystar = delta*Sigma_t
Sigma_t = delta*Sigma_t + (1-delta)*np.eye(n)

if noise_kernel == 'rbf' or noise_kernel == 'matern':
    Chol_y = np.linalg.cholesky(Sigma_t)
else:
    Chol_y = np.eye(n)

In [7]:
if X_kernel == 'rbf':
    X = gen_rbf_X(c_x, c_y, p)
elif X_kernel == 'matern':
    X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)
else:
    X = np.random.randn(n,p)

beta = np.zeros(p)
idx = np.random.choice(p,size=s,replace=False)
beta[idx] = np.random.uniform(-1,1,size=s)

In [8]:
tr_idx = np.ones(n, dtype=bool)

In [9]:
model_errs = []

for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr,
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_y,
        Chol_ystar=None,
        tr_idx=tr_idx,
        fair=False,
    )
    model_errs.append(errs)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [06:22<00:00,  3.82s/it]
100%|██████████| 100/100 [25:56<00:00, 15.56s/it]


In [13]:
plotly.offline.init_notebook_mode()
fig = gen_model_barplots(
    model_errs,
    model_names,
    est_names,
    title="Var Reduction: Rand vs Trace Correction, NSN",
    has_elev_err=True,
    err_bars=True,
    color_discrete_sequence=[px.colors.qualitative.Bold[i] for i in [0,4]],
)
fig.show()

In [11]:
corr_model_errs = []

for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr,
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_y,
        Chol_y_ystar=Cov_y_ystar,
        tr_idx=tr_idx,
        fair=False,
    )
    corr_model_errs.append(errs)

100%|██████████| 100/100 [03:38<00:00,  2.18s/it]
100%|██████████| 100/100 [17:03<00:00, 10.24s/it]


In [12]:
corr_fig = gen_model_barplots(
    corr_model_errs,
    model_names,
    est_names,
    title="Variance Reduction: Random vs Trace Correction, SSN",
    has_elev_err=True,
    err_bars=True,
    color_discrete_sequence=[px.colors.qualitative.Bold[i] for i in [0,4]],
)
corr_fig.show()