在机器学习中,我们经常遇到的一类目标函数:

\begin{equation}L_{\theta}=\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]\end{equation}

最小化L_\theta,需要对L求导,得

\begin{equation}\nabla_{\theta}L_{\theta}=\nabla_{\theta}\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]=\nabla_{\theta} \int p_{\theta}(z) f(z) dz\end{equation}

计算此梯度意味着我们要从p中采样,如果直接采样的话会直接失去\theta的梯度信息。重参数技巧是采用随机变量的函数变换技巧,在不损失梯度的情况下对其进行采样。

随机变量的函数变换

定理:设已知随机变量X的分布函数为F_{X}(x)和密度函数为p_{X}(x),又设Y=g(X),其中函数g(\cdot)是严格单调函数,且导数g^{\prime}(\cdot)存在, 则Y的密度函数为

\begin{equation}p_{Y}(y)=p_{X}(h(y))\left|h^{\prime}(y)\right|\end{equation}

证 :由于Y=g(X)是严格单调函数 (严增函数或严减函数),故其反函数X=h(y)存在。由g可导,从而h也可导。 为确定起见,先设g(X)X的严增函数,则有

\begin{equation}\begin{aligned}F_{Y}(y)&=P(Y \leqslant y)=P(g(X) \leqslant y) \\&=P(X \leqslant h(y))=F_{X}(h(y)) \\p_{Y}(y) &=p_{X}(h(y)) \cdot h^{\prime}(y)\end{aligned}\end{equation}

如果g(X) 是严减函数,则事件“g(X) \leqslant y”等价于“X \geqslant h(y)”,所以在严减函数场合,我们有

\begin{equation}\begin{aligned}F_{Y}(y) &=P(Y \leqslant y)=P(g(X) \leqslant y) \\&=P(X \geqslant h(y))=1-F_{X}(h(y)) \\p_{Y}(y) &=-p_{X}(h(y)) \cdot h^{\prime}(y)\end{aligned}\end{equation}

因为当g为严减函数时,其反函数h也是或函数,故h^{\prime}(y)<0。这样p_{Y}(y)仍为非负的,综合可得上述定理结论。

设随机变量\epsilon满足z=g(\epsilon;\theta),则由上述定理得到,p_{\theta}(z)=p(\epsilon)*|d\epsilon/dz|,代入目标函数得到

\begin{equation}\begin{aligned}\mathbb{E}_{z\sim p_{\theta}(z)}[f(z)]&=\int p_\theta(z)f(z)dz\\&=\int p(\epsilon)f(g(\epsilon;\theta))d\epsilon\\&=\mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g(\epsilon;\theta))]\end{aligned}\end{equation}

对目标函数\mathbb{E}_{\epsilon \sim p(\epsilon)}[f(g(\epsilon;\theta))]进行优化,完成了在不损失梯度的情况下的采样操作。

应用

变换从正态分布中采样的操作:

z服从任意正太分布时,z \sim N(\mu,\sigma^2),令\epsilon \sim N(0,1),那么z=g(\epsilon)=\epsilon \cdot \sigma^2+\mu

变换从类别分布中采样的操作:

z服从任意类别分布时,z \sim Categotical(p),其中p是概率向量,令\epsilon \sim U(0,1),那么z=g(\epsilon)=argmax(log(p)-log(-log(\epsilon)))。实验代码如下

import numpy as np
from collections import Counter

def softmax(x):
    x = np.exp(x)
    return x / x.sum()

logits = np.array([0.2,0.6,0.8,0.35])
probs = softmax(logits)

def gumbel_max_sample_with_logits(logits):
    epsilon = np.random.uniform(0, 1, len(logits))
    idx = np.argmax(logits - np.log(-np.log(epsilon)))
    return idx
 
def gumbel_max_sample_with_probs(probs):
    epsilon = np.random.uniform(0, 1, len(probs))
    idx = np.argmax(np.log(probs) - np.log(-np.log(epsilon)))
    return idx

def random_choice_with_probs(probs):
    idx = np.random.choice(list(range(len(probs))),p=probs)
    return idx

sample_data = {'gumbel_max_sample_with_logits':[], 'gumbel_max_sample_with_probs':[], 'random_choice_with_probs':[]}
sample_num = 100000
for _ in range(sample_num):
    idx1 = gumbel_max_sample_with_logits(logits)
    idx2 = gumbel_max_sample_with_probs(probs)
    idx3 = random_choice_with_probs(probs)
    
    sample_data['gumbel_max_sample_with_logits'].append(idx1)
    sample_data['gumbel_max_sample_with_probs'].append(idx2)
    sample_data['random_choice_with_probs'].append(idx3)

def get_rate(data):
    rate = []
    cnt = Counter(data)
    for idx in range(len(probs)):
        rate.append(cnt[idx]/sample_num)
    return rate

print('origin distribution:\t\t',probs)
for method in sample_data:
    data = sample_data[method]
    rate = get_rate(data)
    print(F'{method}:\t', rate)
运行结果:
origin distribution:		 [0.18262246 0.2724407  0.33275982 0.21217703]
gumbel_max_sample_with_logits:	 [0.18119, 0.27366, 0.33384, 0.21131]
gumbel_max_sample_with_probs:	 [0.18087, 0.27535, 0.33311, 0.21067]
random_choice_with_probs:	 [0.18201, 0.2719, 0.33411, 0.21198]