Processing math: 0%
본문 바로가기
AI & Optimization/Machine Learning

[ML] Importance Sampling (중요도 샘플링)

by SIES 2023. 12. 3.
반응형

확률 모델에 기반한 머신 러닝에서 함수 f(x)의 확률분포 p(x)의 기댓값을 구해야 하는 경우가 있습니다. 하지만 기댓값 \mathbb{E}_{x\sim p}[f(x)]=\int f(x)p(x)dx 을 수식적으로 계산하기 어려운 경우 큰 수의 법칙(Law of large numbers)에 따라 sampling을 통해 x^{(n)}을 추출한 후 아래와 같이 기댓값을 근사할 수 있습니다. 이러한 방법을 Monte Carlo 기법이라고 합니다.

 

\mathbb{E}_{x\sim p}[f(x)]\simeq\frac{1}{N}\sum_{n=1}^N f(x^{(n)})

 

Importance Sampling이란?

Importance sampling은 이러한 상황에서 본래의 분포 p(x)가 아닌 다른 확률분포 q(x)에서 추출된 sample들을 이용하여 기댓값 \mathbb{E}_{x\sim p}[f(x)]를 계산하는 방법입니다. 기댓값 수식에서 분자와 분모에 q(x)를 곱해주면 아래의 식을 얻을 수 있습니다.

 

\begin{align}\mathbb{E}_\textcolor{red}{x\sim p}[f(x)]&={\int} f(x)\textcolor{red}{p(x)}dx \\ &={\int}\left(\frac{p(x)}{q(x)}f(x)\right) \textcolor{green}{q(x)}dx \end{align}

 

이를 다시 q(x) 분포에 대한 sampling 근사를 수행하면 아래와 같이 importance sampling에 기반한 기댓값을 유도할 수 있습니다.

 

\mathbb{E}_\textcolor{red}{x\sim p}[f(x)] =\mathbb{E}_\textcolor{green}{x\sim q}\left[\frac{p(x)}{q(x)}f(x)\right] \simeq\frac{1}{N}\sum_{n=1}^N \frac{p(x)}{q(x)}f(x^{(n)})

 

이때 \frac{p(x)}{q(x)}를 importance weight라고 하며 다른 분포 q(x)로부터 sampling을 수행하여 \mathbb{E}_{x\sim p}[f(x)]를 계산했을 때 발생하는 에러를 보정하는 역할을 합니다. 기존 확률분포 p(x)를 nominal distribution, 다른 확률분포 q(x)를 importance distribution이라고 합니다.

 

Figure 1. 함수 f(x)의 확률분포 p(x)와 importance distribution q(x)의 예시

 

Importance Sampling를 사용하는 이유

기댓값 \mathbb{E}_{x\sim p}[f(x)]을 계산하기 위해 importance sampling를 사용하는 경우는 다음과 같습니다.

 

■ x\sim p로부터 sampling을 할 수 없는 경우

확률분포 p로부터 직접적으로 sampling 할 수 없으면 큰 수의 법칙을 적용하여 기댓값을 계산할 수 없습니다. 대신 다른 확률분포 q로부터 sample을 얻을 수 있다면 importance sampling을 통해 기댓값을 계산할 수 있습니다.

 

예를 들어 강화학습의 PPO 알고리즘에서 이용됩니다. 현재 policy \pi로부터 sampling된 목적함수의 기댓값 \mathbb{E}_{x\sim \pi}[f(x)]을 구하고 싶은데, 과거 policy \pi_\text{old}로부터 얻은 sample x\sim \pi_\text{old}만 있는 경우 importance sampling이 이용됩니다.

 

■ x\sim p로부터 sampling이 비효율적인 경우

Importance sampling은 variance를 줄이기 위한 방법으로 사용될 수 있습니다. 큰 수의 법칙에 따라서 획득한 sample 평균으로 기댓값을 계산하면, 계산 시마다 획득한 sample에 따라 기댓값이 달라지는 variance error가 발생하게 됩니다.

 

Figure 2. Importance sampling을 통한 variance redcution

[그림 2]에서 제일 왼쪽의 그림은 기댓값을 계산할 때 high variance가 발생하는 경우입니다. f(x)가 큰 값을 가진 중앙 부근에서 확률분포 pdf(x)가 낮은값을 가집니다. 따라서 낮은 확률로 sampling되는 이 중앙 부근의 sampling 빈도 수에 따라 f(x)의 기댓값 계산이 크게 달라지게 됩니다.

 

극단적인 경우에는 우리가 관심 있는 부분 (ex. 중앙 부근 값)이 한번도 sampling 되지 않아 기댓값을 구하기 어려울 수 있습니다. 따라서 희소한 sample의 관측을 위해 많은 sampling이 필요하며 비효율적인 computation이 소모될 수 있습니다.

 

[그림 2]에서 왼쪽 그림에서 오른쪽으로 갈수록 variance는 낮아집니다. 중앙의 그림은 uniform distribution을 이용하여 sampling하는 경우이며, 제일 왼쪽 그림에 비해 중앙 부근이 더 자주 sampling 되어 variance가 감소합니다. 제일 오른쪽 그림은 기댓값 계산 시 영향이 큰 중앙 부근의 sampling 확률을 가장 높게 가져가 variance를 더 낮출 수 있습니다.

 

위와 같이 더 관심이 있거나 중요한 부분을 더 많이 sampling 하기 때문에 importance sampling이라고 부릅니다. 따라서 importance sampling에서는 기존 확률분포 p(x)보다 중요한 값이 잘 뽑히도록 확률분포 q(x)를 선택해야 합니다.

 

참고로 importance sampling ratio \frac{p(x)}{q(x)} 없이 x\sim q로 획득한 sample을 이용하여 기댓값을 계산하면 variance는 낮아지지만 biased 된 estimator를 얻게 될 것입니다. 하지만 importance sampling ratio로 기댓값 계산을 보정해 주기 때문에 variance는 낮지만 unbiased 된 estimator를 얻을 수 있습니다.

 

References

[1] C. M. Bishop, "Pattern Recognition and Machine Learning," Springer, 2006.

[2] https://en.wikipedia.org/wiki/Importance_sampling


오타나 잘못된 부분 있으면 댓글 부탁드립니다. 도움이 되셨다면 공감 눌러 주시면 감사하겠습니다 :)

반응형

댓글