1"""Demo code to fit a 2d Gaussian model with soft histograms and jax.grad"""
2
3from jax import numpy as jnp
4from jax import random as jran
5from jax import jit as jjit
6from jax import value_and_grad
7from collections import namedtuple
8from diffsky.signdhist_lomem import nnsig_ndhist
9
10DGParams = namedtuple("DGParams", ("mu0", "sig0", "mu1", "sig1", "frac0"))
11DEFAULT_PARAMS = DGParams(mu0=-1.0, sig0=0.5, mu1=1.0, sig1=1.0, frac0=0.75)
12
13NPTS = 20_000
14
15
16@jjit
17def mc_double_gaussian(params, ran_key):
18 """Draw a stochastic Monte Carlo realization of a double Gaussian"""
19 u_key, n0_key, n1_key = jran.split(ran_key, 3)
20 uran = jran.uniform(u_key, minval=0, maxval=1, shape=(NPTS,))
21 n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0
22 n1 = jran.normal(n1_key, shape=(NPTS,)) * params.sig1 + params.mu1
23 mc_0 = uran < params.frac0
24 xdata = jnp.where(mc_0, n0, n1)
25 return xdata
26
27
28@jjit
29def predict_soft_xhist_mc(params, xbins, ran_key):
30 """Predict histogram counts by applying soft histogram to
31 a stochastic Monte Carlo realization of a double Gaussian"""
32 xdata = mc_double_gaussian(params, ran_key)
33 xhist = soft_xhist(xdata, xbins)
34 return xhist
35
36
37@jjit
38def predict_soft_xhist_weighted(params, xbins, ran_key):
39 """Predict histogram counts by applying soft histogram to
40 a PDF-weighted Monte Carlo realization of a double Gaussian"""
41 n0_key, n1_key = jran.split(ran_key, 2)
42 n0 = jran.normal(n0_key, shape=(NPTS,)) * params.sig0 + params.mu0
43 n1 = jran.normal(n1_key, shape=(NPTS,)) * params.sig1 + params.mu1
44 xhist0 = soft_xhist(n0, xbins)
45 xhist1 = soft_xhist(n1, xbins)
46 xhist = params.frac0 * xhist0 + (1.0 - params.frac0) * xhist1
47 return xhist
48
49
50@jjit
51def soft_xhist(xdata, xbins):
52 """Soft histogram function
53 This is a wrapper around diffsky.nnsig_ndhist for 1d data"""
54 nbins = xbins.shape[0]
55 xbins_lo = xbins[:-1].reshape((nbins - 1, 1))
56 xbins_hi = xbins[1:].reshape((nbins - 1, 1))
57 dx = jnp.diff(xbins).mean()
58 ndsig = jnp.zeros_like(xbins_lo) + dx / 2
59 xdata = xdata.reshape((-1, 1))
60 xhist = nnsig_ndhist(xdata, ndsig, xbins_lo, xbins_hi)
61 return xhist
62
63
64@jjit
65def _mae_kern(x, y):
66 """Mean absolute error"""
67 abs_diff = jnp.abs(y - x)
68 return jnp.mean(abs_diff)
69
70
71@jjit
72def weighted_mae_loss(params, loss_data):
73 """Loss function based on a PDF-weighted soft histogram"""
74 xhist_target, xbins, ran_key = loss_data
75 xhist_pred = predict_soft_xhist_weighted(params, xbins, ran_key)
76 loss = _mae_kern(xhist_pred, xhist_target)
77 return loss
78
79
80@jjit
81def mc_mae_loss(params, loss_data):
82 """Loss function based on a stochastic Monte Carlo with a soft histogram"""
83 xhist_target, xbins, ran_key = loss_data
84 xhist_pred = predict_soft_xhist_mc(params, xbins, ran_key)
85 loss = _mae_kern(xhist_pred, xhist_target)
86 return loss
87
88
89@jjit
90def param_update(params, grads, learning_rate):
91 """Update namedtuple params by taking a small step down the gradient"""
92 new_params = params._make(jnp.array(params) - jnp.array(grads) * learning_rate)
93 return new_params
94
95
96weighted_mae_loss_and_grad = jjit(value_and_grad(weighted_mae_loss, argnums=0))
97mc_mae_loss_and_grad = jjit(value_and_grad(mc_mae_loss, argnums=0))