Racing diffusion model with covariate

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import bayesflow as bf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats

Two choice task (Racing diffusion model)

We’ll assume a simple RDM with two choice alternatives (Tillman et al., 2020).

Here we will simplify the model to include no bias, no variability in starting points. Instead of modeling one accumulator for each of [left, right] responses, we will simply model one accumulator for “incorrect” and one accumulator for “correct” response. This makes it a bit easier to simulate from (we do not need to simulate stimuli).

Additionally, we assume that the decision threshold is a function of time. Specifically, this means that the decision threshold decreases linearly with trial number, according to threshold = \(\alpha\) + \(\beta\)*trial.

The model has four parameters: 2 drift rates (incorrect - \(\nu_0\), correct \(\nu_1\)), decision threshold intercept \(\alpha\), decision threshold slope \(\beta\)#, and non-decision time \(\tau\).

def evidence_accumulation(nu, max_t, dt):
    timesteps = int(max_t / dt)
    t = np.linspace(0, max_t, timesteps)

    noise = np.random.normal(0, 1, size=timesteps) * np.sqrt(dt)
    evidence = nu * t + np.cumsum(noise)

    return t, evidence

def context(n=None):
    if n is None:
        n = np.random.randint(200, 201)
    return dict(n=n)

def prior(nu=None, alpha=None, beta=None, tau=None):
    if nu is None:
        nu=np.random.dirichlet([2, 2])
        nu=np.random.gamma(shape=5, scale=0.5) * nu
    if alpha is None:
        alpha=np.random.gamma(shape=5, scale=0.2)
    if beta is None:
        beta=np.random.beta(a=1, b=1)* -0.01
    if tau is None:
        tau=np.random.exponential(0.15)
    else:
        tau=tau.item()

    return dict(nu=nu, alpha=alpha, beta=beta, tau=tau)

# generate data for a single trial
def trial(nu, alpha, tau, max_t, dt):
    response = -1
    min_t = max_t
    # loop over accumulators
    # if an accumulator has a smaller passage time than the current minimum
    # save it as the fastest accumulator (response)
    for resp, drift in enumerate(nu):
        t, evidence = evidence_accumulation(drift, max_t, dt)
        passage = np.argmax(evidence > alpha)
        t = max_t if passage==0 else t[passage]
        
        if t < min_t:
            min_t = t
            response = resp
            
    return min_t+tau, response

# generate data for n trials
# keep the data.shape always to max_n
# the rest is filled with 0s
def likelihood(n, nu, alpha, beta, tau, max_t=3.0, dt=0.02, max_n=200):
    rt = np.zeros(max_n)
    response = np.zeros(max_n)
    observed = np.zeros(max_n)
    for i in range(n):
        threshold = alpha + beta * i # LvM: the threshold is a function of the trial number
        # if the treshold is below zero, the participant storms out of the lab because they ran out of patience
        if threshold < 0:
            break
        result = trial(nu, threshold, tau, max_t, dt)
        rt[i] = result[0]
        response[i] = result[1]
        observed[i] = 1

    return dict(rt=rt, response=response, trial=range(n), observed=observed)

