본문 바로가기
AI & Optimization/Machine Learning

[ML] Importance Sampling (중요도 샘플링)

by SIES 2023. 12. 3.
반응형

확률 모델에 기반한 머신 러닝에서 함수 $f(x)$의 확률분포 $p(x)$의 기댓값을 구해야 하는 경우가 있습니다. 하지만 기댓값 $\mathbb{E}_{x\sim p}[f(x)]=\int f(x)p(x)dx$ 을 수식적으로 계산하기 어려운 경우 큰 수의 법칙(Law of large numbers)에 따라 sampling을 통해 $x^{(n)}$을 추출한 후 아래와 같이 기댓값을 근사할 수 있습니다. 이러한 방법을 Monte Carlo 기법이라고 합니다.

 

$$ \mathbb{E}_{x\sim p}[f(x)]\simeq\frac{1}{N}\sum_{n=1}^N f(x^{(n)})$$

 

Importance Sampling이란?

Importance sampling은 이러한 상황에서 본래의 분포 $p(x)$가 아닌 다른 확률분포 $q(x)$에서 추출된 sample들을 이용하여 기댓값 $\mathbb{E}_{x\sim p}[f(x)]$를 계산하는 방법입니다. 기댓값 수식에서 분자와 분모에 $q(x)$를 곱해주면 아래의 식을 얻을 수 있습니다.

 

$$\begin{align}\mathbb{E}_\textcolor{red}{x\sim p}[f(x)]&={\int} f(x)\textcolor{red}{p(x)}dx \\ &={\int}\left(\frac{p(x)}{q(x)}f(x)\right) \textcolor{green}{q(x)}dx \end{align}$$

 

이를 다시 $q(x)$ 분포에 대한 sampling 근사를 수행하면 아래와 같이 importance sampling에 기반한 기댓값을 유도할 수 있습니다.

 

$$ \mathbb{E}_\textcolor{red}{x\sim p}[f(x)] =\mathbb{E}_\textcolor{green}{x\sim q}\left[\frac{p(x)}{q(x)}f(x)\right] \simeq\frac{1}{N}\sum_{n=1}^N \frac{p(x)}{q(x)}f(x^{(n)}) $$

 

이때 $\frac{p(x)}{q(x)}$를 importance weight라고 하며 다른 분포 $q(x)$로부터 sampling을 수행하여 $\mathbb{E}_{x\sim p}[f(x)]$를 계산했을 때 발생하는 에러를 보정하는 역할을 합니다. 기존 확률분포 $p(x)$를 nominal distribution, 다른 확률분포 $q(x)$를 importance distribution이라고 합니다.

 

Figure 1. 함수 $f(x)$의 확률분포 $p(x)$와 importance distribution $q(x)$의 예시

 

Importance Sampling를 사용하는 이유

기댓값 $\mathbb{E}_{x\sim p}[f(x)]$을 계산하기 위해 importance sampling를 사용하는 경우는 다음과 같습니다.

 

■ $x\sim p$로부터 sampling을 할 수 없는 경우

확률분포 $p$로부터 직접적으로 sampling 할 수 없으면 큰 수의 법칙을 적용하여 기댓값을 계산할 수 없습니다. 대신 다른 확률분포 $q$로부터 sample을 얻을 수 있다면 importance sampling을 통해 기댓값을 계산할 수 있습니다.

 

예를 들어 강화학습의 PPO 알고리즘에서 이용됩니다. 현재 policy $\pi$로부터 sampling된 목적함수의 기댓값 $\mathbb{E}_{x\sim \pi}[f(x)]$을 구하고 싶은데, 과거 policy $\pi_\text{old}$로부터 얻은 sample $x\sim \pi_\text{old}$만 있는 경우 importance sampling이 이용됩니다.

 

■ $x\sim p$로부터 sampling이 비효율적인 경우

Importance sampling은 variance를 줄이기 위한 방법으로 사용될 수 있습니다. 큰 수의 법칙에 따라서 획득한 sample 평균으로 기댓값을 계산하면, 계산 시마다 획득한 sample에 따라 기댓값이 달라지는 variance error가 발생하게 됩니다.

 

Figure 2. Importance sampling을 통한 variance redcution

[그림 2]에서 제일 왼쪽의 그림은 기댓값을 계산할 때 high variance가 발생하는 경우입니다. $f(x)$가 큰 값을 가진 중앙 부근에서 확률분포 $pdf(x)$가 낮은값을 가집니다. 따라서 낮은 확률로 sampling되는 이 중앙 부근의 sampling 빈도 수에 따라 $f(x)$의 기댓값 계산이 크게 달라지게 됩니다.

 

극단적인 경우에는 우리가 관심 있는 부분 (ex. 중앙 부근 값)이 한번도 sampling 되지 않아 기댓값을 구하기 어려울 수 있습니다. 따라서 희소한 sample의 관측을 위해 많은 sampling이 필요하며 비효율적인 computation이 소모될 수 있습니다.

 

[그림 2]에서 왼쪽 그림에서 오른쪽으로 갈수록 variance는 낮아집니다. 중앙의 그림은 uniform distribution을 이용하여 sampling하는 경우이며, 제일 왼쪽 그림에 비해 중앙 부근이 더 자주 sampling 되어 variance가 감소합니다. 제일 오른쪽 그림은 기댓값 계산 시 영향이 큰 중앙 부근의 sampling 확률을 가장 높게 가져가 variance를 더 낮출 수 있습니다.

 

위와 같이 더 관심이 있거나 중요한 부분을 더 많이 sampling 하기 때문에 importance sampling이라고 부릅니다. 따라서 importance sampling에서는 기존 확률분포 $p(x)$보다 중요한 값이 잘 뽑히도록 확률분포 $q(x)$를 선택해야 합니다.

 

참고로 importance sampling ratio $\frac{p(x)}{q(x)}$ 없이 $x\sim q$로 획득한 sample을 이용하여 기댓값을 계산하면 variance는 낮아지지만 biased 된 estimator를 얻게 될 것입니다. 하지만 importance sampling ratio로 기댓값 계산을 보정해 주기 때문에 variance는 낮지만 unbiased 된 estimator를 얻을 수 있습니다.

 

References

[1] C. M. Bishop, "Pattern Recognition and Machine Learning," Springer, 2006.

[2] https://en.wikipedia.org/wiki/Importance_sampling


오타나 잘못된 부분 있으면 댓글 부탁드립니다. 도움이 되셨다면 공감 눌러 주시면 감사하겠습니다 :)

반응형

댓글