<НА ГЛАВНУЮ

Реализация иерархической байесовской регрессии с NumPyro

Пошаговое руководство по иерархической байесовской регрессии с использованием NumPyro.

Обзор

В этом руководстве мы исследуем иерархическую байесовскую регрессию с NumPyro и подробно разберем весь рабочий процесс. Мы начинаем с генерации синтетических данных, затем определяем вероятностную модель, захватывающую как глобальные паттерны, так и групповые вариации. На каждом этапе мы настраиваем вывод с помощью NUTS, анализируем постериорные распределения и выполняем проверки постериорного предсказания.

Настройка окружения

Прежде чем погрузиться в моделирование, мы настраиваем окружение, устанавливая NumPyro и импортируя необходимые библиотеки. Это гарантирует, что наша сессия Colab полностью готова для иерархического моделирования.

try:
   import numpyro
except ImportError:
   !pip install -q "llvmlite>=0.45.1" "numpyro[cpu]" matplotlib pandas
 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import random
import numpyro
def generate_data(key, n_groups=8, n_per_group=40):
   # Функция
 
key = random.PRNGKey(0)
df, truth = generate_data(key)

Генерация синтетических данных

Далее мы генерируем синтетические иерархические данные, имитирующие реальную вариацию на уровне группы. Мы подготавливаем данные для обработки в JAX:

def generate_data(key, n_groups=8, n_per_group=40):
   # Функция
 
# Генерация данных
key = random.PRNGKey(0)
df, truth = generate_data(key)

Определение модели иерархической регрессии

Мы определяем нашу модель иерархической регрессии и запускаем выборку на основе NUTS:

def hierarchical_regression_model(x, group_idx, n_groups, y=None):
   # Функция
 
nuts = NUTS(hierarchical_regression_model, target_accept_prob=0.9)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=1)
mcmc.run(random.PRNGKey(1), x=x, group_idx=groups, n_groups=n_groups, y=y)
samples = mcmc.get_samples()

Анализ постериорных выборок

Мы вычисляем сводные данные о наших постериорных выборках и проводим проверки постериорного предсказания:

def param_summary(arr):
   # Функция
 
for name in ["mu_alpha", "mu_beta", "sigma_alpha", "sigma_beta", "sigma_obs"]:
   m, lo, hi = param_summary(samples[name])
   print(f"{name}: mean={m:.3f}, HPDI=[{lo:.3f}, {hi:.3f}]")

Визуализация результатов

Наконец, мы визуализируем оцененные перехваты и наклоны на уровне группы, чтобы сравнить их с истинными значениями:

alpha_g = np.asarray(samples["alpha_g"]).mean(axis=0)
beta_g = np.asarray(samples["beta_g"]).mean(axis=0)
 
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].bar(range(n_groups), alpha_g)
axes[0].axhline(truth["true_alpha"], linestyle="--")
plt.show()

Заключение

Мы реализовали комплексный рабочий процесс иерархической байесовской регрессии с помощью NumPyro. Этот процесс позволяет нам эффективно моделировать иерархические взаимосвязи, используя преимущества вывода на основе JAX.

🇬🇧

Switch Language

Read this article in English

Switch to English