Erlo

变分信息瓶颈 (Variational Information Bottleneck) 公式推导

2024-08-21 01:29:41 发布   177 浏览  
页面报错/反馈
收藏 点赞

互信息

互信息用于表示两个随机变量相互依赖的程度。随机变量 (X)(Y) 的互信息定义为

[begin{aligned} I(X, Y) & = mathrm{KL}[p(boldsymbol{x}, boldsymbol{y}) parallel p(boldsymbol{x})p(boldsymbol{y})] \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{y}) sim p(boldsymbol{x}, boldsymbol{y})} left[logfrac{p(boldsymbol{x}, boldsymbol{y})}{p(boldsymbol{x})p(boldsymbol{y})}right], end{aligned} ]

其中 (p(boldsymbol{x}, boldsymbol{y})) 表示 (X)(Y) 的联合概率密度,(p(boldsymbol{x}))(p(boldsymbol{y})) 分别表示 (X)(Y) 的边缘概率密度。

互信息是一个非负的量,当且仅当 (X)(Y) 相互独立时(此时 (p(boldsymbol{x}, boldsymbol{y}) = p(boldsymbol{x})p(boldsymbol{y})) 恒成立)取到最小值 (0)

在机器学习中,联合分布 (p(boldsymbol{x}, boldsymbol{y})) 通常是难以得到的,因此通常会用贝叶斯公式转换一下,使用以下两种形式的互信息:

[begin{aligned} I(X, Y) & = mathbb{E}_{(boldsymbol{x}, boldsymbol{y}) sim p(boldsymbol{x}, boldsymbol{y})} left[logfrac{p(boldsymbol{x}, boldsymbol{y})}{p(boldsymbol{x})p(boldsymbol{y})}right] \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{y}) sim p(boldsymbol{x}, boldsymbol{y})} left[logfrac{p(boldsymbol{x}|boldsymbol{y})}{p(boldsymbol{x})}right] \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{y}) sim p(boldsymbol{x}, boldsymbol{y})} left[logfrac{p(boldsymbol{y}|boldsymbol{x})}{p(boldsymbol{y})}right]. end{aligned} ]

信息瓶颈

令随机变量 (X) 表示输入数据,(Z) 表示编码后的特征,(Y) 表示标签。信息瓶颈 (Information Bottleneck) 理论认为,神经网络的优化存在两阶段性:

  1. 快速拟合阶段:增加 (I(Z, X))
  2. 压缩阶段:减少 (I(Z, X)) 并增加 (I(Z, Y))

information-bottleneck-trajectory

上面这幅插图可视化了神经网络训练过程中互信息的变化轨迹,横轴表示特征与输入的互信息 (I(Z, X)),纵轴表示特征与标签的互信息 (I(Z, Y))(图中用 (T) 表示特征),从紫色到黄色表示从 0 epoch 到 10000 epoch。从图中可见,随着训练的进行,(I(Z, X)) 有一个先增大再减小的过程。

插图出自论文 [1703.00810] Opening the Black Box of Deep Neural Networks via Information。参考阅读:Anatomize Deep Learning with Information Theory | Lil'Log

那么能不能利用这个现象对神经网络的训练进行正则化呢,于是有人提出了变分信息瓶颈 (Variational Information Bottleneck, VIB) 方法,优化的目标为:

[max_{boldsymbol{boldsymbol{theta}}} I(Z, Y; boldsymbol{theta}) - beta I(Z, X; boldsymbol{theta}). ]

我们希望 (Z) 能尽量准确地预测 (Y),同时尽量地遗忘 (X) 中的信息。换句话说,我们希望 (Z) 遗忘 (X) 中的冗余信息,只保留那些对预测 (Y) 有用的信息。这里的最小化 (I(Z, X; boldsymbol{theta})) 起到了正则化的效果

遗憾的是,从高维数据中直接估计互信息是很困难的,变分信息瓶颈的解决思路是通过变分近似实现对互信息的估计。

最小化 I(Z, X)

使用如下形式的互信息 (I(Z, X))

[I(Z, X; boldsymbol{theta}) = mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{p(boldsymbol{z})}right] \ ]