simulator = bf.make_simulator([context, prior, likelihood])
adapter = (bf.Adapter()
    .as_set(["rt", "response", "trial", "observed"])
    .constrain(["nu", "alpha", "tau"], lower=0)
    .standardize(include="nu",    mean= 0.7, std=1.2 )
    .standardize(include="alpha", mean= 0.5, std=0.7 )
    .standardize(include="beta",  mean= -.05, std=0.02 )
    .standardize(include="tau",   mean=-2.5, std=1.3 )
    .standardize(include="trial", mean= 100,  std=60  )
    .concatenate(["nu", "alpha","beta", "tau"], into="inference_variables")
    .concatenate(["rt", "response", "trial", "observed"], into="summary_variables")
    .rename("n", "inference_conditions")
)
workflow = bf.BasicWorkflow(
    simulator = simulator,
    adapter = adapter,
    inference_network = bf.networks.CouplingFlow(
        depth=15,
        subnet_kwargs=dict(dropout=False)
    ),
    summary_network=bf.networks.DeepSet(
        depth=15,
        base_distribution="normal",
        dropout=False
    ),
    initial_learning_rate = 1e-3,
    inference_variables = ["nu", "alpha", "beta", "tau"],
    inference_conditions = ["n"],
    summary_variables = ["rt", "response","trial","observed"]
)
train_data = simulator.sample(2_500)
validation_data = simulator.sample(500)
test_data = simulator.sample(500)
history=workflow.fit_offline(
    data=train_data, 
    epochs=50, 
    batch_size=250, 
    validation_data=validation_data
)
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 447s 10s/step - loss: 2481876.2500 - loss/inference_loss: 2481875.7500 - loss/summary_loss: 0.7316 - val_loss: 463.8929 - val_loss/inference_loss: 463.2145 - val_loss/summary_loss: 0.6784
Epoch 2/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 59.6651 - loss/inference_loss: 58.9418 - loss/summary_loss: 0.7233 - val_loss: 11.2292 - val_loss/inference_loss: 10.7928 - val_loss/summary_loss: 0.4364
Epoch 3/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 7.9351 - loss/inference_loss: 7.6685 - loss/summary_loss: 0.2666 - val_loss: 6.5161 - val_loss/inference_loss: 6.3241 - val_loss/summary_loss: 0.1920
Epoch 4/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 6.3099 - loss/inference_loss: 6.1192 - loss/summary_loss: 0.1907 - val_loss: 5.9470 - val_loss/inference_loss: 5.7425 - val_loss/summary_loss: 0.2046
Epoch 5/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 81s 8s/step - loss: 5.9675 - loss/inference_loss: 5.7735 - loss/summary_loss: 0.1940 - val_loss: 6.3504 - val_loss/inference_loss: 6.1765 - val_loss/summary_loss: 0.1739
Epoch 6/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 6.5348 - loss/inference_loss: 6.2088 - loss/summary_loss: 0.3259 - val_loss: 6.0203 - val_loss/inference_loss: 5.7255 - val_loss/summary_loss: 0.2947
Epoch 7/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 77s 8s/step - loss: 6.0858 - loss/inference_loss: 5.7369 - loss/summary_loss: 0.3489 - val_loss: 6.0578 - val_loss/inference_loss: 5.7619 - val_loss/summary_loss: 0.2960
Epoch 8/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 75s 8s/step - loss: 5.7880 - loss/inference_loss: 5.5090 - loss/summary_loss: 0.2790 - val_loss: 5.7461 - val_loss/inference_loss: 5.5063 - val_loss/summary_loss: 0.2398
Epoch 9/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 76s 8s/step - loss: 5.6276 - loss/inference_loss: 5.3413 - loss/summary_loss: 0.2863 - val_loss: 5.7012 - val_loss/inference_loss: 5.4693 - val_loss/summary_loss: 0.2319
Epoch 10/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 77s 8s/step - loss: 5.5524 - loss/inference_loss: 5.2789 - loss/summary_loss: 0.2735 - val_loss: 5.5155 - val_loss/inference_loss: 5.2490 - val_loss/summary_loss: 0.2665
Epoch 11/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 5.4444 - loss/inference_loss: 5.1554 - loss/summary_loss: 0.2891 - val_loss: 6.0276 - val_loss/inference_loss: 5.8132 - val_loss/summary_loss: 0.2143
Epoch 12/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 77s 8s/step - loss: 5.4726 - loss/inference_loss: 5.1981 - loss/summary_loss: 0.2746 - val_loss: 5.7865 - val_loss/inference_loss: 5.4634 - val_loss/summary_loss: 0.3231
Epoch 13/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 5.5793 - loss/inference_loss: 5.2628 - loss/summary_loss: 0.3165 - val_loss: 5.5894 - val_loss/inference_loss: 5.3168 - val_loss/summary_loss: 0.2726
Epoch 14/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 82s 8s/step - loss: 5.4236 - loss/inference_loss: 5.1230 - loss/summary_loss: 0.3005 - val_loss: 5.0572 - val_loss/inference_loss: 4.8010 - val_loss/summary_loss: 0.2563
Epoch 15/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 5.2587 - loss/inference_loss: 4.9706 - loss/summary_loss: 0.2881 - val_loss: 6.0406 - val_loss/inference_loss: 5.8010 - val_loss/summary_loss: 0.2396
Epoch 16/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 5.3925 - loss/inference_loss: 5.1212 - loss/summary_loss: 0.2713 - val_loss: 5.2179 - val_loss/inference_loss: 4.9951 - val_loss/summary_loss: 0.2227
Epoch 17/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 5.2271 - loss/inference_loss: 4.9291 - loss/summary_loss: 0.2980 - val_loss: 5.3982 - val_loss/inference_loss: 5.1697 - val_loss/summary_loss: 0.2285
Epoch 18/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 83s 8s/step - loss: 5.1509 - loss/inference_loss: 4.8469 - loss/summary_loss: 0.3040 - val_loss: 4.9293 - val_loss/inference_loss: 4.6883 - val_loss/summary_loss: 0.2410
Epoch 19/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 4.9497 - loss/inference_loss: 4.6393 - loss/summary_loss: 0.3105 - val_loss: 4.9728 - val_loss/inference_loss: 4.7078 - val_loss/summary_loss: 0.2650
Epoch 20/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 76s 8s/step - loss: 5.2719 - loss/inference_loss: 4.9577 - loss/summary_loss: 0.3142 - val_loss: 5.7347 - val_loss/inference_loss: 5.4575 - val_loss/summary_loss: 0.2772
Epoch 21/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 78s 8s/step - loss: 4.9919 - loss/inference_loss: 4.6927 - loss/summary_loss: 0.2992 - val_loss: 5.3573 - val_loss/inference_loss: 5.1180 - val_loss/summary_loss: 0.2393
Epoch 22/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 4.8316 - loss/inference_loss: 4.5589 - loss/summary_loss: 0.2728 - val_loss: 4.4284 - val_loss/inference_loss: 4.2041 - val_loss/summary_loss: 0.2243
Epoch 23/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 82s 8s/step - loss: 4.8698 - loss/inference_loss: 4.5886 - loss/summary_loss: 0.2812 - val_loss: 4.3555 - val_loss/inference_loss: 4.1169 - val_loss/summary_loss: 0.2386
Epoch 24/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 76s 8s/step - loss: 4.6208 - loss/inference_loss: 4.3398 - loss/summary_loss: 0.2809 - val_loss: 4.8037 - val_loss/inference_loss: 4.4913 - val_loss/summary_loss: 0.3123
Epoch 25/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 4.4561 - loss/inference_loss: 4.1520 - loss/summary_loss: 0.3041 - val_loss: 4.3030 - val_loss/inference_loss: 4.0572 - val_loss/summary_loss: 0.2458
Epoch 26/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 4.2359 - loss/inference_loss: 3.9303 - loss/summary_loss: 0.3056 - val_loss: 4.0600 - val_loss/inference_loss: 3.8294 - val_loss/summary_loss: 0.2306
Epoch 27/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 81s 8s/step - loss: 4.3664 - loss/inference_loss: 4.0516 - loss/summary_loss: 0.3148 - val_loss: 4.3168 - val_loss/inference_loss: 4.0267 - val_loss/summary_loss: 0.2901
Epoch 28/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 77s 8s/step - loss: 4.6645 - loss/inference_loss: 4.3383 - loss/summary_loss: 0.3262 - val_loss: 5.2905 - val_loss/inference_loss: 4.9483 - val_loss/summary_loss: 0.3422
Epoch 29/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 78s 8s/step - loss: 4.7504 - loss/inference_loss: 4.3712 - loss/summary_loss: 0.3792 - val_loss: 4.3816 - val_loss/inference_loss: 4.0359 - val_loss/summary_loss: 0.3456
Epoch 30/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 4.0697 - loss/inference_loss: 3.6948 - loss/summary_loss: 0.3749 - val_loss: 4.3868 - val_loss/inference_loss: 4.0687 - val_loss/summary_loss: 0.3181
Epoch 31/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 3.8667 - loss/inference_loss: 3.4951 - loss/summary_loss: 0.3716 - val_loss: 3.5801 - val_loss/inference_loss: 3.2828 - val_loss/summary_loss: 0.2973
Epoch 32/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 3.8138 - loss/inference_loss: 3.4788 - loss/summary_loss: 0.3350 - val_loss: 3.9115 - val_loss/inference_loss: 3.6640 - val_loss/summary_loss: 0.2475
Epoch 33/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 78s 8s/step - loss: 4.2117 - loss/inference_loss: 3.9011 - loss/summary_loss: 0.3106 - val_loss: 3.7474 - val_loss/inference_loss: 3.5055 - val_loss/summary_loss: 0.2419
Epoch 34/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 3.6181 - loss/inference_loss: 3.3031 - loss/summary_loss: 0.3150 - val_loss: 3.6324 - val_loss/inference_loss: 3.3973 - val_loss/summary_loss: 0.2351
Epoch 35/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 77s 8s/step - loss: 3.3222 - loss/inference_loss: 2.9961 - loss/summary_loss: 0.3261 - val_loss: 3.5649 - val_loss/inference_loss: 3.2783 - val_loss/summary_loss: 0.2866
Epoch 36/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 3.1659 - loss/inference_loss: 2.8447 - loss/summary_loss: 0.3212 - val_loss: 3.3711 - val_loss/inference_loss: 3.0596 - val_loss/summary_loss: 0.3115
Epoch 37/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 78s 8s/step - loss: 3.2770 - loss/inference_loss: 2.9556 - loss/summary_loss: 0.3214 - val_loss: 3.1489 - val_loss/inference_loss: 2.8473 - val_loss/summary_loss: 0.3016
Epoch 38/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 81s 8s/step - loss: 3.0455 - loss/inference_loss: 2.7263 - loss/summary_loss: 0.3192 - val_loss: 3.1324 - val_loss/inference_loss: 2.8601 - val_loss/summary_loss: 0.2723
Epoch 39/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 2.9168 - loss/inference_loss: 2.5954 - loss/summary_loss: 0.3214 - val_loss: 3.0987 - val_loss/inference_loss: 2.8284 - val_loss/summary_loss: 0.2703
Epoch 40/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 2.7912 - loss/inference_loss: 2.4690 - loss/summary_loss: 0.3222 - val_loss: 3.0236 - val_loss/inference_loss: 2.7303 - val_loss/summary_loss: 0.2933
Epoch 41/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 79s 8s/step - loss: 2.6836 - loss/inference_loss: 2.3553 - loss/summary_loss: 0.3282 - val_loss: 2.8232 - val_loss/inference_loss: 2.5458 - val_loss/summary_loss: 0.2774
Epoch 42/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 82s 8s/step - loss: 2.5994 - loss/inference_loss: 2.2767 - loss/summary_loss: 0.3227 - val_loss: 3.1882 - val_loss/inference_loss: 2.9103 - val_loss/summary_loss: 0.2779
Epoch 43/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 83s 8s/step - loss: 2.5226 - loss/inference_loss: 2.1888 - loss/summary_loss: 0.3339 - val_loss: 2.5523 - val_loss/inference_loss: 2.2820 - val_loss/summary_loss: 0.2703
Epoch 44/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 81s 8s/step - loss: 2.4778 - loss/inference_loss: 2.1490 - loss/summary_loss: 0.3288 - val_loss: 2.8602 - val_loss/inference_loss: 2.5891 - val_loss/summary_loss: 0.2712
Epoch 45/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 2.4191 - loss/inference_loss: 2.0947 - loss/summary_loss: 0.3243 - val_loss: 2.9102 - val_loss/inference_loss: 2.6145 - val_loss/summary_loss: 0.2957
Epoch 46/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 80s 8s/step - loss: 2.3937 - loss/inference_loss: 2.0576 - loss/summary_loss: 0.3361 - val_loss: 2.7691 - val_loss/inference_loss: 2.5065 - val_loss/summary_loss: 0.2626
Epoch 47/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 76s 8s/step - loss: 2.3785 - loss/inference_loss: 2.0467 - loss/summary_loss: 0.3318 - val_loss: 2.7930 - val_loss/inference_loss: 2.4943 - val_loss/summary_loss: 0.2987
Epoch 48/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 95s 10s/step - loss: 2.3532 - loss/inference_loss: 2.0243 - loss/summary_loss: 0.3288 - val_loss: 2.6434 - val_loss/inference_loss: 2.3573 - val_loss/summary_loss: 0.2861
Epoch 49/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 4652s 511s/step - loss: 2.3410 - loss/inference_loss: 2.0102 - loss/summary_loss: 0.3307 - val_loss: 2.5182 - val_loss/inference_loss: 2.2583 - val_loss/summary_loss: 0.2600
Epoch 50/50
10/10 ━━━━━━━━━━━━━━━━━━━━ 81s 8s/step - loss: 2.3267 - loss/inference_loss: 1.9953 - loss/summary_loss: 0.3313 - val_loss: 2.7635 - val_loss/inference_loss: 2.5100 - val_loss/summary_loss: 0.2535
plots=workflow.plot_default_diagnostics(test_data=test_data)

