Supplementary source code for a double GaussianΒΆ

 1"""Demo code to fit a 2d Gaussian model with soft histograms and jax.grad"""
 2
 3from jax import numpy as jnp
 4from jax import random as jran
 5from jax import jit as jjit
 6from jax import value_and_grad
 7from collections import namedtuple
 8from diffsky.signdhist_lomem import nnsig_ndhist
 9
10DGParams = namedtuple("DGParams", ("mu0", "sig0", "mu1", "sig1", "frac0"))
11DEFAULT_PARAMS = DGParams(mu0=-1.0, sig0=0.5, mu1=1.0, sig1=1.0, frac0=0.75)
12
13NPTS = 20_000
14
15
16@jjit
17def mc_double_gaussian(params, ran_key):
18    """Draw a stochastic Monte Carlo realization of a double Gaussian"""
19    u_key, n0_key, n1_key = jran.split(ran_key, 3)
20    uran = jran.uniform(u_key, minval=0, maxval=1, shape=(NPTS,))
21    n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0
22    n1 = jran.normal(n1_key, shape=(NPTS,)) * params.sig1 + params.mu1
23    mc_0 = uran < params.frac0
24    xdata = jnp.where(mc_0, n0, n1)
25    return xdata
26
27
28@jjit
29def predict_soft_xhist_mc(params, xbins, ran_key):
30    """Predict histogram counts by applying soft histogram to
31    a stochastic Monte Carlo realization of a double Gaussian"""
32    xdata = mc_double_gaussian(params, ran_key)
33    xhist = soft_xhist(xdata, xbins)
34    return xhist
35
36
37@jjit
38def predict_soft_xhist_weighted(params, xbins, ran_key):
39    """Predict histogram counts by applying soft histogram to
40    a PDF-weighted Monte Carlo realization of a double Gaussian"""
41    n0_key, n1_key = jran.split(ran_key, 2)
42    n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0
43    n1 = jran.normal(n1_key, shape=(NPTS,)) * params.sig1 + params.mu1
44    xhist0 = soft_xhist(n0, xbins)
45    xhist1 = soft_xhist(n1, xbins)
46    xhist = params.frac0 * xhist0 + (1.0 - params.frac0) * xhist1
47    return xhist
48
49
50@jjit
51def soft_xhist(xdata, xbins):
52    """Soft histogram function
53    This is a wrapper around diffsky.nnsig_ndhist for 1d data"""
54    nbins = xbins.shape[0]
55    xbins_lo = xbins[:-1].reshape((nbins - 1, 1))
56    xbins_hi = xbins[1:].reshape((nbins - 1, 1))
57    dx = jnp.diff(xbins).mean()
58    ndsig = jnp.zeros_like(xbins_lo) + dx / 2
59    xdata = xdata.reshape((-1, 1))
60    xhist = nnsig_ndhist(xdata, ndsig, xbins_lo, xbins_hi)
61    return xhist
62
63
64@jjit
65def _mae_kern(x, y):
66    """Mean absolute error"""
67    abs_diff = jnp.abs(y - x)
68    return jnp.mean(abs_diff)
69
70
71@jjit
72def weighted_mae_loss(params, loss_data):
73    """Loss function based on a PDF-weighted soft histogram"""
74    xhist_target, xbins, ran_key = loss_data
75    xhist_pred = predict_soft_xhist_weighted(params, xbins, ran_key)
76    loss = _mae_kern(xhist_pred, xhist_target)
77    return loss
78
79
80@jjit
81def mc_mae_loss(params, loss_data):
82    """Loss function based on a stochastic Monte Carlo with a soft histogram"""
83    xhist_target, xbins, ran_key = loss_data
84    xhist_pred = predict_soft_xhist_mc(params, xbins, ran_key)
85    loss = _mae_kern(xhist_pred, xhist_target)
86    return loss
87
88
89@jjit
90def param_update(params, grads, learning_rate):
91    """Update namedtuple params by taking a small step down the gradient"""
92    new_params = params._make(jnp.array(params) - jnp.array(grads) * learning_rate)
93    return new_params
94
95
96weighted_mae_loss_and_grad = jjit(value_and_grad(weighted_mae_loss, argnums=0))
97mc_mae_loss_and_grad = jjit(value_and_grad(mc_mae_loss, argnums=0))