{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "dd2996fd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/06/s3csl9g94gx2ptsgwz8cpfvh0000gn/T/ipykernel_72302/1955058804.py:2: DeprecationWarning: \n", "Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),\n", "(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)\n", "but was not found to be installed on your system.\n", "If this would cause problems for you,\n", "please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466\n", " \n", " import pandas as pd\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from sklearn.gaussian_process.kernels import Matern, RBF\n", "from sklearn.cluster import KMeans\n", "\n", "from docs.data_generation import gen_rbf_X, gen_matern_X, create_clus_split, gen_cov_mat\n", "from docs.plotting_utils import gen_model_barplots\n", "\n", "import matplotlib.pyplot as plt\n", "import plotly.express as px\n", "\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "4f605875", "metadata": {}, "source": [ "# CV Corrections\n", "Simulations demonstrating sampling settings where CV methods do properly estimate MSE even when data are correlated." ] }, { "attachments": {}, "cell_type": "markdown", "id": "43761bf6", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "id": "2a0c782c", "metadata": {}, "outputs": [], "source": [ "np.random.seed(1)" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc5e0c98", "metadata": {}, "outputs": [], "source": [ "## number of realizations to run\n", "niter = 100\n", "\n", "## data generation parameters\n", "gsize = 10\n", "n=60**2\n", "p=50\n", "s=50\n", "delta = 0.75\n", "snr = 0.4\n", "tr_frac = .2\n", "bloo_radius = gsize // 2\n", "k = 10\n", "\n", "noise_kernel = 'matern'\n", "noise_length_scale = 1.\n", "noise_nu = .5\n", "\n", "X_kernel = 'matern'\n", "X_length_scale = 5.\n", "X_nu = 2.5\n", "\n", "est_names = [\"KFCV\", \"SPCV\", \"BLOOCV\"]" ] }, { "cell_type": "code", "execution_count": 4, "id": "66717172", "metadata": {}, "outputs": [], "source": [ "nx = ny = int(np.sqrt(n))\n", "xs = np.linspace(0, gsize, nx)\n", "ys = np.linspace(0, gsize, ny)\n", "c_x, c_y = np.meshgrid(xs, ys)\n", "c_x = c_x.flatten()\n", "c_y = c_y.flatten()\n", "coord = np.stack([c_x, c_y]).T" ] }, { "cell_type": "code", "execution_count": 5, "id": "28c5cb4b", "metadata": {}, "outputs": [], "source": [ "if noise_kernel == 'rbf':\n", " Sigma_y = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))\n", "elif noise_kernel == 'matern':\n", " Sigma_y = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))\n", "else:\n", " Sigma_y = np.eye(n)\n", " \n", "Cov_y_ystar = delta*Sigma_y\n", "Sigma_y = delta*Sigma_y + (1-delta)*np.eye(n)\n", "\n", "if noise_kernel == 'rbf' or noise_kernel == 'matern':\n", " Chol_y = np.linalg.cholesky(Sigma_y)\n", "else:\n", " Chol_y = np.eye(n)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b38a96e2", "metadata": {}, "source": [ "## Generate Gaussian X, Y" ] }, { "cell_type": "code", "execution_count": 6, "id": "aad3d707", "metadata": {}, "outputs": [], "source": [ "def gen_X():\n", " if X_kernel == 'rbf':\n", " X = gen_rbf_X(c_x, c_y, p)\n", " elif X_kernel == 'matern':\n", " X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)\n", " else:\n", " X = np.random.randn(n,p)\n", " return X\n", "\n", "X = gen_X()\n", "\n", "beta = np.zeros(p)\n", "idx = np.random.choice(p,size=s,replace=False)\n", "beta[idx] = np.random.uniform(-1,1,size=s)" ] }, { "cell_type": "code", "execution_count": 7, "id": "2ad4c5f3", "metadata": {}, "outputs": [], "source": [ "Chol_y *= np.std(X@beta) / np.sqrt(snr)\n", "Sigma_y = Chol_y @ Chol_y.T" ] }, { "cell_type": "code", "execution_count": 8, "id": "cc36d0d3", "metadata": {}, "outputs": [], "source": [ "Y = X @ beta + Chol_y @ np.random.randn(n)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "099eda8e", "metadata": {}, "source": [ "## Compute expected correction for one fold of CV vs expected correction for random sample" ] }, { "cell_type": "code", "execution_count": 9, "id": "12bd27ee", "metadata": {}, "outputs": [], "source": [ "## assuming linear model, and bias approx 0\n", "def computeCorrection(\n", " S, \n", " Sigma, \n", " tr_idx, \n", " ts_idx=None,\n", "):\n", " if ts_idx is None:\n", " ts_idx = ~tr_idx\n", "\n", " return (Sigma[ts_idx,ts_idx].sum() - 2*np.diag(S @ Sigma[tr_idx,:][:,ts_idx]).sum()) / ts_idx.sum()" ] }, { "cell_type": "code", "execution_count": 10, "id": "9b1695f8", "metadata": {}, "outputs": [], "source": [ "def getDistance(c_x, c_y):\n", " Loc = np.stack([c_x, c_y]).T\n", " m = np.sum(Loc**2, axis=1)\n", " D = (-2 * Loc.dot(Loc.T) + m).T + m\n", " D = 0.5 * (D + D.T)\n", " D = np.maximum(D, 0) ## sometimes gets values like -1e-9\n", " D = np.sqrt(D)\n", " \n", " return D \n", "\n", "def getBufferTrain(D, tr_bool, ts_idx):\n", " buffer_tr_bool = tr_bool & (D[ts_idx,:] > bloo_radius)\n", " return buffer_tr_bool" ] }, { "cell_type": "code", "execution_count": 11, "id": "664752da", "metadata": {}, "outputs": [], "source": [ "D = getDistance(c_x, c_y)" ] }, { "cell_type": "code", "execution_count": 12, "id": "492499a7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:15<00:00, 6.29it/s]\n" ] } ], "source": [ "kfcv_corr = np.zeros(niter)\n", "spcv_corr = np.zeros(niter)\n", "bloocv_corr = np.zeros(niter)\n", "ns_corr = np.zeros(niter)\n", "sp_corr = np.zeros(niter)\n", "\n", "for i in tqdm(range(niter)):\n", " idxs = np.random.choice(n, size=int(tr_frac*n), replace=False)\n", " \n", " cv_tr_idx = idxs[:int(tr_frac*tr_frac*n)]\n", " cv_tr_bool = np.zeros(n, dtype=bool)\n", " cv_tr_bool[cv_tr_idx] = True\n", "\n", " cv_ts_idx = idxs[int(tr_frac*tr_frac*n):int(tr_frac*n)]\n", " cv_ts_bool = np.zeros(n, dtype=bool)\n", " cv_ts_bool[cv_ts_idx] = True\n", "\n", " tr_idx = idxs[:int(tr_frac*tr_frac*n)]\n", " tr_bool = np.zeros(n, dtype=bool)\n", " tr_bool[tr_idx] = True\n", "\n", " ts_idx = idxs[int(tr_frac*tr_frac*n):int(tr_frac*n)]\n", " ts_bool = np.zeros(n, dtype=bool)\n", " ts_bool[ts_idx] = True\n", "\n", " ## split into 3 folds by kmeans\n", " groups = KMeans(n_init=10, n_clusters=k).fit(coord[idxs[:int(tr_frac*n)]]).labels_\n", " # groups = KMeans(n_init=10, n_clusters=5).fit(coord[idxs[:int(tr_frac*tr_frac*n)]]).labels_\n", " spcv_tr_idx = idxs[np.where(groups < k-1)[0]]\n", " spcv_tr_bool = np.zeros(n, dtype=bool)\n", " spcv_tr_bool[spcv_tr_idx] = True\n", " spcv_ts_idx = idxs[np.where(groups == k-1)[0]]\n", " spcv_ts_bool = np.zeros(n, dtype=bool)\n", " spcv_ts_bool[spcv_ts_idx] = True\n", "\n", " ## pick one point for ts, tr is all far enough away\n", " bloocv_tr_idx = idxs[np.random.choice(len(idxs),size=len(idxs),replace=False)]\n", " bloocv_ts_idx = bloocv_tr_idx[0]\n", "\n", " bloocv_tr_bool = np.zeros(n, dtype=bool)\n", " bloocv_tr_bool[bloocv_tr_idx] = True\n", " bloocv_tr_bool = getBufferTrain(D, bloocv_tr_bool, bloocv_ts_idx)\n", "\n", " bloocv_ts_bool = np.zeros(n, dtype=bool)\n", " bloocv_ts_bool[bloocv_ts_idx] = True\n", "\n", " # X = gen_rbf_X(c_x, c_y, p)\n", " # X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)\n", " # Y = X@beta + Chol_y @ np.random.randn(n)\n", " X = gen_X()\n", " \n", " X_cv = X[cv_tr_idx,:]\n", " S_cv = X[cv_ts_idx,:] @ np.linalg.pinv(X_cv)\n", "\n", " X_spcv = X[spcv_tr_idx,:]\n", " S_spcv = X[spcv_ts_idx,:] @ np.linalg.pinv(X_spcv)\n", "\n", " X_bloocv = X[bloocv_tr_bool,:]\n", " S_bloocv = X[bloocv_ts_idx,:] @ np.linalg.pinv(X_bloocv)\n", "\n", " \n", " X_tr = X[tr_idx,:]\n", " S_tr = X[ts_idx,:] @ np.linalg.pinv(X_tr)\n", "\n", " sp_idx = np.random.choice(ts_idx,size=1)\n", " sp_bool = np.zeros(n, dtype=bool)\n", " sp_bool[sp_idx] = True\n", " S_sp = X[sp_idx,:] @ np.linalg.pinv(X_tr)\n", " kfcv_corr[i] = computeCorrection(S_cv, Sigma_y, cv_tr_bool, cv_ts_bool)\n", " spcv_corr[i] = computeCorrection(S_spcv, Sigma_y, spcv_tr_bool, spcv_ts_bool)\n", " bloocv_corr[i] = computeCorrection(S_bloocv, Sigma_y, bloocv_tr_bool, bloocv_ts_bool)\n", " ns_corr[i] = computeCorrection(S_tr, Sigma_y, tr_bool, ts_bool)\n", " sp_corr[i] = computeCorrection(S_sp, Sigma_y, tr_bool, sp_bool)\n", " \n", "corrs = pd.DataFrame({\n", " 'KFCV': kfcv_corr, \n", " 'SPCV': spcv_corr, \n", " 'BLOOCV': bloocv_corr, \n", "})\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "a47c9d44", "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "error_y": { "array": [ 0.005747306996259622, 0.15202399723035148, 0.020681685424080918 ], "color": "black", "type": "data" }, "marker": { "color": [ "rgb(17, 165, 121)", "rgb(57, 105, 172)", "rgb(242, 183, 1)" ] }, "text": [ 1, 1.054, 0.98 ], "textposition": "outside", "type": "bar", "x": [ "KFCV", "SPCV", "BLOOCV" ], "xaxis": "x", "y": [ 1, 1.0542110326159992, 0.980116034109643 ], "yaxis": "y" } ], "layout": { "font": { "size": 15 }, "height": 600, "shapes": [ { "line": { "color": "red" }, "type": "line", "x0": 0, "x1": 1, "xref": "x domain", "y0": 1, "y1": 1, "yref": "y" } ], "showlegend": false, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Train/Test Split: OLS Correction Term Comparisons" }, "width": 600, "xaxis": { "anchor": "y", "domain": [ 0, 1 ], "title": { "text": "Method" } }, "yaxis": { "anchor": "x", "domain": [ 0, 1 ], "title": { "text": "Relative MSE" } } } } }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = gen_model_barplots(\n", " [[ns_corr, kfcv_corr, spcv_corr, bloocv_corr]], \n", " [\"\"], \n", " est_names, \n", " title=\"Train/Test Split: OLS Correction Term Comparisons\", \n", " has_elev_err=False,\n", " err_bars=True,\n", " color_discrete_sequence=[px.colors.qualitative.Bold[i] for i in [1,2,3]],\n", ")\n", "fig.show()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d1637a13", "metadata": {}, "source": [ "## Compute expected correction for one fold of CV vs expected correction for clustered sample" ] }, { "cell_type": "code", "execution_count": 19, "id": "877ec770", "metadata": {}, "outputs": [], "source": [ "clus_kfcv_corr = np.zeros(niter)\n", "clus_spcv_corr = np.zeros(niter)\n", "clus_bloocv_corr = np.zeros(niter)\n", "clus_ns_corr = np.zeros(niter)\n", "clus_sp_corr = np.zeros(niter)\n", "\n", "# tr_frac /= 2" ] }, { "cell_type": "code", "execution_count": 20, "id": "56372311", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:15<00:00, 6.48it/s]\n" ] } ], "source": [ "\n", "for i in tqdm(range(niter)):\n", " idxs, ts_idx = create_clus_split(\n", " nx,\n", " ny,\n", " tr_frac,\n", " ngrid=5,\n", " ts_frac=tr_frac-tr_frac*tr_frac,\n", " sort_grids=False,\n", " )\n", " ## randomize order for CV\n", " idxs = np.random.choice(idxs, size=len(idxs), replace=False)\n", "\n", " cv_tr_idx = idxs[:int(tr_frac*len(idxs))]\n", " cv_tr_bool = np.zeros(n, dtype=bool)\n", " cv_tr_bool[cv_tr_idx] = True\n", "\n", " cv_ts_idx = idxs[int(tr_frac*len(idxs)):]\n", " cv_ts_bool = np.zeros(n, dtype=bool)\n", " cv_ts_bool[cv_ts_idx] = True\n", "\n", " tr_idx = idxs[:int(tr_frac*len(idxs))]\n", " tr_bool = np.zeros(n, dtype=bool)\n", " tr_bool[tr_idx] = True\n", "\n", " # ts_idx = idxs[int(tr_frac*tr_frac*n):]\n", " # ts_bool = np.zeros(n, dtype=bool)\n", " ts_bool[ts_idx] = True\n", "\n", " ## split into 3 folds by kmeans\n", " groups = KMeans(n_init=10, n_clusters=k).fit(coord[idxs]).labels_\n", " # groups = KMeans(n_init=10, n_clusters=k).fit(coord[idxs[:int(tr_frac*n)]]).labels_\n", " # groups = KMeans(n_init=10, n_clusters=5).fit(coord[idxs[:int(tr_frac*tr_frac*n)]]).labels_\n", " spcv_tr_idx = idxs[np.where(groups < k-1)[0]]\n", " spcv_tr_bool = np.zeros(n, dtype=bool)\n", " spcv_tr_bool[spcv_tr_idx] = True\n", " spcv_ts_idx = idxs[np.where(groups == k-1)[0]]\n", " spcv_ts_bool = np.zeros(n, dtype=bool)\n", " spcv_ts_bool[spcv_ts_idx] = True\n", "\n", " ## pick one point for ts, tr is all far enough away\n", " bloocv_tr_idx = idxs[np.random.choice(len(idxs),size=len(idxs),replace=False)]\n", " bloocv_ts_idx = bloocv_tr_idx[0]\n", "\n", " bloocv_tr_bool = np.zeros(n, dtype=bool)\n", " bloocv_tr_bool[bloocv_tr_idx] = True\n", " bloocv_tr_bool = getBufferTrain(D, bloocv_tr_bool, bloocv_ts_idx)\n", "\n", " bloocv_ts_bool = np.zeros(n, dtype=bool)\n", " bloocv_ts_bool[bloocv_ts_idx] = True\n", "\n", " # X = gen_rbf_X(c_x, c_y, p)\n", " # X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)\n", " # Y = X@beta + Chol_y @ np.random.randn(n)\n", " X = gen_X()\n", " \n", " X_cv = X[cv_tr_idx,:]\n", " S_cv = X[cv_ts_idx,:] @ np.linalg.pinv(X_cv)\n", "\n", " X_spcv = X[spcv_tr_idx,:]\n", " S_spcv = X[spcv_ts_idx,:] @ np.linalg.pinv(X_spcv)\n", "\n", " X_bloocv = X[bloocv_tr_bool,:]\n", " S_bloocv = X[bloocv_ts_idx,:] @ np.linalg.pinv(X_bloocv)\n", "\n", " \n", " X_tr = X[tr_idx,:]\n", " S_tr = X[ts_idx,:] @ np.linalg.pinv(X_tr)\n", "\n", " sp_idx = np.random.choice(ts_idx,size=1)\n", " sp_bool = np.zeros(n, dtype=bool)\n", " sp_bool[sp_idx] = True\n", "\n", " S_sp = X[sp_idx,:] @ np.linalg.pinv(X_tr)\n", "\n", " clus_kfcv_corr[i] = computeCorrection(S_cv, Sigma_y, cv_tr_bool, cv_ts_bool)\n", " clus_spcv_corr[i] = computeCorrection(S_spcv, Sigma_y, spcv_tr_bool, spcv_ts_bool)\n", " clus_bloocv_corr[i] = computeCorrection(S_bloocv, Sigma_y, bloocv_tr_bool, bloocv_ts_bool)\n", " clus_ns_corr[i] = computeCorrection(S_tr, Sigma_y, tr_bool, ts_bool)\n", " clus_sp_corr[i] = computeCorrection(S_sp, Sigma_y, tr_bool, sp_bool)\n", "\n", " # print(ts_bool.sum(), cv_tr_bool.sum(), cv_ts_bool.sum())\n", " \n", "clus_corrs = pd.DataFrame({\n", " 'KFCV': clus_kfcv_corr, \n", " 'SPCV': clus_spcv_corr, \n", " 'BLOOCV': clus_bloocv_corr, \n", "})\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "ae07968d", "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "error_y": { "array": [ 0.012430270915315607, 0.33235778116396403, 0.05121208574787478 ], "color": "black", "type": "data" }, "marker": { "color": [ "rgb(17, 165, 121)", "rgb(57, 105, 172)", "rgb(242, 183, 1)" ] }, "text": [ 0.878, 0.974, 0.86 ], "textposition": "outside", "type": "bar", "x": [ "KFCV", "SPCV", "BLOOCV" ], "xaxis": "x", "y": [ 0.877773585716249, 0.9735403702813207, 0.8600325112128235 ], "yaxis": "y" } ], "layout": { "font": { "size": 15 }, "height": 600, "shapes": [ { "line": { "color": "red" }, "type": "line", "x0": 0, "x1": 1, "xref": "x domain", "y0": 1, "y1": 1, "yref": "y" } ], "showlegend": false, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Spatial Train/Test Split: OLS Correction Term Comparisons" }, "width": 600, "xaxis": { "anchor": "y", "domain": [ 0, 1 ], "title": { "text": "Method" } }, "yaxis": { "anchor": "x", "domain": [ 0, 1 ], "title": { "text": "Relative MSE" } } } } }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = gen_model_barplots(\n", " [[clus_ns_corr, clus_kfcv_corr, clus_spcv_corr, clus_bloocv_corr]], \n", " [\"\"], \n", " est_names, \n", " title=\"Spatial Train/Test Split: OLS Correction Term Comparisons\", \n", " has_elev_err=False,\n", " err_bars=True,\n", " color_discrete_sequence=[px.colors.qualitative.Bold[i] for i in [1,2,3]],\n", ")\n", "fig.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "56d8bac8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "spe", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }