1"""Demo code to fit a 1d Gaussian model with soft histograms and jax.grad"""
2
3from jax import numpy as jnp
4from jax import jit as jjit
5from jax import random as jran
6from jax import value_and_grad
7from collections import namedtuple
8from diffsky.signdhist_lomem import nnsig_ndhist
9
10GParams = namedtuple("GParams", ("mu", "sig"))
11DEFAULT_PARAMS = GParams(mu=-1.0, sig=1.0)
12
13NPTS = 20_000
14
15
16@jjit
17def mc_single_gaussian(params, ran_key):
18 """Draw a Monte Carlo realization of a Gaussian"""
19 xdata = jran.normal(ran_key, shape=(NPTS,)) * params.sig + params.mu
20 return xdata
21
22
23@jjit
24def mc_predict_hard_edged_xhist(params, xbins, ran_key):
25 """Predict histogram counts by applying jnp.histogram to
26 a Monte Carlo realization of a Gaussian"""
27 xdata = mc_single_gaussian(params, ran_key)
28 xhist, __ = jnp.histogram(xdata, bins=xbins)
29 return xhist
30
31
32@jjit
33def mc_predict_soft_xhist(params, xbins, ran_key):
34 """Predict histogram counts by applying a soft histogram to
35 a Monte Carlo realization of a Gaussian"""
36 xdata = mc_single_gaussian(params, ran_key)
37 n = xdata.shape[0]
38 xdata = xdata.reshape((n, 1))
39 xhist = soft_xhist(xdata, xbins)
40 return xhist
41
42
43@jjit
44def soft_xhist(xdata, xbins):
45 """Soft histogram function
46 This is a wrapper around diffsky.nnsig_ndhist for 1d data"""
47 nbins = xbins.shape[0]
48 xbins_lo = xbins[:-1].reshape((nbins - 1, 1))
49 xbins_hi = xbins[1:].reshape((nbins - 1, 1))
50 dx = jnp.diff(xbins).mean()
51 ndsig = jnp.zeros_like(xbins_lo) + dx / 2
52 xdata = xdata.reshape((-1, 1))
53 xhist = nnsig_ndhist(xdata, ndsig, xbins_lo, xbins_hi)
54 return xhist
55
56
57@jjit
58def _mae_kern(x, y):
59 """Mean absolute error"""
60 abs_diff = jnp.abs(y - x)
61 return jnp.mean(abs_diff)
62
63
64@jjit
65def hard_edged_xhist_loss(params, loss_data):
66 """Loss function based on a histogram with hard-edged bins"""
67 xhist_target, xbins, ran_key = loss_data
68 xhist_pred = mc_predict_hard_edged_xhist(params, xbins, ran_key)
69 loss = _mae_kern(xhist_pred, xhist_target)
70 return loss
71
72
73@jjit
74def soft_xhist_loss(params, loss_data):
75 """Loss function based on a soft histogram"""
76 xhist_target, xbins, ran_key = loss_data
77 xhist_pred = mc_predict_soft_xhist(params, xbins, ran_key)
78 loss = _mae_kern(xhist_pred, xhist_target)
79 return loss
80
81
82@jjit
83def param_update(params, grads, learning_rate):
84 """Update namedtuple params by taking a small step down the gradient"""
85 new_params = params._make(jnp.array(params) - jnp.array(grads) * learning_rate)
86 return new_params
87
88
89hard_edged_xhist_loss_and_grad = jjit(value_and_grad(hard_edged_xhist_loss, argnums=0))
90soft_xhist_loss_and_grad = jjit(value_and_grad(soft_xhist_loss, argnums=0))