Application to data

Here we will work with data from Katsimpokis et al (2021), Experiment 2. In this experiment, participants are asked to indicate which of two flashing circles has the higher flash rate. Additionally, they are either asked to be Fast or Accurate, and to respond before a deadline or not, in a full factorial manner. The file flash.csv only contains data from the Accuracy instruction.

data_inference = pd.read_csv("flash.csv")
data_inference['condition'].unique()
array(['Deadline', 'No Deadline'], dtype=object)
data_inference_grouped = data_inference.groupby(["subject", "condition"])
data_inference_dict = {
    key: np.array([group[key].values.reshape(200, 1) for _, group in data_inference_grouped]) 
             for key in ['rt', 'response', 'trial','observed']}
data_inference_dict["n"] = np.sum(data_inference_dict["observed"], axis=1)
print({key: value.shape for key, value in data_inference_dict.items()})
{'rt': (48, 200, 1), 'response': (48, 200, 1), 'trial': (48, 200, 1), 'observed': (48, 200, 1), 'n': (48, 1)}
posterior_samples = workflow.sample(conditions=data_inference_dict, num_samples=1000)
# pick the first participant, first condition
posterior = {key: value[0] for key, value in posterior_samples.items()}
posterior['nu'].shape
(1000, 2)
priors = dict(nu=test_data["nu"], alpha=test_data["alpha"], beta=test_data["beta"], tau=test_data["tau"])
f=bf.diagnostics.pairs_posterior(estimates=posterior, priors=priors)

