여러 분포를 결합하여 데이터의 분포를 모델링하는 Mixture Model

여러 분포를 결합하여 데이터의 분포를 모델링하는 Mixture Model
Photo by Rick Mason / Unsplash

Mixture Model is

  • Mixture Model은 여러개의 분포를 결합하여 데이터의 전체 분포를 모델링함
  • 이 때 각 분포는 Component라고 보통 불리며 데이터가 각 Component로부터 생성될 확률을 가중치로 갖음
  • 통상 실무에서 Mixture Model이라고 하면 대체로 GMM(Gaussian Mixture Model)임

Motivation

  • 대부분의 현실 데이터는 단일 분포로 설명하기 어렵기 때문에 Mixture Model이 쓸모가 있음
  • 히스토그램을 그렸을 때 쌍봉이 나오는 경우처럼, 여러 그룹의 데이터가 섞인 경우 이 경우 각 그룹을 별도의 분포로 모델링하는게 더욱 정확할 수 있음

Formula

  • Mixture Model의 PDF(Probability Density Function)은 다음과 같음
    $$p(x) = \sum_{k=1}^{K} \pi_k \cdot p_k(x \mid \theta_k)$$

  • $K$는 Component의 수

  • $\pi_k$​는 각 컴포넌트 kkk의 가중치로, $\sum_{k=1}^{K} \pi_k = 1$

  • $p_k(x \mid \theta_k)$ 는 Component $k$의 확률 밀도 함수

  • $theta_k$​는 Component $k$의 파라미터

$$p(x) = \sum_{k=1}^{K} \pi_k \cdot \mathcal{N}(x \mid \mu_k, \Sigma_k)$$

  • 위 수식은 Gaussian Mixture Model의 PDF(Probability Density Function)은 다음과 같음
  • $N(x∣μk​,Σk​)$는 평균이 $\mu_k$​이고 공분산 행렬이 $\Sigma_k$​인 가우시안 분포

Pros & Cons

Pros

  • 단일분포로 설명할 수 있는 복잡한 데이터 분포를 설명할 수 있음
  • 여러개의 Component를 결합하여 다양한 형태의 데이터 분포를 섦여할 수 있음
  • 특히 각 데이터가 특정 분포로 관측될 확률을 게종하기 때문에 데이터 이해도를 높일 수 있음

Cons

  • 최적의 분포를 구하기 위해서는 여러번 실행하면서 적합시켜야할 필요가 있음
  • 계산이 복잡하기 때문에 계산 비용이 Component의 수나 데이터의 크기에 민감하게 반응함
  • 그리고 너무 데이터의 분포에 최적화하다보면 모델이 과적홥될 수 있음

Alternatives

  • K Means Clustering: 간단하지만 데이터가 구형의 클러스터 안에 잘 속해야 성능이 좋다.
  • Hierarchical Clustering: 계층적으로 접근하는 클러스터링
  • DBScan: 클러스터링이 모양이 불규칙할 수록 성능이 좋으나 계산비용이 큰 편

Sample

from sklearn.mixture import GaussianMixture  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
  
# 데이터 생성  
np.random.seed(0)  
data1 = np.random.normal(0, 1, 500)  
data2 = np.random.normal(5, 1, 200)  
data3 = np.random.exponential(2, 100)  
data = np.concatenate([data1, data2, data3])  
  
# GMM 모델링  
gmm = GaussianMixture(n_components=3)  
gmm.fit(data.reshape(-1, 1))  
  
# GMM 결과 출력  
print("Means:", gmm.means_)  
print("Covariances:", gmm.covariances_)  
print("Weights:", gmm.weights_)  
  
# 데이터 분포 시각화  
plt.figure(figsize=(12, 6))  
sns.histplot(data, bins=30, kde=False, color='g', stat='density', label='Data')  
x = np.linspace(min(data), max(data), 1000).reshape(-1, 1)  
logprob = gmm.score_samples(x)  
pdf = np.exp(logprob)  
plt.plot(x, pdf, '-k', label='GMM')  
plt.title('Data with GMM Fit')  
plt.xlabel('Value')  
plt.ylabel('Density')  
plt.legend()  
  
# 각 컴포넌트의 기여도 시각화  
for i in range(gmm.n_components):  
    pdf_i = gmm.weights_[i] * np.exp(gmm._estimate_log_prob(x)[:, i])  
    plt.plot(x, pdf_i, '--', label=f'Component {i+1}')  
  
plt.legend()  
plt.show()  
  
# 데이터의 특성 분석  
print("Data Mean:", np.mean(data))  
print("Data Std Dev:", np.std(data))