So-net無料ブログ作成
  • ブログをはじめる
  • ログイン

Stan: ガウス過程をためしてみたメモ(2) [統計]

こんどは、2次元の入力に対するクラス分類をためしてみました。

Stanコードです。マニュアルのモデルを合成して つくりました。

// This model is derived from Chapter 18 of
// Stan Modeling Language User's Guide Reference Manual
// by Stan Development Team
// http://mc-stan.org/

functions {
  matrix L_cov_exp_quad_ARD(vector[] x,
                            real alpha,
                            vector rho,
                            real delta) {
    int N = size(x);
    matrix[N, N] K;
    real sq_alpha = square(alpha);
    for (i in 1:(N - 1)) {
      K[i, i] = sq_alpha + delta;
      for (j in (i + 1):N) {
        K[i, j] = sq_alpha
                      * exp(-0.5 * dot_self((x[i] - x[j]) ./ rho));
        K[j, i] = K[i, j];
      }
    }
    K[N, N] = sq_alpha + delta;
    return cholesky_decompose(K);
  }
}

data {
  int<lower=1> N1;
  int<lower=1> D;
  vector[D] X1[N1];
  int Y1[N1];
  int<lower=1> N2;
  vector[D] X2[N2];
}

transformed data {
  real delta = 1e-9;
  int<lower=1> N = N1 + N2;
  vector[D] x[N];
  for (n1 in 1:N1) x[n1] = X1[n1];
  for (n2 in 1:N2) x[N1 + n2] = X2[n2];
}

parameters {
  vector<lower=0>[D] rho;
  real<lower=0> alpha;
  real a;
  vector[N] eta;
}

transformed parameters {
  vector[N] f;
  {
    matrix[N, N] L_K = L_cov_exp_quad_ARD(x, alpha, rho, delta);
    f = L_K * eta;
  }
}

model {
  rho ~ inv_gamma(5, 5);
  alpha ~ normal(0, 1);
  a ~ normal(0, 1);
  eta ~ normal(0, 1);

  Y1 ~ bernoulli_logit(a + f[1:N1]);
}

generated quantities {
  int y2[N2];
  for (n2 in 1:N2)
    y2[n2] = bernoulli_logit_rng(a + f[N1 + n2]);
}

テストに用意したデータです。

library(dplyr)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

model_file <- "GP3.stan"

set.seed(109)
N <- 64
x <- matrix(runif(N * 2, -1, 1), ncol = 2)
y <- xor(x[, 1] > 0, x[, 2] > 0)

data.frame(x1 = x[, 1], x2 = x[, 2], y = y) %>%
  ggplot() +
  geom_point(aes(x = x1, y = x2, colour = y)) +
  coord_fixed()

Rplot02.png

それぞれの象限に新データをおいて、予測をおこないます。

x_new <- matrix(c(0.5,  0.5, -0.5, -0.5,
                  0.5, -0.5,  0.5, -0.5), ncol = 2)

fit <- stan(model_file,
            data = list(N1 = N,
                        D = 2,
                        X1 = x,
                        Y1 = as.integer(y),
                        N2 = nrow(x_new),
                        X2 = x_new))
print(fit, pars = c("rho", "alpha", "a", "y2", "lp__"))

結果です。

Inference for Stan model: GP3.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

         mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
rho[1]   0.68    0.00 0.17   0.42   0.56   0.65   0.77   1.06  2452    1
rho[2]   0.69    0.00 0.17   0.44   0.57   0.66   0.78   1.08  3174    1
alpha    2.79    0.01 0.58   1.75   2.37   2.75   3.17   4.00  4000    1
a        0.01    0.02 0.86  -1.65  -0.58   0.02   0.58   1.67  2909    1
y2[1]    0.06    0.00 0.25   0.00   0.00   0.00   0.00   1.00  3939    1
y2[2]    0.89    0.00 0.31   0.00   1.00   1.00   1.00   1.00  4000    1
y2[3]    0.94    0.00 0.24   0.00   1.00   1.00   1.00   1.00  4000    1
y2[4]    0.04    0.00 0.20   0.00   0.00   0.00   0.00   1.00  4000    1
lp__   -66.28    0.14 5.90 -78.79 -70.08 -65.89 -62.22 -55.47  1714    1

タグ:STAn
nice!(1)  コメント(0) 
共通テーマ:日記・雑感

nice! 1

コメント 0

コメントを書く

お名前:
URL:
コメント:
画像認証:
下の画像に表示されている文字を入力してください。

Facebook コメント