The model is still quite uncertain about the parameter estimates, shown by the large ovelap between prior and posterior. There is a bit of a trade-off in this case between the \(\alpha\) and \(\beta\) parameters (but not as big as I expected). This is because they jointly explain the same aspect of the data (response caution). A higher \(\alpha\) means that on the first trial a participant is more conservative. So to get to the same average level of conservativeness, the \(\beta\) needs to be more negative, hence the negative relation.

Let’s look at another participant:

# pick second participant, second condition
posterior = {key: value[3] for key, value in posterior_samples.items()}
f=bf.diagnostics.pairs_posterior(estimates=posterior, priors=priors)

def ecdf(rt, response, observed, **kwargs):
    observed_mask = (observed == 1)
    response_0_mask = ((response == 0) & observed_mask)
    response_1_mask = ((response == 1) & observed_mask)

    response_0_prop = np.sum(response_0_mask) / np.sum(observed_mask)
    response_1_prop = np.sum(response_1_mask) / np.sum(observed_mask)

    response_0_ecdf = stats.ecdf(rt[response_0_mask]).cdf
    response_0_ecdf = response_0_prop * response_0_ecdf.evaluate(np.linspace(0, 1, 101))

    response_1_ecdf = stats.ecdf(rt[response_1_mask]).cdf
    response_1_ecdf = response_1_prop * response_1_ecdf.evaluate(np.linspace(0, 1, 101))

    return response_0_ecdf, response_1_ecdf

