Supplementary source code for soft histogramsΒΆ

 1"""Demo code to fit a 1d Gaussian model with soft histograms and jax.grad"""
 2
 3from jax import numpy as jnp
 4from jax import jit as jjit
 5from jax import random as jran
 6from jax import value_and_grad
 7from collections import namedtuple
 8from diffsky.signdhist_lomem import nnsig_ndhist
 9
10GParams = namedtuple("GParams", ("mu", "sig"))
11DEFAULT_PARAMS = GParams(mu=-1.0, sig=1.0)
12
13NPTS = 20_000
14
15
16@jjit
17def mc_single_gaussian(params, ran_key):
18    """Draw a Monte Carlo realization of a Gaussian"""
19    xdata = jran.normal(ran_key, shape=(NPTS,)) * params.sig + params.mu
20    return xdata
21
22
23@jjit
24def mc_predict_hard_edged_xhist(params, xbins, ran_key):
25    """Predict histogram counts by applying jnp.histogram to
26    a Monte Carlo realization of a Gaussian"""
27    xdata = mc_single_gaussian(params, ran_key)
28    xhist, __ = jnp.histogram(xdata, bins=xbins)
29    return xhist
30
31
32@jjit
33def mc_predict_soft_xhist(params, xbins, ran_key):
34    """Predict histogram counts by applying a soft histogram to
35    a Monte Carlo realization of a Gaussian"""
36    xdata = mc_single_gaussian(params, ran_key)
37    n = xdata.shape[0]
38    xdata = xdata.reshape((n, 1))
39    xhist = soft_xhist(xdata, xbins)
40    return xhist
41
42
43@jjit
44def soft_xhist(xdata, xbins):
45    """Soft histogram function
46    This is a wrapper around diffsky.nnsig_ndhist for 1d data"""
47    nbins = xbins.shape[0]
48    xbins_lo = xbins[:-1].reshape((nbins - 1, 1))
49    xbins_hi = xbins[1:].reshape((nbins - 1, 1))
50    dx = jnp.diff(xbins).mean()
51    ndsig = jnp.zeros_like(xbins_lo) + dx / 2
52    xdata = xdata.reshape((-1, 1))
53    xhist = nnsig_ndhist(xdata, ndsig, xbins_lo, xbins_hi)
54    return xhist
55
56
57@jjit
58def _mae_kern(x, y):
59    """Mean absolute error"""
60    abs_diff = jnp.abs(y - x)
61    return jnp.mean(abs_diff)
62
63
64@jjit
65def hard_edged_xhist_loss(params, loss_data):
66    """Loss function based on a histogram with hard-edged bins"""
67    xhist_target, xbins, ran_key = loss_data
68    xhist_pred = mc_predict_hard_edged_xhist(params, xbins, ran_key)
69    loss = _mae_kern(xhist_pred, xhist_target)
70    return loss
71
72
73@jjit
74def soft_xhist_loss(params, loss_data):
75    """Loss function based on a soft histogram"""
76    xhist_target, xbins, ran_key = loss_data
77    xhist_pred = mc_predict_soft_xhist(params, xbins, ran_key)
78    loss = _mae_kern(xhist_pred, xhist_target)
79    return loss
80
81
82@jjit
83def param_update(params, grads, learning_rate):
84    """Update namedtuple params by taking a small step down the gradient"""
85    new_params = params._make(jnp.array(params) - jnp.array(grads) * learning_rate)
86    return new_params
87
88
89hard_edged_xhist_loss_and_grad = jjit(value_and_grad(hard_edged_xhist_loss, argnums=0))
90soft_xhist_loss_and_grad = jjit(value_and_grad(soft_xhist_loss, argnums=0))