Relearning Turing.jl

Author

Laus Wullum

Published

July 30, 2022

Relearning Turing.jl

The Julia PPL Turing.jl has had some updates. This blog post serves as a very basic introduction on how to work with a Turing.jl model today.

Data

We will use a dataset from a chapter on longitudinal data in the excellent online book on Applied Modelling in Drug Development written by statisticians from Novartis.

The data comes from a fictional longitudinal phase II dermatology study, focused on the PASI score, which measures psoriasis severity. Specifically, we will model the treatment effect compared to placebo on the PASI score at week 12, adjusting for baseline scores. For simplicity, we ignore the longitudinal aspect of the study, as the goal is to demonstrate Bayesian inference using Turing.jl in Julia with real, non-simulated data.

using Turing
using CSV
using DataFrames
using TidierData 
using TidierPlots
using CategoricalArrays
using FillArrays
using LinearAlgebra 
using AlgebraOfGraphics 
using CairoMakie 


# Load data from the parent folder
adpasi = CSV.read(joinpath("..", "..", "data", "longitudinal.csv"), DataFrame)
#adpasi = CSV.read(joinpath("data", "longitudinal.csv"), DataFrame)

# Name datasets and filter for PASI endpoint at Week12
pasi_data = @chain adpasi begin
    @filter(TRT01P in ["PBO", "TRT"])
    @filter(PARAMCD == "PASITSCO")
    @arrange(AVISITN)
end

xmat = @chain pasi_data begin
    @filter(AVISIT == "Week 12")
    @mutate(TRT01P1 = TRT01P == "TRT")
    @select(TRT01P1, TRT01P, BASE, AVAL)
end

# Design matrix
covariate_mat = [
    xmat[!, "TRT01P1"] parse.(Float64, xmat[!, "BASE"])
]

# Outcome measurements at week 12
pasi_score_week_12 = xmat[!, "AVAL"]
108-element Vector{Float64}:
 36.7
  2.0
  5.9
  3.8
  0.5
  3.9
  7.7
  6.4
  4.4
 42.6
  ⋮
 20.4
  0.3
 13.2
  4.9
  5.1
  5.4
  2.7
  3.8
  6.7

Overview

data(pasi_data) *
    visual(Violin) *
    mapping(
        :AVISIT => "Visit",
        :AVAL => "PASI score",
        color = :TRT01P => "Treatment",
        dodge = :TRT01P,
    ) |>
    draw

Bayesian linear model

We are going to posit the following model, where TRT01P1 denotes treatment and BASE is the baseline PASI score, and AVAL is the outcome score at week 12. The specific prior does not concern us.

More programmatically,

AVAL ~ BASE + TRT01P

\[ \begin{align} PASI_i &\sim \mathcal{N}(\beta^T X_i, \sigma^2)\\ \sigma^2 &\sim \mathcal{N}_+ (0, 10^2) \\ \beta_i &\sim \mathcal{N}(0, 10) \end{align} \]

Turing.jl model

In Turing.jl this becomes:

@model function lin_reg(x)
    sigma_sq ~ truncated(Normal(0, 10); lower = 0)
    intercept ~ Normal(0, 10)
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(Zeros(nfeatures), 10.0 * I)
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sigma_sq * I)

    return mean(y)
end
lin_reg (generic function with 2 methods)

Model object

Now we can define the model unconditionally.

model = lin_reg(covariate_mat)
DynamicPPL.Model{typeof(lin_reg), (:x,), (), (), Tuple{Matrix{Float64}}, Tuple{}, DynamicPPL.DefaultContext, false}(lin_reg, (x = [0.0 16.0; 1.0 20.3; … ; 1.0 21.8; 1.0 19.1],), NamedTuple(), DynamicPPL.DefaultContext())

With this we can sample from the prior predictive distribution:

pp_data = rand(model)
(sigma_sq = 2.6299689043287127, intercept = 5.731365886684755, coefficients = [4.270294035982309, -1.118487027986863], y = [-9.692395739250431, -10.750549634135126, -24.384185242507645, -42.668522176141394, -23.822141556459396, -23.194242219158593, -11.284806802748049, -9.243327951806984, -5.288356451009132, -9.382948615984645  …  -10.82608651696414, -7.827229384270633, -4.744048122236251, -8.771013855500414, -14.392909563770079, 0.12309371547625364, -12.602665152478432, -9.31108131611646, -15.380616188154741, -11.480693537641972])

This yields a named tuple with parameters drawn and data sampled.

We can plot the data envisioned by our prior specification. (The prior predictive distribution is not very good, but that is beside the point here.)

xmat.AVAL_PP1 = pp_data.y
data(xmat) *
    visual(Violin) *
    mapping(
        :TRT01P => "Treatment",
        :AVAL_PP1 => "Prior predictive PASI score at week 12",
        color = :TRT01P => "Treatment",
        dodge = :TRT01P,
    ) |>
    draw

We can also fix the parameters at specific values in the model.

params_gen = (sigma_sq = 1, coefficients = [0.1, 0.3], intercept = 1)

model_gen = fix(model, params_gen)

rand(model_gen)
(y = [5.205597313744559, 6.771334506466306, 10.689555673368991, 15.113111600246173, 9.248387157608361, 10.66684337583699, 6.710425394521913, 3.92662816761732, 7.901384011266981, 6.3228237142482575  …  5.853938311832917, 5.27177248983979, 6.552936721489295, 6.875741667809666, 7.0022151742315035, 2.5348891654338037, 6.48132319067105, 6.178698337691365, 7.802349082142385, 7.3161834763919495],)

We can posit a conditional model on the observed data, to perform inference on the posterior distribution of the parameters.

model_cond = model | (y = pasi_score_week_12,)
DynamicPPL.Model{typeof(lin_reg), (:x,), (), (), Tuple{Matrix{Float64}}, Tuple{}, DynamicPPL.ConditionContext{@NamedTuple{y::Vector{Float64}}, DynamicPPL.DefaultContext}, false}(lin_reg, (x = [0.0 16.0; 1.0 20.3; … ; 1.0 21.8; 1.0 19.1],), NamedTuple(), ConditionContext((y = [36.7, 2.0, 5.9, 3.8, 0.5, 3.9, 7.7, 6.4, 4.4, 42.6, 10.8, 10.9, 7.8, 0.7, 5.5, 10.8, 5.8, 9.5, 6.2, 0.5, 41.6, 9.9, 17.7, 32.4, 9.1, 0.4, 32.9, 11.4, 4.5, 7.7, 4.2, 2.6, 2.6, 4.5, 7.0, 0.4, 1.2, 37.8, 3.7, 7.3, 3.5, 5.5, 9.5, 5.2, -0.2, 3.1, 7.5, 5.2, 4.8, 10.9, -0.4, 1.0, 4.3, 4.6, 4.5, 0.3, 37.7, 12.3, 3.0, 2.0, 11.5, 8.4, 4.0, 7.7, 4.3, 4.2, 3.8, 2.9, 9.6, 29.6, 5.2, 1.9, 5.1, 12.6, 27.8, 44.1, 5.0, 40.5, 8.4, 14.5, 3.5, 7.8, 23.5, 17.6, 5.7, 27.1, 0.5, 5.6, 4.2, 17.0, 1.6, 18.2, 1.1, 28.0, 4.8, 27.0, -0.2, 16.7, 33.7, 20.4, 0.3, 13.2, 4.9, 5.1, 5.4, 2.7, 3.8, 6.7],), DynamicPPL.DefaultContext()))

This object can be used to perform inference.

post_inf_chain = sample(model_cond, NUTS(), 1000)
Sampling   0%|█                                         |  ETA: N/A
Info: Found initial step size
  ϵ = 0.0125