注意到这里需要 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})),一种比较方便的处理方法是像 VAE 那样使用概率编码器 (probabilistic encoder),而不是传统的确定性编码器 (deterministic encoder),即 (X mapsto Z) 是一个随机函数而不是传统的确定性函数。参考 VAE 中的做法,我们将 (p(boldsymbol{z}|boldsymbol{x})) 预定义为参数化的高斯分布,并用神经网络输出这个高斯分布的参数:

[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}) := N(boldsymbol{z}; boldsymbol{mu}(boldsymbol{x}; boldsymbol{theta}), boldsymbol{sigma}^2(boldsymbol{x}; boldsymbol{theta})boldsymbol{I}). ]

解决了 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})),接下来的问题是如何求解 (p(boldsymbol{z}))。可能会想到采样估计的办法,即蒙特卡洛 (Monte Carlo, MC) 估计:

[begin{aligned} p(boldsymbol{z}) & = int_{boldsymbol{x}} p(boldsymbol{x})p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})mathrm{d}boldsymbol{x} \ & = mathbb{E}_{boldsymbol{x} sim p(boldsymbol{x})}[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})] \ & approx frac{1}{N}sum_{i=1}^N p(boldsymbol{z}|boldsymbol{x}_i; boldsymbol{theta}), quad boldsymbol{x}_i sim p(boldsymbol{x}). end{aligned} ]

但是论文作者并没有使用这种方法,可能是认为在这里用 MC 估计的方差太大了,需要大量采样才能估得准,效率太低。为了估计期望 (mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{p(boldsymbol{z})}right]),就先要从 (p(boldsymbol{x})) 中采样 (boldsymbol{x}),然后从 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 中采样 (boldsymbol{z})。更麻烦的是方括号内的函数值也无法直接解析求解,需要先采样估计出 (p(boldsymbol{z})) 才能计算。采样估计的过程太多,估计的方差自然会大。

变分信息瓶颈,顾名思义,就是通过变分近似的方法来解决无法获得 (p(boldsymbol{z})) 的问题。假如有一个形式已知的无参分布 (q(boldsymbol{z})),它跟 (p(boldsymbol{z})) 非常接近,那我们用这个 (q(boldsymbol{z})) 替换掉公式里的 (p(boldsymbol{z})),不就能近似地计算互信息 (I(Z, X)) 吗?这里不妨将 (q(boldsymbol{z})) 定义为标准高斯分布,即 (q(boldsymbol{z}) := N(boldsymbol{z}, boldsymbol{0}, boldsymbol{I}))

接下来需要证明这种替换是有道理的,参考 VAE 中推导的经验,我们尝试用 (q(boldsymbol{z})) 替换 (p(boldsymbol{z})),并尝试把额外的部分凑出一个 KL:

[begin{aligned} I(Z, X; boldsymbol{theta}) & = mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{p(boldsymbol{z})}right] \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{q(boldsymbol{z})}frac{q(boldsymbol{z})}{p(boldsymbol{z})}right] \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{q(boldsymbol{z})}right] + mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}right] end{aligned} ]

对于第一项,(p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}))(q(boldsymbol{z})) 都有解析式,因此方括号内的函数可以算出解析解。利用 (p(boldsymbol{x}, boldsymbol{z}) = p(boldsymbol{x})p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})),可以把第一项写得好看些:

[begin{aligned} mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{q(boldsymbol{z})}right] & = iint p(boldsymbol{x})p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{q(boldsymbol{z})} mathrm{d}boldsymbol{z}mathrm{d}boldsymbol{x} \ & = int_x p(boldsymbol{x}) int_z p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})logfrac{p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}{q(boldsymbol{z})} mathrm{d}boldsymbol{z}mathrm{d}boldsymbol{x} \ & = mathbb{E}_{boldsymbol{x} sim p(boldsymbol{x})}[mathrm{KL}[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}) parallel q(boldsymbol{z})]] overset{text{def}}{=} R(Z, X; boldsymbol{theta}) \ & approx frac{1}{N} cdot mathrm{KL}[p(boldsymbol{z}|x_i; boldsymbol{theta}) parallel q(boldsymbol{z})], quad x_i sim p(boldsymbol{x}). end{aligned} ]

