import os
"KERAS_BACKEND"] = "tensorflow"
os.environ[import keras
import bayesflow as bf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
Wald response times, Racing diffusion model
Diffusion models in BayesFlow
def evidence_accumulation(nu, max_t, dt):
= int(max_t / dt)
timesteps = np.linspace(0, max_t, timesteps)
t
= np.random.normal(0, 1, size=timesteps) * np.sqrt(dt)
noise = nu * t + np.cumsum(noise)
evidence
return t, evidence
= evidence_accumulation(nu=2.5, max_t=1.0, dt=0.01)
t, evidence
plt.plot(t, evidence)"Time (s)")
plt.xlabel("Evidence") plt.ylabel(
Text(0, 0.5, 'Evidence')
Simple response time (Wald model)
Here we will train a simple response time model based on a single accumulator that diffuses with a drift \(\nu\) to a decision threshold \(\alpha\). For more background about the Wald model, see Anders, Alario, & van Maanen (2016). Here we will estimate only the drift and decision threshold, no non-decision time or variability in starting point. In this case, the mean and a standard deviation of the response times are sufficient statistics, so we do not need a summary network.
def prior():
# drift rate
=np.random.gamma(shape=10, scale=0.25)
nu# decision threshold
=np.random.gamma(shape=10, scale=0.1)
alpha
return dict(nu=nu, alpha=alpha)
# generate data for a single trial
def trial(nu, alpha, max_t, dt):
= evidence_accumulation(nu, max_t, dt)
t, evidence = np.argmax(evidence > alpha)
passage = max_t if passage==0 else t[passage]
rt
return rt
# generate data for n trials
def likelihood(nu, alpha, n=250, max_t=3.0, dt=0.02):
= np.zeros(n)
rt for i in range(n):
= trial(nu, alpha, max_t, dt)
rt[i]
return dict(rt=rt)
# sufficient statistics: mean, sd, n
def summary(rt):
return dict(
= np.mean(rt),
mean = np.std(rt)
sd
)
= bf.make_simulator([prior, likelihood, summary]) simulator
= simulator.sample(1_000) df
=bf.diagnostics.pairs_samples(
f
df, =["nu", "alpha", "mean", "sd"],
variable_keys=[r"$\nu$", r"$\alpha$", "mean RT", "sd RT"]) variable_names
= (bf.Adapter()
adapter "nu", "alpha"], lower=0)
.constrain(["nu", "alpha"], into="inference_variables")
.concatenate(["mean", "sd"], into="inference_conditions")
.concatenate(["rt")
.drop( )
= bf.BasicWorkflow(
workflow = simulator,
simulator = adapter,
adapter = bf.networks.CouplingFlow(permutation="swap", subnet_kwargs=dict(dropout=False)),
inference_network = ["nu", "alpha"],
inference_variables = ["mean", "sd"]
inference_conditions )
= simulator.sample(5_000)
train_data = simulator.sample(1_000) validation_data
=workflow.fit_offline(
history=train_data,
data=50,
epochs=250,
batch_size=validation_data
validation_data )
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 6s 31ms/step - loss: 4.3237 - loss/inference_loss: 4.3237 - val_loss: 3.2620 - val_loss/inference_loss: 3.2620
Epoch 2/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 2.9751 - loss/inference_loss: 2.9751 - val_loss: 2.6640 - val_loss/inference_loss: 2.6640
Epoch 3/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 2.5618 - loss/inference_loss: 2.5618 - val_loss: 2.4644 - val_loss/inference_loss: 2.4644
Epoch 4/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 2.3460 - loss/inference_loss: 2.3460 - val_loss: 2.2887 - val_loss/inference_loss: 2.2887
Epoch 5/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 2.1960 - loss/inference_loss: 2.1960 - val_loss: 2.0773 - val_loss/inference_loss: 2.0773
Epoch 6/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 2.0349 - loss/inference_loss: 2.0349 - val_loss: 1.9428 - val_loss/inference_loss: 1.9428
Epoch 7/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 1.8103 - loss/inference_loss: 1.8103 - val_loss: 1.6752 - val_loss/inference_loss: 1.6752
Epoch 8/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 1.4347 - loss/inference_loss: 1.4347 - val_loss: 1.0306 - val_loss/inference_loss: 1.0306
Epoch 9/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.6755 - loss/inference_loss: 0.6755 - val_loss: 0.1332 - val_loss/inference_loss: 0.1332
Epoch 10/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -0.0722 - loss/inference_loss: -0.0722 - val_loss: -0.5113 - val_loss/inference_loss: -0.5113
Epoch 11/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -0.6561 - loss/inference_loss: -0.6561 - val_loss: -0.9005 - val_loss/inference_loss: -0.9005
Epoch 12/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -0.9802 - loss/inference_loss: -0.9802 - val_loss: -1.0020 - val_loss/inference_loss: -1.0020
Epoch 13/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -1.0407 - loss/inference_loss: -1.0407 - val_loss: -1.1719 - val_loss/inference_loss: -1.1719
Epoch 14/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -1.0787 - loss/inference_loss: -1.0787 - val_loss: -1.0381 - val_loss/inference_loss: -1.0381
Epoch 15/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: -1.0662 - loss/inference_loss: -1.0662 - val_loss: -1.1802 - val_loss/inference_loss: -1.1802
Epoch 16/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -1.1526 - loss/inference_loss: -1.1526 - val_loss: -1.1777 - val_loss/inference_loss: -1.1777
Epoch 17/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -1.1973 - loss/inference_loss: -1.1973 - val_loss: -1.3017 - val_loss/inference_loss: -1.3017
Epoch 18/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: -1.2158 - loss/inference_loss: -1.2158 - val_loss: -1.1972 - val_loss/inference_loss: -1.1972
Epoch 19/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.2340 - loss/inference_loss: -1.2340 - val_loss: -1.1711 - val_loss/inference_loss: -1.1711
Epoch 20/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.2433 - loss/inference_loss: -1.2433 - val_loss: -1.3729 - val_loss/inference_loss: -1.3729
Epoch 21/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.2589 - loss/inference_loss: -1.2589 - val_loss: -1.2672 - val_loss/inference_loss: -1.2672
Epoch 22/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.2627 - loss/inference_loss: -1.2627 - val_loss: -1.1731 - val_loss/inference_loss: -1.1731
Epoch 23/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.2662 - loss/inference_loss: -1.2662 - val_loss: -1.2593 - val_loss/inference_loss: -1.2593
Epoch 24/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: -1.2929 - loss/inference_loss: -1.2929 - val_loss: -1.3146 - val_loss/inference_loss: -1.3146
Epoch 25/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.2873 - loss/inference_loss: -1.2873 - val_loss: -1.1808 - val_loss/inference_loss: -1.1808
Epoch 26/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: -1.2965 - loss/inference_loss: -1.2965 - val_loss: -1.2206 - val_loss/inference_loss: -1.2206
Epoch 27/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: -1.2928 - loss/inference_loss: -1.2928 - val_loss: -1.2815 - val_loss/inference_loss: -1.2815
Epoch 28/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.3085 - loss/inference_loss: -1.3085 - val_loss: -1.2836 - val_loss/inference_loss: -1.2836
Epoch 29/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.3052 - loss/inference_loss: -1.3052 - val_loss: -1.3027 - val_loss/inference_loss: -1.3027
Epoch 30/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.3032 - loss/inference_loss: -1.3032 - val_loss: -1.3187 - val_loss/inference_loss: -1.3187
Epoch 31/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3289 - loss/inference_loss: -1.3289 - val_loss: -1.3867 - val_loss/inference_loss: -1.3867
Epoch 32/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.3229 - loss/inference_loss: -1.3229 - val_loss: -1.3211 - val_loss/inference_loss: -1.3211
Epoch 33/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3265 - loss/inference_loss: -1.3265 - val_loss: -1.3229 - val_loss/inference_loss: -1.3229
Epoch 34/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3420 - loss/inference_loss: -1.3420 - val_loss: -1.4363 - val_loss/inference_loss: -1.4363
Epoch 35/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3360 - loss/inference_loss: -1.3360 - val_loss: -1.3724 - val_loss/inference_loss: -1.3724
Epoch 36/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: -1.3250 - loss/inference_loss: -1.3250 - val_loss: -1.4013 - val_loss/inference_loss: -1.4013
Epoch 37/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3528 - loss/inference_loss: -1.3528 - val_loss: -1.4209 - val_loss/inference_loss: -1.4209
Epoch 38/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3484 - loss/inference_loss: -1.3484 - val_loss: -1.3968 - val_loss/inference_loss: -1.3968
Epoch 39/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3481 - loss/inference_loss: -1.3481 - val_loss: -1.3297 - val_loss/inference_loss: -1.3297
Epoch 40/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3504 - loss/inference_loss: -1.3504 - val_loss: -1.4910 - val_loss/inference_loss: -1.4910
Epoch 41/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3585 - loss/inference_loss: -1.3585 - val_loss: -1.2492 - val_loss/inference_loss: -1.2492
Epoch 42/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: -1.3609 - loss/inference_loss: -1.3609 - val_loss: -1.3917 - val_loss/inference_loss: -1.3917
Epoch 43/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3535 - loss/inference_loss: -1.3535 - val_loss: -1.3360 - val_loss/inference_loss: -1.3360
Epoch 44/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3583 - loss/inference_loss: -1.3583 - val_loss: -1.3302 - val_loss/inference_loss: -1.3302
Epoch 45/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3607 - loss/inference_loss: -1.3607 - val_loss: -1.3974 - val_loss/inference_loss: -1.3974
Epoch 46/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3665 - loss/inference_loss: -1.3665 - val_loss: -1.3183 - val_loss/inference_loss: -1.3183
Epoch 47/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3582 - loss/inference_loss: -1.3582 - val_loss: -1.2993 - val_loss/inference_loss: -1.2993
Epoch 48/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: -1.3658 - loss/inference_loss: -1.3658 - val_loss: -1.2053 - val_loss/inference_loss: -1.2053
Epoch 49/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: -1.3575 - loss/inference_loss: -1.3575 - val_loss: -1.3078 - val_loss/inference_loss: -1.3078
Epoch 50/50
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: -1.3586 - loss/inference_loss: -1.3586 - val_loss: -1.3565 - val_loss/inference_loss: -1.3565
= simulator.sample(1_000)
test_data =workflow.plot_default_diagnostics(test_data=test_data) plots
Two choice task (Racing diffusion model)
We’ll assume a simple RDM with two choice alternatives (Tillman et al., 2020).
Here we will simplify the model to include no bias, no variability in starting points. Instead of modeling one accumulator for each of [left, right] responses, we will simply model one accumulator for “incorrect” and one accumulator for “correct” response. This makes it a bit easier to simulate from (we do not need to simulate stimuli).
The model has four parameters: 2 drift rates (incorrect - \(\nu_0\), correct \(\nu_1\)), decision threshold \(\alpha\), and non-decision time \(\tau\).
def context(n=None):
if n is None:
= np.random.randint(200, 351)
n return dict(n=n)
def prior(nu=None, alpha=None, tau=None):
if nu is None:
=np.random.dirichlet([2, 2])
nu=np.random.gamma(shape=5, scale=0.5) * nu
nuif alpha is None:
=np.random.gamma(shape=5, scale=0.2)
alphaif tau is None:
=np.random.exponential(0.15)
tauelse:
=tau.item()
tau
return dict(nu=nu, alpha=alpha, tau=tau)
# generate data for a single trial
def trial(nu, alpha, tau, max_t, dt):
= -1
response = max_t
min_t # loop over accumulators
# if an accumulator has a smaller passage time than the current minimum
# save it as the fastest accumulator (response)
for resp, drift in enumerate(nu):
= evidence_accumulation(drift, max_t, dt)
t, evidence = np.argmax(evidence > alpha)
passage = max_t if passage==0 else t[passage]
t
if t < min_t:
= t
min_t = resp
response
return min_t+tau, response
# generate data for n trials
# keep the data.shape always to max_n
# the rest is filled with 0s
def likelihood(n, nu, alpha, tau, max_t=3.0, dt=0.02, max_n=350):
= np.zeros(max_n)
rt = np.zeros(max_n)
response = np.zeros(max_n)
observed for i in range(n):
= trial(nu, alpha, tau, max_t, dt)
result = result[0]
rt[i] = result[1]
response[i] = 1
observed[i]
return dict(rt=rt, response=response, observed=observed)
= bf.make_simulator([context, prior, likelihood]) simulator
= (bf.Adapter()
adapter "rt", "response", "observed"])
.as_set(["nu", "alpha", "tau"], lower=0)
.constrain([="nu", mean= 0.7, std=1.2)
.standardize(include="alpha", mean= 0.5, std=0.7)
.standardize(include="tau", mean=-2.5, std=1.3)
.standardize(include"nu", "alpha", "tau"], into="inference_variables")
.concatenate(["rt", "response", "observed"], into="summary_variables")
.concatenate(["n", "inference_conditions")
.rename( )
= bf.BasicWorkflow(
workflow = simulator,
simulator = adapter,
adapter = bf.networks.CouplingFlow(
inference_network ="swap",
permutation=dict(dropout=False)
subnet_kwargs
),=bf.networks.DeepSet(
summary_network="normal",
base_distribution=False
dropout
),= ["nu", "alpha", "tau"],
inference_variables = ["n"],
inference_conditions = ["rt", "response", "observed"]
summary_variables )
= simulator.sample(5_000)
train_data = simulator.sample(1_000) validation_data
=workflow.fit_offline(
history=train_data,
data=100,
epochs=250,
batch_size=validation_data
validation_data )
INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 34s 1s/step - loss: 14.3609 - loss/inference_loss: 14.0609 - loss/summary_loss: 0.3000 - val_loss: 6.1114 - val_loss/inference_loss: 5.9787 - val_loss/summary_loss: 0.1327
Epoch 2/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 633ms/step - loss: 5.8206 - loss/inference_loss: 5.6786 - loss/summary_loss: 0.1419 - val_loss: 5.5208 - val_loss/inference_loss: 5.3216 - val_loss/summary_loss: 0.1992
Epoch 3/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 645ms/step - loss: 5.2590 - loss/inference_loss: 5.0502 - loss/summary_loss: 0.2088 - val_loss: 5.1275 - val_loss/inference_loss: 4.9059 - val_loss/summary_loss: 0.2216
Epoch 4/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 649ms/step - loss: 4.7636 - loss/inference_loss: 4.5833 - loss/summary_loss: 0.1803 - val_loss: 4.5547 - val_loss/inference_loss: 4.4057 - val_loss/summary_loss: 0.1490
Epoch 5/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 607ms/step - loss: 4.4397 - loss/inference_loss: 4.2925 - loss/summary_loss: 0.1471 - val_loss: 4.3656 - val_loss/inference_loss: 4.2313 - val_loss/summary_loss: 0.1344
Epoch 6/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 604ms/step - loss: 4.3418 - loss/inference_loss: 4.1914 - loss/summary_loss: 0.1504 - val_loss: 4.1845 - val_loss/inference_loss: 4.0441 - val_loss/summary_loss: 0.1405
Epoch 7/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 605ms/step - loss: 4.3237 - loss/inference_loss: 4.1712 - loss/summary_loss: 0.1525 - val_loss: 4.3240 - val_loss/inference_loss: 4.1473 - val_loss/summary_loss: 0.1767
Epoch 8/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 613ms/step - loss: 4.1682 - loss/inference_loss: 4.0072 - loss/summary_loss: 0.1610 - val_loss: 3.9488 - val_loss/inference_loss: 3.8090 - val_loss/summary_loss: 0.1398
Epoch 9/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 607ms/step - loss: 3.9366 - loss/inference_loss: 3.7888 - loss/summary_loss: 0.1479 - val_loss: 4.2893 - val_loss/inference_loss: 4.1522 - val_loss/summary_loss: 0.1371
Epoch 10/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 607ms/step - loss: 3.8633 - loss/inference_loss: 3.7098 - loss/summary_loss: 0.1535 - val_loss: 4.0335 - val_loss/inference_loss: 3.9091 - val_loss/summary_loss: 0.1245
Epoch 11/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 607ms/step - loss: 3.7262 - loss/inference_loss: 3.5701 - loss/summary_loss: 0.1562 - val_loss: 3.7161 - val_loss/inference_loss: 3.5764 - val_loss/summary_loss: 0.1397
Epoch 12/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 609ms/step - loss: 3.6239 - loss/inference_loss: 3.4630 - loss/summary_loss: 0.1610 - val_loss: 3.5772 - val_loss/inference_loss: 3.4350 - val_loss/summary_loss: 0.1422
Epoch 13/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 622ms/step - loss: 3.4660 - loss/inference_loss: 3.2979 - loss/summary_loss: 0.1681 - val_loss: 3.2676 - val_loss/inference_loss: 3.1305 - val_loss/summary_loss: 0.1371
Epoch 14/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 622ms/step - loss: 3.4642 - loss/inference_loss: 3.3057 - loss/summary_loss: 0.1586 - val_loss: 3.6891 - val_loss/inference_loss: 3.5483 - val_loss/summary_loss: 0.1408
Epoch 15/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 612ms/step - loss: 3.3781 - loss/inference_loss: 3.2192 - loss/summary_loss: 0.1588 - val_loss: 3.4303 - val_loss/inference_loss: 3.2832 - val_loss/summary_loss: 0.1471
Epoch 16/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 607ms/step - loss: 3.1897 - loss/inference_loss: 3.0271 - loss/summary_loss: 0.1626 - val_loss: 3.6857 - val_loss/inference_loss: 3.5336 - val_loss/summary_loss: 0.1521
Epoch 17/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 611ms/step - loss: 3.1959 - loss/inference_loss: 3.0351 - loss/summary_loss: 0.1608 - val_loss: 3.4023 - val_loss/inference_loss: 3.2598 - val_loss/summary_loss: 0.1426
Epoch 18/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 624ms/step - loss: 3.0804 - loss/inference_loss: 2.9221 - loss/summary_loss: 0.1583 - val_loss: 2.8908 - val_loss/inference_loss: 2.7297 - val_loss/summary_loss: 0.1611
Epoch 19/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 618ms/step - loss: 3.0177 - loss/inference_loss: 2.8528 - loss/summary_loss: 0.1649 - val_loss: 3.3557 - val_loss/inference_loss: 3.1970 - val_loss/summary_loss: 0.1587
Epoch 20/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 620ms/step - loss: 2.9631 - loss/inference_loss: 2.7981 - loss/summary_loss: 0.1650 - val_loss: 3.7985 - val_loss/inference_loss: 3.6307 - val_loss/summary_loss: 0.1678
Epoch 21/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 606ms/step - loss: 2.9434 - loss/inference_loss: 2.7783 - loss/summary_loss: 0.1652 - val_loss: 2.8577 - val_loss/inference_loss: 2.7084 - val_loss/summary_loss: 0.1493
Epoch 22/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 12s 618ms/step - loss: 2.9067 - loss/inference_loss: 2.7405 - loss/summary_loss: 0.1662 - val_loss: 3.5980 - val_loss/inference_loss: 3.4340 - val_loss/summary_loss: 0.1640
Epoch 23/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 664ms/step - loss: 2.9271 - loss/inference_loss: 2.7585 - loss/summary_loss: 0.1686 - val_loss: 3.5134 - val_loss/inference_loss: 3.3457 - val_loss/summary_loss: 0.1677
Epoch 24/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 666ms/step - loss: 2.8916 - loss/inference_loss: 2.7192 - loss/summary_loss: 0.1724 - val_loss: 2.5384 - val_loss/inference_loss: 2.3735 - val_loss/summary_loss: 0.1649
Epoch 25/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 655ms/step - loss: 2.6834 - loss/inference_loss: 2.5162 - loss/summary_loss: 0.1673 - val_loss: 2.9493 - val_loss/inference_loss: 2.7742 - val_loss/summary_loss: 0.1752
Epoch 26/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 663ms/step - loss: 2.7069 - loss/inference_loss: 2.5339 - loss/summary_loss: 0.1729 - val_loss: 2.7834 - val_loss/inference_loss: 2.6410 - val_loss/summary_loss: 0.1424
Epoch 27/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 631ms/step - loss: 2.7247 - loss/inference_loss: 2.5504 - loss/summary_loss: 0.1743 - val_loss: 2.7583 - val_loss/inference_loss: 2.6091 - val_loss/summary_loss: 0.1491
Epoch 28/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 655ms/step - loss: 3.0438 - loss/inference_loss: 2.8668 - loss/summary_loss: 0.1770 - val_loss: 2.7321 - val_loss/inference_loss: 2.5490 - val_loss/summary_loss: 0.1831
Epoch 29/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 646ms/step - loss: 3.3625 - loss/inference_loss: 3.1734 - loss/summary_loss: 0.1891 - val_loss: 3.2479 - val_loss/inference_loss: 3.0610 - val_loss/summary_loss: 0.1869
Epoch 30/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 2.8979 - loss/inference_loss: 2.7022 - loss/summary_loss: 0.1956 - val_loss: 3.0476 - val_loss/inference_loss: 2.8501 - val_loss/summary_loss: 0.1975
Epoch 31/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 642ms/step - loss: 2.8198 - loss/inference_loss: 2.6258 - loss/summary_loss: 0.1939 - val_loss: 2.5580 - val_loss/inference_loss: 2.3881 - val_loss/summary_loss: 0.1699
Epoch 32/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 652ms/step - loss: 2.6621 - loss/inference_loss: 2.4773 - loss/summary_loss: 0.1847 - val_loss: 2.7204 - val_loss/inference_loss: 2.5579 - val_loss/summary_loss: 0.1626
Epoch 33/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 641ms/step - loss: 2.3376 - loss/inference_loss: 2.1612 - loss/summary_loss: 0.1764 - val_loss: 2.3137 - val_loss/inference_loss: 2.1466 - val_loss/summary_loss: 0.1671
Epoch 34/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 640ms/step - loss: 2.3064 - loss/inference_loss: 2.1392 - loss/summary_loss: 0.1671 - val_loss: 2.1724 - val_loss/inference_loss: 2.0077 - val_loss/summary_loss: 0.1647
Epoch 35/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 2.6591 - loss/inference_loss: 2.4902 - loss/summary_loss: 0.1689 - val_loss: 2.8983 - val_loss/inference_loss: 2.7284 - val_loss/summary_loss: 0.1699
Epoch 36/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 2.2897 - loss/inference_loss: 2.0997 - loss/summary_loss: 0.1900 - val_loss: 2.1468 - val_loss/inference_loss: 1.9434 - val_loss/summary_loss: 0.2034
Epoch 37/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 672ms/step - loss: 2.3447 - loss/inference_loss: 2.1427 - loss/summary_loss: 0.2020 - val_loss: 2.2293 - val_loss/inference_loss: 2.0126 - val_loss/summary_loss: 0.2167
Epoch 38/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 650ms/step - loss: 2.1639 - loss/inference_loss: 1.9669 - loss/summary_loss: 0.1970 - val_loss: 2.2016 - val_loss/inference_loss: 2.0265 - val_loss/summary_loss: 0.1751
Epoch 39/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 631ms/step - loss: 1.9992 - loss/inference_loss: 1.8057 - loss/summary_loss: 0.1936 - val_loss: 1.9904 - val_loss/inference_loss: 1.8133 - val_loss/summary_loss: 0.1771
Epoch 40/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 637ms/step - loss: 1.8934 - loss/inference_loss: 1.7005 - loss/summary_loss: 0.1929 - val_loss: 2.1794 - val_loss/inference_loss: 2.0040 - val_loss/summary_loss: 0.1755
Epoch 41/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 1.8562 - loss/inference_loss: 1.6690 - loss/summary_loss: 0.1872 - val_loss: 1.9737 - val_loss/inference_loss: 1.7985 - val_loss/summary_loss: 0.1752
Epoch 42/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 632ms/step - loss: 1.7466 - loss/inference_loss: 1.5612 - loss/summary_loss: 0.1854 - val_loss: 2.0639 - val_loss/inference_loss: 1.8752 - val_loss/summary_loss: 0.1887
Epoch 43/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 648ms/step - loss: 1.7693 - loss/inference_loss: 1.5840 - loss/summary_loss: 0.1853 - val_loss: 2.5070 - val_loss/inference_loss: 2.3081 - val_loss/summary_loss: 0.1989
Epoch 44/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 674ms/step - loss: 1.7723 - loss/inference_loss: 1.5713 - loss/summary_loss: 0.2011 - val_loss: 1.9744 - val_loss/inference_loss: 1.7753 - val_loss/summary_loss: 0.1991
Epoch 45/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 634ms/step - loss: 1.6406 - loss/inference_loss: 1.4498 - loss/summary_loss: 0.1909 - val_loss: 1.7236 - val_loss/inference_loss: 1.5469 - val_loss/summary_loss: 0.1768
Epoch 46/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 1.6229 - loss/inference_loss: 1.4347 - loss/summary_loss: 0.1882 - val_loss: 1.7695 - val_loss/inference_loss: 1.5858 - val_loss/summary_loss: 0.1838
Epoch 47/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 635ms/step - loss: 1.9651 - loss/inference_loss: 1.7786 - loss/summary_loss: 0.1865 - val_loss: 2.2095 - val_loss/inference_loss: 2.0418 - val_loss/summary_loss: 0.1676
Epoch 48/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 1.7954 - loss/inference_loss: 1.6079 - loss/summary_loss: 0.1875 - val_loss: 1.7250 - val_loss/inference_loss: 1.5349 - val_loss/summary_loss: 0.1901
Epoch 49/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 640ms/step - loss: 1.4650 - loss/inference_loss: 1.2654 - loss/summary_loss: 0.1996 - val_loss: 1.5375 - val_loss/inference_loss: 1.3673 - val_loss/summary_loss: 0.1703
Epoch 50/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 1.4561 - loss/inference_loss: 1.2723 - loss/summary_loss: 0.1838 - val_loss: 1.3710 - val_loss/inference_loss: 1.2069 - val_loss/summary_loss: 0.1641
Epoch 51/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 652ms/step - loss: 1.3965 - loss/inference_loss: 1.2208 - loss/summary_loss: 0.1757 - val_loss: 1.6446 - val_loss/inference_loss: 1.4805 - val_loss/summary_loss: 0.1641
Epoch 52/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 647ms/step - loss: 1.5803 - loss/inference_loss: 1.4120 - loss/summary_loss: 0.1683 - val_loss: 2.0261 - val_loss/inference_loss: 1.8793 - val_loss/summary_loss: 0.1468
Epoch 53/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 646ms/step - loss: 1.4886 - loss/inference_loss: 1.3165 - loss/summary_loss: 0.1721 - val_loss: 1.4217 - val_loss/inference_loss: 1.2650 - val_loss/summary_loss: 0.1567
Epoch 54/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 649ms/step - loss: 1.3557 - loss/inference_loss: 1.1763 - loss/summary_loss: 0.1794 - val_loss: 1.1852 - val_loss/inference_loss: 1.0335 - val_loss/summary_loss: 0.1517
Epoch 55/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 647ms/step - loss: 1.3222 - loss/inference_loss: 1.1358 - loss/summary_loss: 0.1863 - val_loss: 1.4570 - val_loss/inference_loss: 1.2818 - val_loss/summary_loss: 0.1752
Epoch 56/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 651ms/step - loss: 1.2665 - loss/inference_loss: 1.0845 - loss/summary_loss: 0.1819 - val_loss: 1.3704 - val_loss/inference_loss: 1.2134 - val_loss/summary_loss: 0.1570
Epoch 57/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 650ms/step - loss: 1.1603 - loss/inference_loss: 0.9831 - loss/summary_loss: 0.1771 - val_loss: 1.2693 - val_loss/inference_loss: 1.1043 - val_loss/summary_loss: 0.1650
Epoch 58/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 667ms/step - loss: 1.1930 - loss/inference_loss: 1.0246 - loss/summary_loss: 0.1684 - val_loss: 1.1312 - val_loss/inference_loss: 0.9607 - val_loss/summary_loss: 0.1705
Epoch 59/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 646ms/step - loss: 1.3012 - loss/inference_loss: 1.1337 - loss/summary_loss: 0.1675 - val_loss: 0.9646 - val_loss/inference_loss: 0.8219 - val_loss/summary_loss: 0.1427
Epoch 60/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 1.1655 - loss/inference_loss: 0.9962 - loss/summary_loss: 0.1692 - val_loss: 1.2670 - val_loss/inference_loss: 1.1023 - val_loss/summary_loss: 0.1647
Epoch 61/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 642ms/step - loss: 1.0884 - loss/inference_loss: 0.9162 - loss/summary_loss: 0.1722 - val_loss: 1.0760 - val_loss/inference_loss: 0.9093 - val_loss/summary_loss: 0.1666
Epoch 62/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 1.0608 - loss/inference_loss: 0.8926 - loss/summary_loss: 0.1682 - val_loss: 1.2358 - val_loss/inference_loss: 1.0848 - val_loss/summary_loss: 0.1510
Epoch 63/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 645ms/step - loss: 0.9816 - loss/inference_loss: 0.8134 - loss/summary_loss: 0.1682 - val_loss: 1.1988 - val_loss/inference_loss: 1.0484 - val_loss/summary_loss: 0.1503
Epoch 64/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 652ms/step - loss: 0.9914 - loss/inference_loss: 0.8276 - loss/summary_loss: 0.1638 - val_loss: 1.1647 - val_loss/inference_loss: 1.0099 - val_loss/summary_loss: 0.1549
Epoch 65/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 663ms/step - loss: 0.9703 - loss/inference_loss: 0.8106 - loss/summary_loss: 0.1597 - val_loss: 1.0473 - val_loss/inference_loss: 0.8970 - val_loss/summary_loss: 0.1503
Epoch 66/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 649ms/step - loss: 0.9324 - loss/inference_loss: 0.7728 - loss/summary_loss: 0.1596 - val_loss: 1.1476 - val_loss/inference_loss: 0.9971 - val_loss/summary_loss: 0.1505
Epoch 67/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 646ms/step - loss: 0.9491 - loss/inference_loss: 0.7905 - loss/summary_loss: 0.1585 - val_loss: 0.9906 - val_loss/inference_loss: 0.8568 - val_loss/summary_loss: 0.1338
Epoch 68/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 645ms/step - loss: 0.9203 - loss/inference_loss: 0.7623 - loss/summary_loss: 0.1580 - val_loss: 0.8555 - val_loss/inference_loss: 0.7052 - val_loss/summary_loss: 0.1503
Epoch 69/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 647ms/step - loss: 0.9002 - loss/inference_loss: 0.7434 - loss/summary_loss: 0.1567 - val_loss: 0.9529 - val_loss/inference_loss: 0.7949 - val_loss/summary_loss: 0.1580
Epoch 70/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 642ms/step - loss: 0.8696 - loss/inference_loss: 0.7143 - loss/summary_loss: 0.1553 - val_loss: 1.0420 - val_loss/inference_loss: 0.8976 - val_loss/summary_loss: 0.1444
Epoch 71/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 635ms/step - loss: 0.8365 - loss/inference_loss: 0.6796 - loss/summary_loss: 0.1569 - val_loss: 0.7756 - val_loss/inference_loss: 0.6244 - val_loss/summary_loss: 0.1512
Epoch 72/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 641ms/step - loss: 0.7884 - loss/inference_loss: 0.6326 - loss/summary_loss: 0.1558 - val_loss: 0.9432 - val_loss/inference_loss: 0.8026 - val_loss/summary_loss: 0.1406
Epoch 73/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 654ms/step - loss: 0.8110 - loss/inference_loss: 0.6554 - loss/summary_loss: 0.1556 - val_loss: 1.2307 - val_loss/inference_loss: 1.0674 - val_loss/summary_loss: 0.1633
Epoch 74/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 645ms/step - loss: 0.8237 - loss/inference_loss: 0.6710 - loss/summary_loss: 0.1527 - val_loss: 1.0769 - val_loss/inference_loss: 0.9334 - val_loss/summary_loss: 0.1435
Epoch 75/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 640ms/step - loss: 0.7520 - loss/inference_loss: 0.6001 - loss/summary_loss: 0.1519 - val_loss: 0.5952 - val_loss/inference_loss: 0.4440 - val_loss/summary_loss: 0.1512
Epoch 76/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 638ms/step - loss: 0.7704 - loss/inference_loss: 0.6198 - loss/summary_loss: 0.1506 - val_loss: 1.0217 - val_loss/inference_loss: 0.8737 - val_loss/summary_loss: 0.1481
Epoch 77/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 633ms/step - loss: 0.8600 - loss/inference_loss: 0.7117 - loss/summary_loss: 0.1483 - val_loss: 0.8531 - val_loss/inference_loss: 0.7029 - val_loss/summary_loss: 0.1502
Epoch 78/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 646ms/step - loss: 0.7354 - loss/inference_loss: 0.5853 - loss/summary_loss: 0.1500 - val_loss: 0.8464 - val_loss/inference_loss: 0.6999 - val_loss/summary_loss: 0.1465
Epoch 79/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 656ms/step - loss: 0.6935 - loss/inference_loss: 0.5461 - loss/summary_loss: 0.1474 - val_loss: 1.1493 - val_loss/inference_loss: 1.0016 - val_loss/summary_loss: 0.1477
Epoch 80/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 640ms/step - loss: 0.6776 - loss/inference_loss: 0.5293 - loss/summary_loss: 0.1483 - val_loss: 0.8535 - val_loss/inference_loss: 0.7131 - val_loss/summary_loss: 0.1404
Epoch 81/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 641ms/step - loss: 0.6657 - loss/inference_loss: 0.5185 - loss/summary_loss: 0.1471 - val_loss: 0.6595 - val_loss/inference_loss: 0.5300 - val_loss/summary_loss: 0.1295
Epoch 82/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 655ms/step - loss: 0.6525 - loss/inference_loss: 0.5066 - loss/summary_loss: 0.1459 - val_loss: 0.8119 - val_loss/inference_loss: 0.6816 - val_loss/summary_loss: 0.1303
Epoch 83/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 652ms/step - loss: 0.6473 - loss/inference_loss: 0.5016 - loss/summary_loss: 0.1457 - val_loss: 1.1176 - val_loss/inference_loss: 0.9759 - val_loss/summary_loss: 0.1417
Epoch 84/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 651ms/step - loss: 0.6252 - loss/inference_loss: 0.4787 - loss/summary_loss: 0.1464 - val_loss: 0.9897 - val_loss/inference_loss: 0.8482 - val_loss/summary_loss: 0.1415
Epoch 85/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 648ms/step - loss: 0.6253 - loss/inference_loss: 0.4791 - loss/summary_loss: 0.1461 - val_loss: 0.9309 - val_loss/inference_loss: 0.7894 - val_loss/summary_loss: 0.1415
Epoch 86/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 653ms/step - loss: 0.6400 - loss/inference_loss: 0.4954 - loss/summary_loss: 0.1446 - val_loss: 0.7634 - val_loss/inference_loss: 0.6060 - val_loss/summary_loss: 0.1574
Epoch 87/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 634ms/step - loss: 0.6005 - loss/inference_loss: 0.4567 - loss/summary_loss: 0.1438 - val_loss: 0.9929 - val_loss/inference_loss: 0.8591 - val_loss/summary_loss: 0.1337
Epoch 88/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 0.5821 - loss/inference_loss: 0.4381 - loss/summary_loss: 0.1440 - val_loss: 0.7725 - val_loss/inference_loss: 0.6350 - val_loss/summary_loss: 0.1375
Epoch 89/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 656ms/step - loss: 0.5848 - loss/inference_loss: 0.4408 - loss/summary_loss: 0.1440 - val_loss: 0.7279 - val_loss/inference_loss: 0.5971 - val_loss/summary_loss: 0.1309
Epoch 90/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 654ms/step - loss: 0.5676 - loss/inference_loss: 0.4243 - loss/summary_loss: 0.1433 - val_loss: 0.4774 - val_loss/inference_loss: 0.3479 - val_loss/summary_loss: 0.1295
Epoch 91/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 650ms/step - loss: 0.5667 - loss/inference_loss: 0.4229 - loss/summary_loss: 0.1438 - val_loss: 0.8972 - val_loss/inference_loss: 0.7705 - val_loss/summary_loss: 0.1267
Epoch 92/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 644ms/step - loss: 0.5633 - loss/inference_loss: 0.4198 - loss/summary_loss: 0.1435 - val_loss: 0.5660 - val_loss/inference_loss: 0.4268 - val_loss/summary_loss: 0.1392
Epoch 93/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 651ms/step - loss: 0.5625 - loss/inference_loss: 0.4200 - loss/summary_loss: 0.1425 - val_loss: 0.8851 - val_loss/inference_loss: 0.7444 - val_loss/summary_loss: 0.1407
Epoch 94/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 653ms/step - loss: 0.5464 - loss/inference_loss: 0.4034 - loss/summary_loss: 0.1429 - val_loss: 0.7550 - val_loss/inference_loss: 0.6080 - val_loss/summary_loss: 0.1470
Epoch 95/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 659ms/step - loss: 0.5542 - loss/inference_loss: 0.4115 - loss/summary_loss: 0.1427 - val_loss: 1.0188 - val_loss/inference_loss: 0.8732 - val_loss/summary_loss: 0.1456
Epoch 96/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 651ms/step - loss: 0.5547 - loss/inference_loss: 0.4116 - loss/summary_loss: 0.1431 - val_loss: 0.7003 - val_loss/inference_loss: 0.5620 - val_loss/summary_loss: 0.1383
Epoch 97/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 650ms/step - loss: 0.5643 - loss/inference_loss: 0.4210 - loss/summary_loss: 0.1432 - val_loss: 0.5191 - val_loss/inference_loss: 0.3852 - val_loss/summary_loss: 0.1339
Epoch 98/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 633ms/step - loss: 0.5524 - loss/inference_loss: 0.4096 - loss/summary_loss: 0.1428 - val_loss: 0.7010 - val_loss/inference_loss: 0.5694 - val_loss/summary_loss: 0.1316
Epoch 99/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 640ms/step - loss: 0.5481 - loss/inference_loss: 0.4046 - loss/summary_loss: 0.1436 - val_loss: 0.8737 - val_loss/inference_loss: 0.7477 - val_loss/summary_loss: 0.1261
Epoch 100/100
20/20 ━━━━━━━━━━━━━━━━━━━━ 13s 637ms/step - loss: 0.5406 - loss/inference_loss: 0.3977 - loss/summary_loss: 0.1429 - val_loss: 0.6171 - val_loss/inference_loss: 0.4793 - val_loss/summary_loss: 0.1378
= simulator.sample(1_000)
test_data =workflow.plot_default_diagnostics(test_data=test_data) plots
Application to real data
Here we will work with data from Fortmann, et al. (2008), that is also included in the R package EMC2
(Stevenson, Donzallaz, & Heathcote, 2025). Here we use a slightly reshaped dataset where each combination of subject x condition is extended to a length of 350 trials (the missing trials are filled with 0s). This allows us to estimate the posterior for each subject in each condition with a single pass through the posterior approximator.
= pd.read_csv("forstmann.csv") data_inference
'condition'].unique() data_inference[
array(['accuracy', 'speed', 'neutral'], dtype=object)
= data_inference.groupby(["subject", "condition"]) data_inference_grouped
= {
data_inference_dict 350, 1) for _, group in data_inference_grouped])
key: np.array([group[key].values.reshape(for key in ['rt', 'response', 'observed']}
"n"] = np.sum(data_inference_dict["observed"], axis=1)
data_inference_dict[print({key: value.shape for key, value in data_inference_dict.items()})
{'rt': (57, 350, 1), 'response': (57, 350, 1), 'observed': (57, 350, 1), 'n': (57, 1)}
= workflow.sample(conditions=data_inference_dict, num_samples=1_000) posterior_samples
# pick the first participant, first condition
= {key: value[0] for key, value in posterior_samples.items()}
posterior = data_inference_grouped.get_group(('as1t', 'accuracy')) data
=bf.diagnostics.pairs_posterior(estimates=posterior) f
def ecdf(rt, response, observed, **kwargs):
= (observed == 1)
observed_mask = ((response == 0) & observed_mask)
response_0_mask = ((response == 1) & observed_mask)
response_1_mask
= np.sum(response_0_mask) / np.sum(observed_mask)
response_0_prop = np.sum(response_1_mask) / np.sum(observed_mask)
response_1_prop
= stats.ecdf(rt[response_0_mask]).cdf
response_0_ecdf = response_0_prop * response_0_ecdf.evaluate(np.linspace(0, 1, 101))
response_0_ecdf
= stats.ecdf(rt[response_1_mask]).cdf
response_1_ecdf = response_1_prop * response_1_ecdf.evaluate(np.linspace(0, 1, 101))
response_1_ecdf
return response_0_ecdf, response_1_ecdf
= ecdf(**data) plot_data
= simulator.sample(1000, **posterior) posterior_predictives
= []
plot_data_predictive for i in range(1000):
= { key: value[i:i+1,...] for key, value in posterior_predictives.items()}
x **x))
plot_data_predictive.append(ecdf(= np.array(plot_data_predictive) plot_data_predictive
= np.quantile(
plot_data_quantiles
plot_data_predictive, = [0.25, 0.5, 0.75],
q =0
axis
) plot_data_quantiles.shape
(3, 2, 101)
= np.linspace(0, 1, 101)
t = ["red", "blue"]
cols for i, lab in enumerate(["Incorrect", "Correct"]):
=lab, color=cols[i])
plt.plot(t, plot_data[i], label1, i, :], color=cols[i], alpha=0.5, label="median predictive")
plt.plot(t, plot_data_quantiles[
plt.fill_between(
t,0, i,:],
plot_data_quantiles[-1, i,:],
plot_data_quantiles[="50% predictive interval",
label=cols[i],
color=0.3
alpha
)=plt.legend() f
Further exercises
- The diagnostics show that there are some situations where we learn basically nothing from the data. Find out whether there is some common pattern between these simulations and whether you could fix this.
- Each participant in the experiment was exposed to three different conditions (speed, neutral, accuracy). Here we fitted a simple model to each condition separately. However, it would make much more sense to fit a single model that encompassess all three conditions. Think about which parameters of the decision making model should theoretically be affected by the experimental condition, and implement a model that quantifies these differences.
- The data we worked with here are recoded such that we only have correct vs incorrect responses. However, the actual data contain two columns, where the stimulus has values “left” vs “right” and the response as well, “left” or “right”. Think about how to reparametrize the model such that instead of having one accumulator for correct and one for incorrect responses, you would have one accumulator for “left” and one accumulator “right” – and how would you make sure that the model responds to changes in stimulus. Train a BayesFlow model based on data from such simulator.
- The function that simulates the evidence accumulation process (
evidence_accumulation
) generates a standard Wiener diffusion with drift. Look for alternative ways how you could represent this process and try to implement it in a statistical model (examples: linear accumulation, Ornstein-Uhlenbeck process, Levy flight, …).
References
Anders, R., Alario, F., & Van Maanen, L. (2016). The shifted Wald distribution for response time data analysis. Psychological methods, 21(3), 309.
Forstmann, B. U., Dutilh, G., Brown, S., Neumann, J., Von Cramon, D. Y., Ridderinkhof, K. R., & Wagenmakers, E. J. (2008). Striatum and pre-SMA facilitate decision-making under time pressure. Proceedings of the National Academy of Sciences, 105(45), 17538-17542.
Stevenson N., Donzallaz M., Heathcote A. (2025). EMC2: Bayesian Hierarchical Analysis of Cognitive Models of Choice. R package version 3.1.0, https://github.com/ampl-psych/emc2.
Tillman, G., Van Zandt, T., & Logan, G. D. (2020). Sequential sampling models without random between-trial variability: The racing diffusion model of speeded decision making. Psychonomic Bulletin & Review, 27(5), 911-936.