Sampling   1%|█                                         |  ETA: 1:19:03
Sampling   1%|█                                         |  ETA: 0:42:04
Sampling   2%|█                                         |  ETA: 0:27:17
Sampling   2%|█                                         |  ETA: 0:20:50
Sampling   3%|██                                        |  ETA: 0:16:22
Sampling   3%|██                                        |  ETA: 0:13:46
Sampling   4%|██                                        |  ETA: 0:11:38
Sampling   4%|██                                        |  ETA: 0:10:13
Sampling   5%|██                                        |  ETA: 0:08:58
Sampling   5%|███                                       |  ETA: 0:08:10
Sampling   6%|███                                       |  ETA: 0:07:20
Sampling   6%|███                                       |  ETA: 0:06:44
Sampling   7%|███                                       |  ETA: 0:06:09
Sampling   7%|███                                       |  ETA: 0:05:43
Sampling   8%|████                                      |  ETA: 0:05:17
Sampling   8%|████                                      |  ETA: 0:04:57
Sampling   9%|████                                      |  ETA: 0:04:37
Sampling   9%|████                                      |  ETA: 0:04:21
Sampling  10%|█████                                     |  ETA: 0:04:05
Sampling  10%|█████                                     |  ETA: 0:03:53
Sampling  11%|█████                                     |  ETA: 0:03:40
Sampling  11%|█████                                     |  ETA: 0:03:29
Sampling  12%|█████                                     |  ETA: 0:03:18
Sampling  12%|██████                                    |  ETA: 0:03:10
Sampling  13%|██████                                    |  ETA: 0:03:00
Sampling  13%|██████                                    |  ETA: 0:02:53
Sampling  14%|██████                                    |  ETA: 0:02:45
Sampling  14%|██████                                    |  ETA: 0:02:39
Sampling  15%|███████                                   |  ETA: 0:02:32
Sampling  15%|███████                                   |  ETA: 0:02:27
Sampling  16%|███████                                   |  ETA: 0:02:21
Sampling  16%|███████                                   |  ETA: 0:02:16
Sampling  17%|███████                                   |  ETA: 0:02:12
Sampling  17%|████████                                  |  ETA: 0:02:08
Sampling  18%|████████                                  |  ETA: 0:02:03
Sampling  18%|████████                                  |  ETA: 0:01:59
Sampling  19%|████████                                  |  ETA: 0:01:55
Sampling  19%|████████                                  |  ETA: 0:01:52
Sampling  20%|█████████                                 |  ETA: 0:01:48
Sampling  20%|█████████                                 |  ETA: 0:01:45
Sampling  21%|█████████                                 |  ETA: 0:01:42
Sampling  21%|█████████                                 |  ETA: 0:01:39
Sampling  22%|██████████                                |  ETA: 0:01:36
Sampling  22%|██████████                                |  ETA: 0:01:33
Sampling  23%|██████████                                |  ETA: 0:01:30
Sampling  23%|██████████                                |  ETA: 0:01:28
Sampling  24%|██████████                                |  ETA: 0:01:25
Sampling  24%|███████████                               |  ETA: 0:01:23
Sampling  25%|███████████                               |  ETA: 0:01:21
Sampling  25%|███████████                               |  ETA: 0:01:19
Sampling  26%|███████████                               |  ETA: 0:01:17
Sampling  26%|███████████                               |  ETA: 0:01:15
Sampling  27%|████████████                              |  ETA: 0:01:13
Sampling  27%|████████████                              |  ETA: 0:01:11
Sampling  28%|████████████                              |  ETA: 0:01:09
Sampling  28%|████████████                              |  ETA: 0:01:08
Sampling  29%|████████████                              |  ETA: 0:01:06
Sampling  29%|█████████████                             |  ETA: 0:01:04
Sampling  30%|█████████████                             |  ETA: 0:01:03
Sampling  30%|█████████████                             |  ETA: 0:01:01
Sampling  31%|█████████████                             |  ETA: 0:01:00
Sampling  31%|██████████████                            |  ETA: 0:00:59
Sampling  32%|██████████████                            |  ETA: 0:00:57
Sampling  32%|██████████████                            |  ETA: 0:00:56
Sampling  33%|██████████████                            |  ETA: 0:00:55
Sampling  33%|██████████████                            |  ETA: 0:00:53
Sampling  34%|███████████████                           |  ETA: 0:00:54
Sampling  34%|███████████████                           |  ETA: 0:00:53
Sampling  35%|███████████████                           |  ETA: 0:00:51
Sampling  35%|███████████████                           |  ETA: 0:00:50
Sampling  36%|███████████████                           |  ETA: 0:00:49
Sampling  36%|████████████████                          |  ETA: 0:00:48
Sampling  37%|████████████████                          |  ETA: 0:00:51
Sampling  37%|████████████████                          |  ETA: 0:00:52
Sampling  38%|████████████████                          |  ETA: 0:00:51
Sampling  38%|████████████████                          |  ETA: 0:00:50
Sampling  39%|█████████████████                         |  ETA: 0:00:49
Sampling  39%|█████████████████                         |  ETA: 0:00:48
Sampling  40%|█████████████████                         |  ETA: 0:00:47
Sampling  40%|█████████████████                         |  ETA: 0:00:46
Sampling  41%|██████████████████                        |  ETA: 0:00:45
Sampling  41%|██████████████████                        |  ETA: 0:00:44
Sampling  42%|██████████████████                        |  ETA: 0:00:43
Sampling  42%|██████████████████                        |  ETA: 0:00:42
Sampling  43%|██████████████████                        |  ETA: 0:00:41
Sampling  43%|███████████████████                       |  ETA: 0:00:41
Sampling  44%|███████████████████                       |  ETA: 0:00:40
Sampling  44%|███████████████████                       |  ETA: 0:00:39
Sampling  45%|███████████████████                       |  ETA: 0:00:38
Sampling  45%|███████████████████                       |  ETA: 0:00:37
Sampling  46%|████████████████████                      |  ETA: 0:00:37
Sampling  46%|████████████████████                      |  ETA: 0:00:36
Sampling  47%|████████████████████                      |  ETA: 0:00:35
Sampling  47%|████████████████████                      |  ETA: 0:00:35
Sampling  48%|████████████████████                      |  ETA: 0:00:34
Sampling  48%|█████████████████████                     |  ETA: 0:00:33
Sampling  49%|█████████████████████                     |  ETA: 0:00:33
Sampling  49%|█████████████████████                     |  ETA: 0:00:32
Sampling  50%|█████████████████████                     |  ETA: 0:00:31
Sampling  50%|██████████████████████                    |  ETA: 0:00:31
Sampling  51%|██████████████████████                    |  ETA: 0:00:30
Sampling  51%|██████████████████████                    |  ETA: 0:00:29
Sampling  52%|██████████████████████                    |  ETA: 0:00:29
Sampling  52%|██████████████████████                    |  ETA: 0:00:28
Sampling  53%|███████████████████████                   |  ETA: 0:00:28
Sampling  53%|███████████████████████                   |  ETA: 0:00:27
Sampling  54%|███████████████████████                   |  ETA: 0:00:27
Sampling  54%|███████████████████████                   |  ETA: 0:00:26
Sampling  55%|███████████████████████                   |  ETA: 0:00:26
Sampling  55%|████████████████████████                  |  ETA: 0:00:25
Sampling  56%|████████████████████████                  |  ETA: 0:00:25
Sampling  56%|████████████████████████                  |  ETA: 0:00:24
Sampling  57%|████████████████████████                  |  ETA: 0:00:24
Sampling  57%|████████████████████████                  |  ETA: 0:00:23
Sampling  58%|█████████████████████████                 |  ETA: 0:00:23
Sampling  58%|█████████████████████████                 |  ETA: 0:00:22
Sampling  59%|█████████████████████████                 |  ETA: 0:00:22
Sampling  59%|█████████████████████████                 |  ETA: 0:00:21
Sampling  60%|██████████████████████████                |  ETA: 0:00:21
Sampling  60%|██████████████████████████                |  ETA: 0:00:20
Sampling  61%|██████████████████████████                |  ETA: 0:00:20
Sampling  61%|██████████████████████████                |  ETA: 0:00:20
Sampling  62%|██████████████████████████                |  ETA: 0:00:19
Sampling  62%|███████████████████████████               |  ETA: 0:00:19
Sampling  63%|███████████████████████████               |  ETA: 0:00:18
Sampling  63%|███████████████████████████               |  ETA: 0:00:18
Sampling  64%|███████████████████████████               |  ETA: 0:00:18
Sampling  64%|███████████████████████████               |  ETA: 0:00:17
Sampling  65%|████████████████████████████              |  ETA: 0:00:17
Sampling  65%|████████████████████████████              |  ETA: 0:00:17
Sampling  66%|████████████████████████████              |  ETA: 0:00:16
Sampling  66%|████████████████████████████              |  ETA: 0:00:16
Sampling  67%|████████████████████████████              |  ETA: 0:00:15
Sampling  67%|█████████████████████████████             |  ETA: 0:00:15
Sampling  68%|█████████████████████████████             |  ETA: 0:00:15
Sampling  68%|█████████████████████████████             |  ETA: 0:00:14
Sampling  69%|█████████████████████████████             |  ETA: 0:00:14
Sampling  69%|█████████████████████████████             |  ETA: 0:00:14
Sampling  70%|██████████████████████████████            |  ETA: 0:00:13
Sampling  70%|██████████████████████████████            |  ETA: 0:00:13
Sampling  71%|██████████████████████████████            |  ETA: 0:00:13
Sampling  71%|██████████████████████████████            |  ETA: 0:00:13
Sampling  72%|███████████████████████████████           |  ETA: 0:00:12
Sampling  72%|███████████████████████████████           |  ETA: 0:00:12
Sampling  73%|███████████████████████████████           |  ETA: 0:00:12
Sampling  73%|███████████████████████████████           |  ETA: 0:00:11
Sampling  74%|███████████████████████████████           |  ETA: 0:00:11
Sampling  74%|████████████████████████████████          |  ETA: 0:00:11
Sampling  75%|████████████████████████████████          |  ETA: 0:00:11
Sampling  75%|████████████████████████████████          |  ETA: 0:00:10
Sampling  76%|████████████████████████████████          |  ETA: 0:00:10
Sampling  76%|████████████████████████████████          |  ETA: 0:00:10
Sampling  77%|█████████████████████████████████         |  ETA: 0:00:09
Sampling  77%|█████████████████████████████████         |  ETA: 0:00:09
Sampling  78%|█████████████████████████████████         |  ETA: 0:00:09
Sampling  78%|█████████████████████████████████         |  ETA: 0:00:09
Sampling  79%|█████████████████████████████████         |  ETA: 0:00:08
Sampling  79%|██████████████████████████████████        |  ETA: 0:00:08
Sampling  80%|██████████████████████████████████        |  ETA: 0:00:08
Sampling  80%|██████████████████████████████████        |  ETA: 0:00:08
Sampling  81%|██████████████████████████████████        |  ETA: 0:00:07
Sampling  81%|███████████████████████████████████       |  ETA: 0:00:07
Sampling  82%|███████████████████████████████████       |  ETA: 0:00:07
Sampling  82%|███████████████████████████████████       |  ETA: 0:00:07
Sampling  83%|███████████████████████████████████       |  ETA: 0:00:07
Sampling  83%|███████████████████████████████████       |  ETA: 0:00:06
Sampling  84%|████████████████████████████████████      |  ETA: 0:00:06
Sampling  84%|████████████████████████████████████      |  ETA: 0:00:06
Sampling  85%|████████████████████████████████████      |  ETA: 0:00:06
Sampling  85%|████████████████████████████████████      |  ETA: 0:00:05
Sampling  86%|████████████████████████████████████      |  ETA: 0:00:05
Sampling  86%|█████████████████████████████████████     |  ETA: 0:00:05
Sampling  87%|█████████████████████████████████████     |  ETA: 0:00:05
Sampling  87%|█████████████████████████████████████     |  ETA: 0:00:05
Sampling  88%|█████████████████████████████████████     |  ETA: 0:00:04
Sampling  88%|█████████████████████████████████████     |  ETA: 0:00:04
Sampling  89%|██████████████████████████████████████    |  ETA: 0:00:04
Sampling  89%|██████████████████████████████████████    |  ETA: 0:00:04
Sampling  90%|██████████████████████████████████████    |  ETA: 0:00:04
Sampling  90%|██████████████████████████████████████    |  ETA: 0:00:03
Sampling  91%|███████████████████████████████████████   |  ETA: 0:00:03
Sampling  91%|███████████████████████████████████████   |  ETA: 0:00:03
Sampling  92%|███████████████████████████████████████   |  ETA: 0:00:03
Sampling  92%|███████████████████████████████████████   |  ETA: 0:00:03
Sampling  93%|███████████████████████████████████████   |  ETA: 0:00:02
Sampling  93%|████████████████████████████████████████  |  ETA: 0:00:02
Sampling  94%|████████████████████████████████████████  |  ETA: 0:00:02
Sampling  94%|████████████████████████████████████████  |  ETA: 0:00:02
Sampling  95%|████████████████████████████████████████  |  ETA: 0:00:02
Sampling  95%|████████████████████████████████████████  |  ETA: 0:00:02
Sampling  96%|█████████████████████████████████████████ |  ETA: 0:00:01
Sampling  96%|█████████████████████████████████████████ |  ETA: 0:00:01
Sampling  97%|█████████████████████████████████████████ |  ETA: 0:00:01
Sampling  97%|█████████████████████████████████████████ |  ETA: 0:00:01
Sampling  98%|█████████████████████████████████████████ |  ETA: 0:00:01
Sampling  98%|██████████████████████████████████████████|  ETA: 0:00:01
Sampling  99%|██████████████████████████████████████████|  ETA: 0:00:00
Sampling  99%|██████████████████████████████████████████|  ETA: 0:00:00
Sampling 100%|██████████████████████████████████████████|  ETA: 0:00:00
Sampling 100%|██████████████████████████████████████████| Time: 0:00:30
Sampling 100%|██████████████████████████████████████████| Time: 0:00:45
Chains MCMC chain (1000×18×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 33.87 seconds
Compute duration  = 33.87 seconds
parameters        = sigma_sq, intercept, coefficients[1], coefficients[2]
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint

Use `describe(chains)` for summary statistics and quantiles.

The samples are collected into an MCMCchain.jl object, which can be used for convergence diagnostics.

To obtain posterior predictive samples we call predict on the model using the posterior draws.

predict(model, post_inf_chain)
Chains MCMC chain (1000×108×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8], y[9], y[10], y[11], y[12], y[13], y[14], y[15], y[16], y[17], y[18], y[19], y[20], y[21], y[22], y[23], y[24], y[25], y[26], y[27], y[28], y[29], y[30], y[31], y[32], y[33], y[34], y[35], y[36], y[37], y[38], y[39], y[40], y[41], y[42], y[43], y[44], y[45], y[46], y[47], y[48], y[49], y[50], y[51], y[52], y[53], y[54], y[55], y[56], y[57], y[58], y[59], y[60], y[61], y[62], y[63], y[64], y[65], y[66], y[67], y[68], y[69], y[70], y[71], y[72], y[73], y[74], y[75], y[76], y[77], y[78], y[79], y[80], y[81], y[82], y[83], y[84], y[85], y[86], y[87], y[88], y[89], y[90], y[91], y[92], y[93], y[94], y[95], y[96], y[97], y[98], y[99], y[100], y[101], y[102], y[103], y[104], y[105], y[106], y[107], y[108]
internals         = 

Use `describe(chains)` for summary statistics and quantiles.