다항분포의 사전분포로 사용되는 다변량 확률분포, Dirichlet 분포

다항분포의 사전분포로 사용되는 다변량 확률분포, Dirichlet 분포
Photo by CHUTTERSNAP / Unsplash

정의

  • Dirichlet 분포는 다항 분포의 사전 분포로서 사용되는 다변량 확률 분포입니다.
  • 각 항목이 0과 1 사이의 값을 가지며 모든 항목의 합이 1이 되는 특성을 가집니다.
  • 특히 베이지안 통계에서 다항 분포의 사전 분포로 자주 사용됩니다. 이는 켤레 분포(conjugate prior)의 특징을 갖습니다.
    • Dirichlet 분포의 켤레성은 다음과 같은 베이지안 업데이트 규칙을 가집니다$$\alpha_i' = \alpha_i + x_i$$
      • 여기서 $x_i$​는 관측된 데이터의 빈도입니다.
  • PDF 함수는 다음과 같습니다.
    $$f(x_1, x_2, \ldots, x_K; \alpha_1, \alpha_2, \ldots, \alpha_K) = \frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
    $$B(\alpha) = \frac{\prod_{i=1}^{K} \Gamma(\alpha_i)}{\Gamma\left(\sum_{i=1}^{K} \alpha_i\right)}$$
  • 베타 함수는 감마 함수 $\Gamma$ 를 사용하여 계산됩니다. $\alpha = (\alpha_1, \alpha_2, \ldots, \alpha_K)$는 파라미터 벡터입니다.
    • 감마 함수는 $0z>0$인 모든 실수에 대해 정의되며, 팩토리얼 함수의 일반화로 볼 수 있습니다. 즉, $n$이 자연수일 때 감마 함수는 다음과 같은 관계를 가집니다:
    • 감마 함수의 정의는 지수 함수 $e^{-t}$와 다항식 $t^{z-1}$의 곱으로 구성된 적분입니다.
  • 이는 감마 함수가 특정 함수의 곱의 면적을 계산한다는 것을 의미합니다.
  • 지수 함수는 빠르게 감소하며, 다항식은 ttt에 따라 증가하거나 감소할 수 있습니다. 이 두 함수의 곱을 적분함으로써 감마 함수의 값을 구할 수 있습니다.

Pros & Cons

Pros

  • 다양한 표현력: Dirichlet 분포는 다양한 형태의 데이터 분포를 모델링할 수 있습니다.
  • 사전 분포로 유용: 베이지안 통계에서 다항 분포의 사전 분포로 자주 사용됩니다.
  • 해석 용이: 파라미터 $\alpha$의 값에 따라 분포의 모양이 직관적으로 변경됩니다.
import numpy as np  
import matplotlib.pyplot as plt  
from scipy.stats import dirichlet  
 
fig = plt.figure(figsize=(18, 6))  
 
# 첫 번째 예시: 균등한 alpha 값  
alpha1 = [1, 1, 1]  
data1 = dirichlet.rvs(alpha1, size=5000)  
ax1 = fig.add_subplot(131, projection='3d')  
ax1.scatter(data1[:, 0], data1[:, 1], data1[:, 2], alpha=0.3)  
ax1.set_xlabel('X1')  
ax1.set_ylabel('X2')  
ax1.set_zlabel('X3')  
ax1.set_title('alpha=[1, 1, 1]')  
 
# 두 번째 예시: 특정 범주를 선호하는 alpha 값  
alpha2 = [10, 1, 1]  
data2 = dirichlet.rvs(alpha2, size=5000)  
ax2 = fig.add_subplot(132, projection='3d')  
ax2.scatter(data2[:, 0], data2[:, 1], data2[:, 2], alpha=0.3)  
ax2.set_xlabel('X1')  
ax2.set_ylabel('X2')  
ax2.set_zlabel('X3')  
ax2.set_title('alpha=[10, 1, 1]')  
 
# 세 번째 예시: 매우 높은 확률 집중  
alpha3 = [10, 10, 10]  
data3 = dirichlet.rvs(alpha3, size=5000)  
ax3 = fig.add_subplot(133, projection='3d')  
ax3.scatter(data3[:, 0], data3[:, 1], data3[:, 2], alpha=0.3)  
ax3.set_xlabel('X1')  
ax3.set_ylabel('X2')  
ax3.set_zlabel('X3')  
ax3.set_title('alpha=[10, 10, 10]')  
 
plt.suptitle('Dirichlet Distribution with Different alpha Values')  
plt.show()

  • 베이지안 업데이트: 관측 데이터에 따라 파라미터를 쉽게 갱신할 수 있습니다.
  • 상호작용 모델링: 여러 범주 간의 상호작용을 모델링하는 데 유용합니다.

Cons

  • 복잡한 계산: 고차원에서는 계산이 복잡할 수 있습니다.
  • 제한된 표현력: 특정 형태의 데이터에는 적합하지 않을 수 있습니다.
  • 파라미터 민감도: 파라미터 설정에 따라 결과가 크게 달라질 수 있습니다.
  • 해석의 어려움: 파라미터가 많아질수록 해석이 어려워질 수 있습니다.

Alternatives

  • Multinomial 분포: Dirichlet 분포의 대안으로, 단일 시도의 결과를 여러 범주로 분류할 때 사용됩니다.
  • Beta 분포: 이항 분포의 사전 분포로 사용되며, Dirichlet 분포의 2차원 버전입니다.
  • Gaussian Mixture Model (GMM): 연속형 데이터의 클러스터링에 사용되며, 각 클러스터가 가우시안 분포를 따릅니다.
  • Categorical 분포: 각 범주의 확률을 직접 모델링하는 데 사용됩니다.

Sample

import numpy as np  
import matplotlib.pyplot as plt  
from scipy.stats import dirichlet  
  
# Dirichlet 분포의 초기 파라미터 설정  
alpha = [1, 2, 3]  
  
# 초기 Dirichlet 분포에서 샘플링  
initial_data = dirichlet.rvs(alpha, size=5000)  
  
# 시각화  
fig = plt.figure(figsize=(12, 6))  
ax1 = fig.add_subplot(121, projection='3d')  
ax1.scatter(initial_data[:, 0], initial_data[:, 1], initial_data[:, 2], alpha=0.3)  
ax1.set_xlabel('X1')  
ax1.set_ylabel('X2')  
ax1.set_zlabel('X3')  
ax1.set_title('Initial Dirichlet Distribution')  
  
# 관측 데이터 추가 (베이지안 업데이트)  
observed_counts = [10, 5, 15]  
alpha_updated = np.add(alpha, observed_counts)  
  
# 업데이트된 Dirichlet 분포에서 샘플링  
updated_data = dirichlet.rvs(alpha_updated, size=5000)  
  
# 시각화  
ax2 = fig.add_subplot(122, projection='3d')  
ax2.scatter(updated_data[:, 0], updated_data[:, 1], updated_data[:, 2], alpha=0.3)  
ax2.set_xlabel('X1')  
ax2.set_ylabel('X2')  
ax2.set_zlabel('X3')  
ax2.set_title('Updated Dirichlet Distribution')  
  
plt.show()