Select a participant for whom to display the goodness-of-fit:

data = data_inference_grouped.get_group((1, 'No Deadline'))
plot_data = ecdf(**data)
posterior_predictives = simulator.sample(500, **posterior)
plot_data_predictive = []
for i in range(100):
    x = { key: value[i:i+1,...] for key, value in posterior_predictives.items()}
    plot_data_predictive.append(ecdf(**x))
    
plot_data_predictive = np.array(plot_data_predictive)
plot_data_quantiles = np.quantile(
    plot_data_predictive, 
    q = [0.25, 0.5, 0.75],
    axis=0
)
plot_data_quantiles.shape
(3, 2, 101)
t = np.linspace(0, 1, 101)
cols = ["red", "blue"]
for i, lab in enumerate(["Incorrect", "Correct"]):
    plt.plot(t, plot_data[i], label=lab, color=cols[i])
    plt.plot(t, plot_data_quantiles[1, i, :], color=cols[i], alpha=0.5, label="median predictive")
    plt.fill_between(
        t,
        plot_data_quantiles[0,  i,:],
        plot_data_quantiles[-1, i,:],
        label="50% predictive interval",
        color=cols[i],
        alpha=0.3
    )
f=plt.legend()

The model fits the data of this participant reasonably well. If you change the participant id or condition, you will see that this is not the case for all here.