{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.gaussian_process.kernels import Matern, RBF\n", "from sklearn.tree import DecisionTreeRegressor\n", "\n", "import plotly\n", "import plotly.express as px\n", "\n", "from docs.mse_estimator import ErrorComparer\n", "from docs.data_generation import gen_rbf_X, gen_matern_X, gen_cov_mat\n", "from docs.plotting_utils import gen_model_barplots\n", "from spe.estimators import new_y_est, cp_arbitrary, by_spatial" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generalized BY Comparison\n", "Here we compare ```spe.estimators.cp_arbitrary``` to a generalized version of the BY estimator, ```spe.estimators.by_spatial```, to estimate MSE on simulated data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "## number of realizations to run\n", "niter = 100\n", "\n", "## data generation parameters\n", "gsize=10\n", "n=10**2\n", "p=200\n", "s=200\n", "delta = 0.75\n", "snr = 2.\n", "tr_frac = .25\n", "\n", "noise_kernel = 'matern'\n", "noise_length_scale = 1.\n", "noise_nu = .5\n", "\n", "X_kernel = 'matern'\n", "X_length_scale = 5.\n", "X_nu = 2.5\n", "\n", "## ErrorComparer parameters\n", "alpha = .05\n", "nboot = 100\n", "k = 5\n", "\n", "models = [DecisionTreeRegressor(max_depth=5, max_features='sqrt')]\n", "\n", "ests = [\n", " new_y_est,\n", " new_y_est,\n", " cp_arbitrary,\n", " by_spatial,\n", " by_spatial,\n", " by_spatial,\n", "]\n", "est_kwargs = [\n", " {'alpha': None,\n", " 'full_refit': False},\n", " {'alpha': .05,\n", " 'full_refit': False},\n", " {'alpha': alpha, \n", " 'use_trace_corr': False, \n", " 'nboot': nboot},\n", " {'alpha': .05, \n", " 'nboot': nboot},\n", " {'alpha': 1., \n", " 'nboot': nboot},\n", " {'alpha': 5., \n", " 'nboot': nboot},\n", "]\n", "\n", "## plot parameters\n", "est_names = [\"GenCp .05\", \"BY .05\", \"BY 1.\", \"BY 5.\"]\n", "model_names = [\"Decision Tree\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "err_cmp = ErrorComparer()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "nx = ny = int(np.sqrt(n))\n", "xs = np.linspace(0, gsize, nx)\n", "ys = np.linspace(0, gsize, ny)\n", "c_x, c_y = np.meshgrid(xs, ys)\n", "c_x = c_x.flatten()\n", "c_y = c_y.flatten()\n", "coord = np.stack([c_x, c_y]).T" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "if noise_kernel == 'rbf':\n", " Sigma_t = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))\n", "elif noise_kernel == 'matern':\n", " Sigma_t = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))\n", "else:\n", " Sigma_t = np.eye(n)\n", " \n", "Cov_y_ystar = delta*Sigma_t\n", "Sigma_t = delta*Sigma_t + (1-delta)*np.eye(n)\n", "\n", "if noise_kernel == 'rbf' or noise_kernel == 'matern':\n", " Chol_y = np.linalg.cholesky(Sigma_t)\n", "else:\n", " Chol_y = np.eye(n)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "if X_kernel == 'rbf':\n", " Sigma_X = gen_cov_mat(c_x, c_y, RBF(length_scale=X_length_scale))\n", "elif X_kernel == 'matern':\n", " Sigma_X = gen_cov_mat(c_x, c_y, Matern(length_scale=X_length_scale, nu=X_nu))\n", "else:\n", " Sigma_X = np.eye(n)\n", "\n", "if X_kernel == 'rbf' or X_kernel == 'matern':\n", " Chol_X = np.linalg.cholesky(Sigma_X)\n", "else:\n", " Chol_X = np.eye(n)\n", "\n", "X = Chol_X @ np.random.randn(n,p)\n", "\n", "if X_kernel == 'rbf':\n", " X_spikes = gen_rbf_X(c_x, c_y, p)\n", "elif X_kernel == 'matern':\n", " X_spikes = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)\n", "else:\n", " X_spikes = np.random.randn(n,p)\n", "\n", "X_iso = np.random.randn(n,p)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "nx = ny = int(np.sqrt(n))\n", "xs = np.linspace(0, 30, nx)\n", "ys = np.linspace(0, 30, ny)\n", "c_x, c_y = np.meshgrid(xs, ys)\n", "c_x = c_x.flatten()\n", "c_y = c_y.flatten()\n", "coord = np.stack([c_x, c_y]).T\n", "\n", "if X_kernel == 'rbf':\n", " Sigma_X_less = gen_cov_mat(c_x, c_y, RBF(length_scale=X_length_scale))\n", "elif X_kernel == 'matern':\n", " Sigma_X_less = gen_cov_mat(c_x, c_y, Matern(length_scale=X_length_scale, nu=X_nu))\n", "else:\n", " Sigma_X_less = np.eye(n)\n", "\n", "if X_kernel == 'rbf' or X_kernel == 'matern':\n", " Chol_X_less = np.linalg.cholesky(Sigma_X_less)\n", "else:\n", " Chol_X_less = np.eye(n)\n", "\n", "X_less = Chol_X_less @ np.random.randn(n,p)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "beta = np.zeros(p)\n", "idx = np.random.choice(p,size=s,replace=False)\n", "beta[idx] = np.random.uniform(-1,1,size=s)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "tr_idx = np.ones(n, dtype=bool)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Simulate $\\begin{pmatrix} Y \\\\ Y^* \\end{pmatrix} \\sim \\mathcal{N}\\left(\\begin{pmatrix} \\mu \\\\ \\mu \\end{pmatrix}, \\begin{pmatrix}\\Sigma_Y & \\Sigma_{Y, Y^*} \\\\ \\Sigma_{Y^*, Y} & \\Sigma_{Y} \\end{pmatrix}\\right)$" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### $X_{\\cdot,i}$ independently generated by uniform spikes at locations, then interpolate based on cov matrix" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/100 [00:00, ?it/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:47<00:00, 2.12it/s]\n" ] } ], "source": [ "spike_model_errs = []\n", "\n", "for model in models:\n", " errs = err_cmp.compare(\n", " model,\n", " ests,\n", " est_kwargs,\n", " niter=niter,\n", " n=n,\n", " p=p,\n", " s=s,\n", " snr=snr, \n", " X=X_spikes,\n", " beta=beta,\n", " coord=coord,\n", " Chol_y=Chol_y,\n", " Chol_ystar=None,\n", " Cov_y_ystar=Cov_y_ystar,\n", " tr_idx=tr_idx,\n", " fair=False,\n", " est_sigma=False,\n", " )\n", " spike_model_errs.append(errs)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### $X_{i,\\cdot} \\sim \\mathcal{N}(0, I\\sigma^2)$" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:46<00:00, 2.15it/s]\n" ] } ], "source": [ "iid_model_errs = []\n", "\n", "for model in models:\n", " errs = err_cmp.compare(\n", " model,\n", " ests,\n", " est_kwargs,\n", " niter=niter,\n", " n=n,\n", " p=p,\n", " s=s,\n", " snr=snr, \n", " X=X_iso,\n", " beta=beta,\n", " coord=coord,\n", " Chol_y=Chol_y,\n", " Chol_ystar=None,\n", " Cov_y_ystar=Cov_y_ystar,\n", " tr_idx=tr_idx,\n", " fair=False,\n", " est_sigma=False,\n", " )\n", " iid_model_errs.append(errs)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "error_y": { "array": [ 0.08350345486789641, 0.1268875108559484, 0.11223427882164304, 0.10747919834475006 ], "color": "black", "type": "data" }, "marker": { "color": [ "rgb(127, 60, 141)", "rgb(230, 131, 16)", "rgb(0, 134, 149)", "rgb(207, 28, 144)" ] }, "text": [ 1.066, 1.003, 0.988, 0.954 ], "textposition": "outside", "type": "bar", "x": [ "GenCp .05", "BY .05", "BY 1.", "BY 5." ], "xaxis": "x", "y": [ 1.0655896396810973, 1.0033213633383096, 0.9879701738826541, 0.9544918574063418 ], "yaxis": "y" }, { "error_y": { "array": [ 0.08271142483515297, 0.2233512085280012, 0.26838324374953865, 0.24290523160060365 ], "color": "black", "type": "data" }, "marker": { "color": [ "rgb(127, 60, 141)", "rgb(230, 131, 16)", "rgb(0, 134, 149)", "rgb(207, 28, 144)" ] }, "text": [ 1.043, 0.995, 0.983, 1.02 ], "textposition": "outside", "type": "bar", "x": [ "GenCp .05", "BY .05", "BY 1.", "BY 5." ], "xaxis": "x2", "y": [ 1.04320514199614, 0.9950397672496221, 0.982960321243187, 1.0197349201261652 ], "yaxis": "y2" } ], "layout": { "annotations": [ { "font": { "size": 16 }, "showarrow": false, "text": "Corr X", "x": 0.225, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper" }, { "font": { "size": 16 }, "showarrow": false, "text": "IID X", "x": 0.775, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper" } ], "font": { "size": 15 }, "height": 600, "shapes": [ { "line": { "color": "red" }, "type": "line", "x0": 0, "x1": 1, "xref": "x domain", "y0": 1, "y1": 1, "yref": "y" }, { "line": { "color": "grey", "dash": "dash" }, "type": "line", "x0": 0, "x1": 1, "xref": "x domain", "y0": 1.047751214982157, "y1": 1.047751214982157, "yref": "y" }, { "line": { "color": "red" }, "type": "line", "x0": 0, "x1": 1, "xref": "x2 domain", "y0": 1, "y1": 1, "yref": "y2" }, { "line": { "color": "grey", "dash": "dash" }, "type": "line", "x0": 0, "x1": 1, "xref": "x2 domain", "y0": 1.0515060198974568, "y1": 1.0515060198974568, "yref": "y2" } ], "showlegend": false, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "BY Comparisons: Depth 5 Decision Tree, SSN" }, "width": 900, "xaxis": { "anchor": "y", "domain": [ 0, 0.45 ], "title": { "text": "Method" } }, "xaxis2": { "anchor": "y2", "domain": [ 0.55, 1 ], "title": { "text": "Method" } }, "yaxis": { "anchor": "x", "domain": [ 0, 1 ], "title": { "text": "Relative MSE" } }, "yaxis2": { "anchor": "x2", "domain": [ 0, 1 ] } } }, "text/html": [ "