这个 (R(Z, X; boldsymbol{theta}) := mathbb{E}_{boldsymbol{x} sim p(boldsymbol{x})}[mathrm{KL}[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}) parallel q(boldsymbol{z})]]) 常常被称为 rate,也就是率失真理论里的率。Rate 这一项是可以用 mini-batch 梯度下降来优化的,具体来说,从训练集中采样一批样本 (boldsymbol{x}_1, ldots, boldsymbol{x}_N),最小化每个 (boldsymbol{x}_i)(mathrm{KL}[p(boldsymbol{z}|boldsymbol{x}_i; boldsymbol{theta}) parallel q(boldsymbol{z})]) 即可。由于两个分布都是高斯分布,因此这里的 KL 有解析解:

[begin{aligned} & mathrm{KL}[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}) parallel q(boldsymbol{z})] \ & = mathrm{KL}[N(boldsymbol{mu}(boldsymbol{x}), boldsymbol{sigma}^2(boldsymbol{x})boldsymbol{I}), N(boldsymbol{0}, boldsymbol{I})] \ & = sum_{j=1}^J mathrm{KL}[N(mu_j, sigma^2_j) parallel N(0, 1)] \ & = sum_{j=1}^J frac{1}{2}(-logsigma^2_j - 1 + mu^2_j + sigma^2_j). end{aligned} ]

详细的推导过程可参考从极大似然估计到变分自编码器 - VAE 公式推导中“KL 散度的解析解”这一节。相比原来的形式,“写得好看”之后的好处在于:函数对 (boldsymbol{z}) 的积分可以解析地求解,这样一来,用 MC 估计 (R(Z, X; boldsymbol{theta})) 时,只需要从 (p(boldsymbol{x})) 中采样 (boldsymbol{x}),无需再从 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 中采样 (boldsymbol{z}),减少了采样带来的误差。

对于第二项,注意到期望方括号中的函数跟 (boldsymbol{x}) 没关系,因此:

[begin{aligned} mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}right] & = mathbb{E}_{boldsymbol{z} sim p(boldsymbol{z})}left[logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}right] \ & = -mathbb{E}_{boldsymbol{z} sim p(boldsymbol{z})}left[logfrac{p(boldsymbol{z})}{q(boldsymbol{z})}right] \ & = -mathrm{KL}[p(boldsymbol{z}) parallel q(boldsymbol{z})], end{aligned} ]

如果要详细证明一下的话就是:

[begin{aligned} mathbb{E}_{(boldsymbol{x}, boldsymbol{z}) sim p(boldsymbol{x}, boldsymbol{z})}left[logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}right] & = iint p(boldsymbol{x}, boldsymbol{z})logfrac{q(boldsymbol{z})}{p(boldsymbol{z})} mathrm{d}boldsymbol{z}mathrm{d}boldsymbol{x} \ & = int_{boldsymbol{z}}logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}left(int_{boldsymbol{x}} p(boldsymbol{z}, boldsymbol{x})mathrm{d}boldsymbol{x}right)mathrm{d}boldsymbol{z} \ & = int_{boldsymbol{z}}logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}p(boldsymbol{z})mathrm{d}boldsymbol{z} \ & = mathbb{E}_{boldsymbol{z} sim p(boldsymbol{z})}left[logfrac{q(boldsymbol{z})}{p(boldsymbol{z})}right] = -mathrm{KL}[p(boldsymbol{z}) parallel q(boldsymbol{z})]. end{aligned} ]

因此这一项就是要凑的那个 KL 散度。由于得不到 (p(boldsymbol{z})) 的解析式,KL 散度这一项无法被直接优化,它放在这里只是为了证明变分近似的合理性,详见下文。

综上所述,互信息 (I(Z, X)) 可以拆成两部分:

[begin{aligned} I(Z, X; boldsymbol{theta}) & = mathbb{E}_{boldsymbol{x} sim p(boldsymbol{x})}[mathrm{KL}[p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta}) parallel q(boldsymbol{z})]] - mathrm{KL}[p(boldsymbol{z}) parallel q(boldsymbol{z})] \ & = R(Z, X; boldsymbol{theta}) - mathrm{KL}[p(boldsymbol{z}) parallel q(boldsymbol{z})]. end{aligned} ]

由 KL 散度的非负性可知,rate (R) 是互信息 (I(Z, X; boldsymbol{theta})) 的上界:

[R(Z, X; boldsymbol{theta}) = I(Z, X; boldsymbol{theta}) + mathrm{KL}[p(boldsymbol{z}) parallel q(boldsymbol{z})] geq I(Z, X; boldsymbol{theta}), ]

