From 949c8cdc6fc2b568e3c18f517b7826dad034dbe4 Mon Sep 17 00:00:00 2001 From: HuangFuSL Date: Mon, 16 Dec 2024 17:25:52 +0800 Subject: [PATCH] Update: Gaussian Process --- docs/coding/machine-learning/.pages | 4 +- .../machine-learning/gaussian-process.ipynb | 9989 +++++++++++++++++ docs/coding/machine-learning/index.md | 6 +- 3 files changed, 9997 insertions(+), 2 deletions(-) create mode 100644 docs/coding/machine-learning/gaussian-process.ipynb diff --git a/docs/coding/machine-learning/.pages b/docs/coding/machine-learning/.pages index f3070fe9..5fea42fe 100644 --- a/docs/coding/machine-learning/.pages +++ b/docs/coding/machine-learning/.pages @@ -1,4 +1,6 @@ nav: - 线性模型: linear-models.ipynb - 决策树: decision-tree.ipynb - - 贝叶斯优化: bayesian-optimization.ipynb + - 贝叶斯优化: + - 高斯过程: gaussian-process.ipynb + - 算法: bayesian-optimization.ipynb diff --git a/docs/coding/machine-learning/gaussian-process.ipynb b/docs/coding/machine-learning/gaussian-process.ipynb new file mode 100644 index 00000000..70510420 --- /dev/null +++ b/docs/coding/machine-learning/gaussian-process.ipynb @@ -0,0 +1,9989 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 高斯过程" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.stats import norm\n", + "from matplotlib import pyplot as plt\n", + "import numpy as np\n", + "\n", + "%config InlineBackend.figure_format = 'svg'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "高斯过程(Gaussian Process,GP)是将多变量高斯分布推广到无限维度$\\mathcal X\\subseteq \\mathbb R^n$的概率分布。具体地,对于$X = \\{x_1, \\ldots, x_n\\}\\subseteq \\mathcal X$,随机变量$f(x_1), \\ldots, f(x_n)$服从多元高斯分布$\\mathcal N(\\mu(X), \\Sigma(X))$。其中,$\\mu(X)$为均值函数,$\\Sigma(X)$为协方差函数。因此,高斯分布研究的是函数的概率分布。\n", + "\n", + "多元高斯分布的条件分布依然是高斯分布。设随机变量$X = (X_1, X_2)$服从多元高斯分布$\\mathcal N(\\mu, \\Sigma)$,其中$\\mu = (\\mu_1, \\mu_2)$,$\\Sigma = \\begin{bmatrix} \\Sigma_{11} & \\Sigma_{12} \\\\ \\Sigma_{21} & \\Sigma_{22} \\end{bmatrix}$,则给定$X_1 = x_1$后,$X_2$的条件分布为:\n", + "\n", + "$$\n", + "(X_2\\mid X_1 = x_1) \\sim \\mathcal N(\\mu_2 + \\Sigma_{21}\\Sigma_{11}^{-1}(x_1 - \\mu_1), \\Sigma_{22} - \\Sigma_{21}\\Sigma_{11}^{-1}\\Sigma_{12})\n", + "$$\n", + "\n", + "设一组观测数据数据$\\boldsymbol X = x_1, \\ldots, x_n$及其对应的函数值$\\boldsymbol Y = y_1, \\ldots, y_n$。对于一个新的数据点$x$,要预测其对应的函数值$y = f(x)$的分布。如果我们能计算出观测数据$\\boldsymbol Y$和预测变量$y$之间的协方差,便可以通过高斯分布的条件分布计算得到$f(x)$的均值和方差。\n", + "\n", + "此处令$n = 1$,即$X$取值范围为整个实数域。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def y_posterior(x, x_obs, y_obs, kernel_func):\n", + " # x: (num_features)\n", + " # x_obs: (num_observations, num_features)\n", + " # y_obs: (num_observations)\n", + "\n", + " sigma_21 = kernel_func(x, x_obs) # dim: (num_observations)\n", + " sigma_22 = kernel_func(x, x) # dim: scalar\n", + " # dim: (num_observations, num_observations)\n", + " sigma_11 = kernel_func(x_obs, x_obs)\n", + "\n", + " inv_11 = np.linalg.inv(sigma_11 + 1e-8 * np.eye(sigma_11.shape[0]))\n", + " mu = np.einsum('i,ij,j->', sigma_21, inv_11, y_obs)\n", + " sigma = sigma_22 - np.einsum('j,jk,k->', sigma_21, inv_11, sigma_21)\n", + " return mu, sigma\n", + "\n", + "def generate_samples(num_samples=5, target_function=None, sigma=1):\n", + " X = np.random.uniform(-5, 5, num_samples).reshape(-1, 1)\n", + " Y = np.random.normal(0, sigma, num_samples)\n", + " if target_function is not None:\n", + " Y += target_function(X)\n", + " return X, Y\n", + "\n", + "def plot_posterior(kernel, num_samples=5, X_obs=None, Y_obs=None, ax=None, grid_size=200):\n", + " if X_obs is None != Y_obs is None:\n", + " raise ValueError('X_obs must be provided if Y_obs is provided.')\n", + " if (X_obs is None or Y_obs is None) and num_samples is None:\n", + " raise ValueError('num_samples must be provided if X_obs is provided.')\n", + " if X_obs is None:\n", + " assert Y_obs is None\n", + " X_obs, Y_obs = generate_samples(num_samples)\n", + " assert X_obs is not None and Y_obs is not None\n", + " if X_obs.shape[0] != Y_obs.shape[0]:\n", + " raise ValueError('The number of observations must be the same.')\n", + "\n", + " X_min, X_max = np.min(X_obs), np.max(X_obs)\n", + " alpha_95 = norm.ppf(0.975)\n", + "\n", + " if ax is None:\n", + " fig, ax = plt.subplots()\n", + " else:\n", + " fig = None\n", + "\n", + " X = np.linspace(X_min, X_max, grid_size)\n", + " y = [y_posterior(x, X_obs, Y_obs, kernel) for x in X.reshape(-1, 1)]\n", + " ax.plot(X, [y[0] for y in y], linewidth=1, label='Function')\n", + " ax.scatter(X_obs, Y_obs, s=20, label='Observations')\n", + " ax.fill_between(X,\n", + " [y[0] - np.sqrt(y[1]) * alpha_95 for y in y],\n", + " [y[0] + np.sqrt(y[1]) * alpha_95 for y in y],\n", + " alpha=0.5, label='95% CI'\n", + " )\n", + "\n", + " ax.set_xlabel('x')\n", + " ax.set_xlabel('y')\n", + " ax.legend()\n", + " if fig is not None:\n", + " fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "高斯过程假设更相似的$x$有更相似的$y$,即$\\text{Cov}(f(x), f(x')) = d_{x, x'}$。用于描述$x$之间相似度的函数称为**核函数**:$d_{x, x'} = k(x, x')$。核函数的数值越大,说明$x$和$x'$越相似,否则越不相似。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "def kernel_wrapper(*args, **kwargs):\n", + " if args:\n", + " _kernel = args[0]\n", + " if kwargs:\n", + " _kernel = functools.partial(_kernel, **kwargs)\n", + "\n", + " @functools.wraps(_kernel)\n", + " def new_kernel(a, b):\n", + " if a.shape[-1] != b.shape[-1]:\n", + " raise ValueError('The last dimension of a and b must be the same.')\n", + "\n", + " num_a = 1 if a.ndim == 1 else a.shape[0]\n", + " num_b = 1 if b.ndim == 1 else b.shape[0]\n", + " num_hidden = a.shape[-1]\n", + " target_shape = (num_a, num_b, num_hidden)\n", + "\n", + " x_a = np.broadcast_to(a.reshape((num_a, 1, num_hidden)), target_shape)\n", + " x_b = np.broadcast_to(b.reshape((1, num_b, num_hidden)), target_shape)\n", + "\n", + " result = _kernel(x_a, x_b)\n", + " result_shape = [\n", + " *([] if a.ndim == 1 else [num_a]),\n", + " *([] if b.ndim == 1 else [num_b])\n", + " ]\n", + " result = result.reshape(result_shape)\n", + " return result\n", + "\n", + " return new_kernel\n", + " else:\n", + " return functools.partial(kernel_wrapper, **kwargs)\n", + "\n", + "# Use the following method to define a kernel function\n", + "# @kernel_wrapper - for kernels with no hyperparameters\n", + "# @kernel_wrapper(hyperparameter=value) - for kernels with hyperparameters" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "高斯过程中常用的核函数有:\n", + "\n", + "* 线性核函数:$k(x, x'; v) = vx^Tx'$,其中$v$为超参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:48: RuntimeWarning: invalid value encountered in sqrt\n", + " [y[0] - np.sqrt(y[1]) * alpha_95 for y in y],\n", + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:49: RuntimeWarning: invalid value encountered in sqrt\n", + " [y[0] + np.sqrt(y[1]) * alpha_95 for y in y],\n", + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:57: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.075732\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def linear_kernel(a, b, nu=1):\n", + " # Input: a, b: (num_a, num_b, num_hidden)\n", + " return np.einsum('ijk,ijk->ij', a, b) * nu\n", + "\n", + "plot_posterior(kernel_wrapper(nu=1)(linear_kernel), grid_size=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* 平方指数核函数:$k(x, x'; \\sigma, l) = \\sigma^2\\exp\\left(-\\frac{\\|x - x'\\|^2}{2l^2}\\right)$,$\\sigma$为幅度参数,控制函数值的随机波动范围;$l$为长度参数,控制函数值的随机波动频率。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:57: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.177809\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def rbf_kernel(a, b, sigma, l):\n", + " # Input: a, b: (num_a, num_b, num_hidden)\n", + " return sigma ** 2 * np.exp(\n", + " -0.5 * np.linalg.norm(a - b, axis=-1) ** 2 / l ** 2\n", + " )\n", + "\n", + "plot_posterior(kernel_wrapper(sigma=1, l=1)(rbf_kernel), grid_size=80)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Matern核函数:$k(x, x'; \\sigma, l, \\nu) = \\frac{2^{1-\\nu}}{\\Gamma(\\nu)}\\left(\\frac{\\sqrt{2\\nu}\\|x - x'\\|}{l}\\right)^\\nu K_\\nu\\left(\\frac{\\sqrt{2\\nu}\\|x - x'\\|}{l}\\right)$,其中$\\nu$为超参数,用于控制函数值的光滑度;$K_\\nu$为修正Bessel函数。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/497016935.py:6: RuntimeWarning: invalid value encountered in multiply\n", + " y = 2 ** (1 - nu) / gamma(nu) * x ** nu * kn(nu, x)\n", + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:57: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.304928\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from scipy.special import kn, gamma\n", + "\n", + "def matern_kernel(a, b, l, nu):\n", + " # Input: a, b: (num_a, num_b, num_hidden)\n", + " x = (np.sqrt(2 * nu) * np.linalg.norm(a - b, axis=-1) / l)\n", + " y = 2 ** (1 - nu) / gamma(nu) * x ** nu * kn(nu, x)\n", + " # Replace inf with 1\n", + " y = np.where(x < 1e-6, 1, y)\n", + " return y\n", + "\n", + "plot_posterior(kernel_wrapper(l=1, nu=2)(matern_kernel), 4, grid_size=80)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* 周期核函数:$k(x, x'; \\sigma, l, p) = \\sigma^2\\exp\\left(-\\frac{2\\sin^2(\\pi\\|x - x'\\|/p)}{l^2}\\right)$,其中$p$为周期参数,$l$为长度参数,$\\sigma$为幅度参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:57: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.463899\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def periodic_kernel(a, b, sigma, l, p):\n", + " # Input: a, b: (num_a, num_b, num_hidden)\n", + " return sigma ** 2 * np.exp(\n", + " -2 * np.sin(np.pi * np.linalg.norm(a - b, axis=-1) / p) ** 2 / l ** 2\n", + " )\n", + "\n", + "plot_posterior(kernel_wrapper(sigma=1, l=1, p=2)(periodic_kernel), 4, grid_size=150)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* 噪声核函数:$k(x, x'; \\sigma, l) = \\sigma^2\\delta_{x, x'}$,其中$\\delta_{x, x'}$为Kronecker delta函数,用于捕获观测数据的噪声。\n", + "\n", + " $$\n", + " \\delta(x, x') = \\begin{cases}\n", + " 1, & x = x' \\\\\n", + " 0, & x \\neq x'\n", + " \\end{cases}\n", + " $$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:57: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.580729\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def noise_kernel(a, b, sigma):\n", + " return sigma ** 2 * np.all((a == b), axis=-1)\n", + "\n", + "plot_posterior(kernel_wrapper(sigma=1)(noise_kernel), 5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在实际应用中,需要根据目标函数已知的特性,设计合适的核函数:如果函数中包含周期性,可以使用周期核函数;如果函数值在相似的$x$之间变化较小,可以使用平方指数核函数。\n", + "\n", + "如对于函数\n", + "\n", + "$$\n", + "f(x) = \\sin(\\pi x) + 0.3 x^2 + \\varepsilon\n", + "$$\n", + "\n", + "需要同时引入周期核函数和平方指数核函数,才能较好地预测函数值。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/4237481582.py:16: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.728874\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def target_function(X):\n", + " x = X.reshape(-1)\n", + " return np.sin(x * 2 * 3.14 / 2) + x ** 2 * 0.3\n", + "\n", + "\n", + "def plot_function(func, ax, X_min=-5, X_max=5, grid_size=200, **kwargs):\n", + " X = np.linspace(X_min, X_max, grid_size)\n", + " Y = func(X.reshape(-1, 1))\n", + " ax.plot(X, Y, **kwargs)\n", + "\n", + "fig, ax = plt.subplots(figsize=(4, 4))\n", + "plot_function(target_function, ax, grid_size=100)\n", + "ax.set_xlabel('x')\n", + "ax.set_ylabel('y')\n", + "ax.set_title('Target Function')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:48: RuntimeWarning: invalid value encountered in sqrt\n", + " [y[0] - np.sqrt(y[1]) * alpha_95 for y in y],\n", + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/2131638459.py:49: RuntimeWarning: invalid value encountered in sqrt\n", + " [y[0] + np.sqrt(y[1]) * alpha_95 for y in y],\n", + "/var/folders/wg/gb8y92_d43j60wvfw6bs_cz00000gn/T/ipykernel_96329/1232705018.py:23: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n", + " fig.show()\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-12-16T17:23:12.943931\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.9.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "periodic = kernel_wrapper(sigma=1, l=1, p=2)(periodic_kernel)\n", + "rbf = kernel_wrapper(sigma=1, l=1)(rbf_kernel)\n", + "noise = kernel_wrapper(sigma=0.1)(noise_kernel)\n", + "\n", + "def added(a, b):\n", + " return periodic(a, b) + rbf(a, b) + noise(a, b)\n", + "\n", + "kernels = {\n", + " 'Periodic': periodic,\n", + " 'RBF': rbf,\n", + " 'Hybrid': added\n", + "}\n", + "\n", + "X_obs, Y_obs = generate_samples(10, target_function, sigma=0.5)\n", + "X_min, X_max = np.min(X_obs), np.max(X_obs)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(10, 3))\n", + "for ax, (name, kernel) in zip(axes, kernels.items()):\n", + " plot_posterior(kernel, X_obs=X_obs, Y_obs=Y_obs, ax=ax, grid_size=150)\n", + " plot_function(target_function, ax, X_min, X_max, color='red', label='Target Function', grid_size=100)\n", + " ax.set_title(f'{name} Kernel')\n", + " ax.get_legend().remove()\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "高斯过程不适用于以下场景:\n", + "\n", + "* 数据量较大时,由于计算$m$维矩阵逆的复杂度为$O(m^3)$,计算复杂度较高。\n", + "* 数据维度较高时,维度灾难会导致核函数容易退化,无法捕获数据之间的相似性。\n", + "* 高斯过程的预测结果是连续的,无法直接处理连续数据。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/coding/machine-learning/index.md b/docs/coding/machine-learning/index.md index 8bfd8ba2..b1ecdff4 100644 --- a/docs/coding/machine-learning/index.md +++ b/docs/coding/machine-learning/index.md @@ -4,4 +4,8 @@ * [线性模型](linear-models.ipynb) * [决策树](decision-tree.ipynb) -* [贝叶斯优化](bayesian-optimization.ipynb) +* 贝叶斯优化 + * [高斯过程](gaussian-process.ipynb) + * [算法流程](bayesian-optimization.ipynb) +* 主题模型 + * [VAE-NTM](vae-ntm.ipynb)