在机器学习中,我们经常遇到的一类目标函数:
\begin{equation}L_{\theta}=\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]\end{equation}最小化L_\theta,需要对L求导,得
\begin{equation}\nabla_{\theta}L_{\theta}=\nabla_{\theta}\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]=\nabla_{\theta} \int p_{\theta}(z) f(z) dz\end{equation}计算此梯度意味着我们要从p中采样,如果直接采样的话会直接失去\theta的梯度信息。重参数技巧是采用随机变量的函数变换技巧,在不损失梯度的情况下对其进行采样。
随机变量的函数变换
定理:设已知随机变量X的分布函数为F_{X}(x)和密度函数为p_{X}(x),又设Y=g(X),其中函数g(\cdot)是严格单调函数,且导数g^{\prime}(\cdot)存在, 则Y的密度函数为
\begin{equation}p_{Y}(y)=p_{X}(h(y))\left|h^{\prime}(y)\right|\end{equation}证 :由于Y=g(X)是严格单调函数 (严增函数或严减函数),故其反函数X=h(y)存在。由g可导,从而h也可导。 为确定起见,先设g(X)是X的严增函数,则有
\begin{equation}\begin{aligned}F_{Y}(y)&=P(Y \leqslant y)=P(g(X) \leqslant y) \\&=P(X \leqslant h(y))=F_{X}(h(y)) \\p_{Y}(y) &=p_{X}(h(y)) \cdot h^{\prime}(y)\end{aligned}\end{equation}如果g(X) 是严减函数,则事件“g(X) \leqslant y”等价于“X \geqslant h(y)”,所以在严减函数场合,我们有
\begin{equation}\begin{aligned}F_{Y}(y) &=P(Y \leqslant y)=P(g(X) \leqslant y) \\&=P(X \geqslant h(y))=1-F_{X}(h(y)) \\p_{Y}(y) &=-p_{X}(h(y)) \cdot h^{\prime}(y)\end{aligned}\end{equation}因为当g为严减函数时,其反函数h也是或函数,故h^{\prime}(y)<0。这样p_{Y}(y)仍为非负的,综合可得上述定理结论。
设随机变量\epsilon满足z=g(\epsilon;\theta),则由上述定理得到,p_{\theta}(z)=p(\epsilon)*|d\epsilon/dz|,代入目标函数得到
\begin{equation}\begin{aligned}\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]&=\int p_\theta(z)f(z)dz\\&=\int p(\epsilon)f(g(\epsilon;\theta))d\epsilon\\&=\mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g(\epsilon;\theta))]\end{aligned}\end{equation}对目标函数\mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g(\epsilon;\theta))]进行优化,完成了在不损失梯度的情况下的采样操作。
应用
变换从正态分布中采样的操作:
当z服从任意正太分布时,z \sim N(\mu,\sigma^2),令\epsilon \sim N(0,1),那么z=g(\epsilon)=\epsilon \cdot \sigma^2+\mu
变换从类别分布中采样的操作:
当z服从任意类别分布时,z \sim Categotical(p),其中p是概率向量,令\epsilon \sim U(0,1),那么z=g(\epsilon)=argmax(log(p)-log(-log(\epsilon)))。实验代码如下
import numpy as np from collections import Counter def softmax(x): x = np.exp(x) return x / x.sum() logits = np.array([0.2,0.6,0.8,0.35]) probs = softmax(logits) def gumbel_max_sample_with_logits(logits): epsilon = np.random.uniform(0, 1, len(logits)) idx = np.argmax(logits - np.log(-np.log(epsilon))) return idx def gumbel_max_sample_with_probs(probs): epsilon = np.random.uniform(0, 1, len(probs)) idx = np.argmax(np.log(probs) - np.log(-np.log(epsilon))) return idx def random_choice_with_probs(probs): idx = np.random.choice(list(range(len(probs))),p=probs) return idx sample_data = {'gumbel_max_sample_with_logits':[], 'gumbel_max_sample_with_probs':[], 'random_choice_with_probs':[]} sample_num = 100000 for _ in range(sample_num): idx1 = gumbel_max_sample_with_logits(logits) idx2 = gumbel_max_sample_with_probs(probs) idx3 = random_choice_with_probs(probs) sample_data['gumbel_max_sample_with_logits'].append(idx1) sample_data['gumbel_max_sample_with_probs'].append(idx2) sample_data['random_choice_with_probs'].append(idx3) def get_rate(data): rate = [] cnt = Counter(data) for idx in range(len(probs)): rate.append(cnt[idx]/sample_num) return rate print('origin distribution:\t\t',probs) for method in sample_data: data = sample_data[method] rate = get_rate(data) print(F'{method}:\t', rate)
运行结果: origin distribution: [0.18262246 0.2724407 0.33275982 0.21217703] gumbel_max_sample_with_logits: [0.18119, 0.27366, 0.33384, 0.21131] gumbel_max_sample_with_probs: [0.18087, 0.27535, 0.33311, 0.21067] random_choice_with_probs: [0.18201, 0.2719, 0.33411, 0.21198]