这正合我们意愿,因为我们想要最小化互信息 (I(Z, X; boldsymbol{theta})),所以我们可以通过最小化它的上界 (R(Z, X; boldsymbol{theta})) 来间接地实现互信息的最小化,实现“曲线救国”。

最大化 I(Z, Y)

使用如下形式的互信息 (I(Z, X))

[I(Z, Y; boldsymbol{theta}) = mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}left[logfrac{p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})}{p(boldsymbol{y})}right] \ ]

标签的分布 (p(boldsymbol{y})) 可能是无法知道的:如果 (boldsymbol{y}) 是类别标签,那么离散型分布 (p(boldsymbol{y})) 是比较容易求的;但如果 (boldsymbol{y}) 是数值,连续型分布 (p(boldsymbol{y})) 是比较难求的。不过难求的 (p(boldsymbol{y})) 并不影响优化过程,因为

[mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[-log p(boldsymbol{y})] = -mathbb{E}_{boldsymbol{y} sim p(boldsymbol{y})}[log p(boldsymbol{y})] overset{text{def}}{=} mathrm{H}(Y), ]

其中 (mathrm{H}(Y)) 表示随机变量 (Y) 的信息熵 (entropy)。由于标签 (Y) 来自于数据集,不属于优化变量,因此 (mathrm{H}(Y)) 是一个定值,不影响优化过程。

接下来要解决的是 (p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})) 难求的问题。这里需要与前一节 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 的情况相区分,(X) 是数据集中的数据,(Z) 是可优化的特征,因此对于 (X mapsto Z) 这个过程,我们可以任意指定 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 的形式,(p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 不是难求的。而 (Y) 是数据集中的数据,对于 (Z mapsto Y) 这个过程,(p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})) 的形式是客观上确定的,我们不能随意指定,(p(boldsymbol{y}|boldsymbol{z})) 是难求的。

可以用一个形式已知的分布 (q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})) 来近似 (p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta}))

[begin{aligned} I(Z, Y; boldsymbol{theta}) & = mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})] + mathrm{H}(Y) \ & = mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})] + mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}left[log frac{p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})}{q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})}right] + mathrm{H}(Y) \ & = mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})] + iint p(boldsymbol{z})p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})logfrac{p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta})}{q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})} mathrm{d}boldsymbol{y}mathrm{d}boldsymbol{z} + mathrm{H}(Y) \ & = mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})] + mathbb{E}_{boldsymbol{z} sim p(boldsymbol{z})}[mathrm{KL}[p(boldsymbol{y}|boldsymbol{z}; boldsymbol{theta}) parallel q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})]] + mathrm{H}(Y) \ & geq mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})] + mathrm{H}(Y) overset{text{def}}{=}I_{text{BA}}. end{aligned} ]

利用 KL 散度的非负性,可以得到互信息 (I(Z, Y; boldsymbol{theta})) 的一个下界 (I_{text{BA}}),它被称为互信息的 Barber & Agakov 下界。

(p(boldsymbol{y}, boldsymbol{z}) = int_x p(boldsymbol{x}, boldsymbol{y}, boldsymbol{z}) mathrm{d}boldsymbol{x}) 可得

[begin{aligned} mathbb{E}_{(boldsymbol{y}, boldsymbol{z}) sim p(boldsymbol{y}, boldsymbol{z})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})] & = iint left(int_x p(boldsymbol{x}, boldsymbol{y}, boldsymbol{z}) mathrm{d}boldsymbol{x}right) log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})mathrm{d}boldsymbol{y}mathrm{d}boldsymbol{z} \ & = iiint p(boldsymbol{x}, boldsymbol{y})p(boldsymbol{z}|boldsymbol{x}, boldsymbol{y}; boldsymbol{theta})log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi}) mathrm{d}boldsymbol{x}mathrm{d}boldsymbol{y}mathrm{d}boldsymbol{z} \ & = iiint p(boldsymbol{x}, boldsymbol{y})p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi}) mathrm{d}boldsymbol{x}mathrm{d}boldsymbol{y}mathrm{d}boldsymbol{z} \ & = mathbb{E}_{(boldsymbol{x}, boldsymbol{y}) sim p(boldsymbol{x}, boldsymbol{y})}[mathbb{E}_{boldsymbol{z} sim p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})}[log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})]] \ & approx frac{1}{NM}sum_{i=1}^Nsum_{j=1}^M log q(boldsymbol{y}_i|boldsymbol{z}_j; boldsymbol{theta}), quad (x_i, boldsymbol{y}_i) sim p(boldsymbol{x}, boldsymbol{y}), boldsymbol{z}_j sim p(boldsymbol{z}|x_i; boldsymbol{theta}). end{aligned} ]

(Y) 是连续型数据(回归问题),则选择高斯分布模型作为近似分布 (q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})),最大化 (log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})) 对应最小化 MSE 损失。若 (Y) 是离散型数据(分类问题),则选择伯努利分布(二分类模型)或类别分布(多分类模型)模型作为近似分布 (q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})),最大化 (log q(boldsymbol{y}|boldsymbol{z}; boldsymbol{phi})) 对应最小化交叉熵损失。详细的推导过程可参考从极大似然估计到变分自编码器 - VAE 公式推导中“重构损失”这一节。

(N) 的意思是从数据集中采样 (N) 个训练数据 ((boldsymbol{x}_1, boldsymbol{y}_1), ldots, (boldsymbol{x}_N, boldsymbol{y}_N))(M) 的意思是对于每个样本 (boldsymbol{x}_i),从分布 (p(boldsymbol{z}|boldsymbol{x}_i; boldsymbol{theta})) 中采样 (M) 个特征 (boldsymbol{z}) 来计算 (M) 次 MSE/交叉熵损失。

一些理解

总的来说,最大化 (I(Z, Y)) 对应最小化交叉熵损失,最小化 (I(Z, X)) 对应最小化 KL 散度正则项(即 rate (R))。

变分信息瓶颈与普通判别模型的区别:

  1. 将普通判别模型中的确定性编码器 (deterministic encoder)改成了概率编码器 (probabilistic encoder),给定 (boldsymbol{x}),普通判别模型会给出唯一的 (boldsymbol{z}),而 VIB 的 (boldsymbol{z}) 是从某个分布中采样得到的,是一个随机变量。
  2. 加入了一个 KL 散度正则项(即 rate (R)),希望特征的后验分布 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 尽量接近标准高斯分布。

从这两点改进来看,变分信息瓶颈与 VAE 非常相似。

为什么最小化 KL 散度能作为正则项?为什么鼓励接近标准高斯分布是一种正则化效果?如果 KL 正则项为 0,则 (p(boldsymbol{z}|boldsymbol{x}; boldsymbol{theta})) 完全就是标准高斯分布,不包含任何关于样本 (boldsymbol{x}) 的信息,即完全遗忘了 (boldsymbol{x}) 的信息。当然了,这样的特征是不具备任何判别能力的,所以需要通过调节权重系数 (beta) 以在遗忘和预测能力之间取得平衡。

此外,注意到

[R(Z, X; boldsymbol{theta}) = I(Z, X; boldsymbol{theta}) + mathrm{KL}[p(boldsymbol{z}) parallel N(boldsymbol{0}, boldsymbol{I})], ]

因此在最小化正则项 (R(Z, X; boldsymbol{theta})) 时,不仅是在最小化互信息 (I(Z, X; boldsymbol{theta})),而且在最小化 (mathrm{KL}[p(boldsymbol{z}) parallel N(boldsymbol{0}, boldsymbol{I})]),使得特征 (Z) 的分布 (p(boldsymbol{z})) 逐渐趋近于标准高斯分布。标准高斯分布有很多优良的性质,例如,它的各个维度是相互独立的,这就是在鼓励特征 (Z) 的各维度解耦。

参考资料

论文原文:Deep Variational Information Bottleneck

从变分编码、信息瓶颈到正态分布:论遗忘的重要性 - 科学空间

变分信息瓶颈(Variational Information Bottleneck) - Sphinx Garden

迁移学习:互信息的变分上下界 - orion-orion - 博客园; 迁移学习:互信息的变分上下界 - 猎户座的文章 - 知乎

登录查看全部

参与评论

评论留言

还没有评论留言,赶紧来抢楼吧~~

手机查看

返回顶部

给这篇文章打个标签吧~

棒极了 糟糕透顶 好文章 PHP JAVA JS 小程序 Python SEO MySql 确认