{ "cells": [ { "cell_type": "markdown", "id": "02646743-5e32-47c9-b942-99d8e303185a", "metadata": {}, "source": [ "# Introduction to soft histograms\n", "\n", "This notebook demonstrates the basic principles of using a soft histogram when fitting a model for a probability distribution. We'll use a one-dimensional Gaussian distribution as our model, and the first thing we'll do is try to fit the model parameters using a loss function based on a standard histogram. As discussed below, computing standard histograms with hard-edged bins is a non-differentiable calculation, and so our first attempt to fit the model will fail. We'll then see how using a soft histogram solves the problem.\n", "\n", "This demo is based on the `single_gaussian.py` module, which implements a Gaussian model. Let's start out by using the module to generate some target data with the `mc_single_gaussian` function, and visually inspect the distributions." ] }, { "cell_type": "code", "execution_count": null, "id": "db9b4c72-fd3b-4e54-bbab-8b3e7c925fac", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", "from jax import numpy as jnp\n", "\n", "ran_key = jran.key(0)" ] }, { "cell_type": "code", "execution_count": null, "id": "8b84cf21-5d5a-425c-bbe1-83938f79f5d3", "metadata": {}, "outputs": [], "source": [ "import single_gaussian as sg\n", "\n", "NBINS = 50\n", "XBOUNDS = (-10.0, 10.0)\n", "XBINS = np.linspace(*XBOUNDS, NBINS)[:-1]\n", "\n", "\n", "PARAMS_INIT = sg.DEFAULT_PARAMS._replace()\n", "ran_key, init_key = jran.split(ran_key, 2)\n", "XDATA_INIT = sg.mc_single_gaussian(PARAMS_INIT, init_key)\n", "\n", "PARAMS_TARGET = sg.DEFAULT_PARAMS._replace(mu=-2.0, sig=2.0)\n", "ran_key, target_key = jran.split(ran_key, 2)\n", "XDATA_TARGET = sg.mc_single_gaussian(PARAMS_TARGET, target_key)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "__=ax.hist(XDATA_TARGET, bins=XBINS, \n", " alpha=0.7, label=r'target population')\n", "__=ax.hist(XDATA_INIT, bins=XBINS, \n", " alpha=0.7, label=r'initial population')\n", "leg = ax.legend()" ] }, { "cell_type": "markdown", "id": "882f0f64-fd68-4dcb-96b8-dd18c1e4301b", "metadata": {}, "source": [ "### Predicting a histogram from a population\n", "\n", "The `mc_predict_hard_edged_xhist` function is just a wrapper around `mc_single_gaussian` that first generates target data, and then uses `jnp.histogram` to bin the data into a predicted histogram." ] }, { "cell_type": "code", "execution_count": null, "id": "7428dde5-e537-4e31-a73b-6a58fc54c0a9", "metadata": {}, "outputs": [], "source": [ "ran_key, init_key = jran.split(ran_key, 2)\n", "XHIST_INIT = sg.mc_predict_hard_edged_xhist(\n", " PARAMS_INIT, XBINS, init_key)\n", "\n", "ran_key, target_key = jran.split(ran_key, 2)\n", "XHIST_TARGET = sg.mc_predict_hard_edged_xhist(\n", " PARAMS_TARGET, XBINS, target_key)" ] }, { "cell_type": "markdown", "id": "181e6fde-bb21-4b9d-955c-820d8dc7c533", "metadata": {}, "source": [ "### Running gradient descent\n", "\n", "The `hard_edged_xhist_loss_and_grad` function is the loss function we will try to minimize with gradient descent. This loss function uses `mc_predict_hard_edged_xhist` to predict a histogram, and then computes the mean absolute error between the predicted and target histogram. We use `jax.value_and_grad` to additionally return the gradients of the loss in addition to the value.\n", "\n", "The next cell takes 100 steps of gradient descent, where at each step, we take a tiny step down the gradient to update the parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "95275efa-dd29-4f05-bff5-33a16374ba63", "metadata": {}, "outputs": [], "source": [ "ran_key, loss_key = jran.split(ran_key, 2)\n", "loss_data = XHIST_TARGET, XBINS, loss_key\n", "\n", "learn_rate = 0.0005\n", "\n", "nsteps = 100\n", "loss_collector = []\n", "p_best = PARAMS_INIT._replace()\n", "for istep in range(nsteps):\n", " loss, grads = sg.hard_edged_xhist_loss_and_grad(\n", " p_best, loss_data)\n", " p_best = sg.param_update(\n", " p_best, grads, learn_rate)\n", " loss_collector.append(loss)" ] }, { "cell_type": "markdown", "id": "952233c7-fd78-4d42-a6e1-fa6eb65818ef", "metadata": {}, "source": [ "### Inspect the results\n", "\n", "We'll now inspect the results by plotting the loss curve, and comparing the best-fit histogram to the target." ] }, { "cell_type": "code", "execution_count": null, "id": "8ce1adc0-c970-4b89-9d40-e81290138011", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1)\n", "xlabel = ax.set_xlabel('step')\n", "ylabel = ax.set_ylabel('log10 loss')\n", "__=ax.plot(np.log10(loss_collector))\n", "\n", "ran_key, pred_key = jran.split(ran_key, 2)\n", "xhist_best = sg.mc_predict_hard_edged_xhist(\n", " p_best, XBINS, pred_key)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", " label='target')\n", "__=ax.plot(XBINS[1:], XHIST_INIT, '--', \n", " label='initial guess')\n", "__=ax.plot(XBINS[1:], xhist_best, ':', \n", " label='best fit MC method')\n", "leg = ax.legend()" ] }, { "cell_type": "markdown", "id": "416d9d9b-9d41-4547-8539-979609f03548", "metadata": {}, "source": [ "### That didn't work at all - what happened?\n", "\n", "It looks like our best-fit histogram didn't move from the initial histogram. Let's see how the best-fit points compare to the target and initial points:" ] }, { "cell_type": "code", "execution_count": null, "id": "077abebd-b557-4215-b722-8895795afbf9", "metadata": {}, "outputs": [], "source": [ "gen = zip(p_best._fields, p_best, PARAMS_INIT, PARAMS_TARGET)\n", "for key, val_best, val_init, val_target in gen:\n", " print(f\"Init {key} = {val_init:.2f}\")\n", " print(f\"Best {key} = {val_best:.2f}\")\n", " print(f\"True {key} = {val_target:.2f}\\n\")" ] }, { "cell_type": "markdown", "id": "1e47b5bc-9285-4c8f-a791-97c050ab8f79", "metadata": {}, "source": [ "### Hmmm, the parameters didn't move at all...\n", "\n", "Let's inspect the gradient" ] }, { "cell_type": "code", "execution_count": null, "id": "5ec7fa72-6280-4de6-ae60-967cef8617c4", "metadata": {}, "outputs": [], "source": [ "loss_best_mc, grads = sg.hard_edged_xhist_loss_and_grad(\n", " p_best, loss_data)\n", "print(grads)" ] }, { "cell_type": "markdown", "id": "d6fca23e-1d12-4af0-aeb0-a6698a7c4f44", "metadata": {}, "source": [ "### All the parameters have zero gradients!\n", "\n", "That's why the parameter did not move from its initial position during our gradient descent. The problem comes from trying to differentiate through a histogram with hard-edged bins.\n", "\n", "#### Why don't hard-edged histograms work with autodiff?\n", "\n", "Let's think about how autodiff works with hard-edged histograms to understand what happened. The way a standard histogram calculation works is that for each bin $i$, we loop over each point $x_{\\rm j}$ in our dataset, and if the point falls within the boundaries of bin $i$, we increment our histogram by 1, otherwise we increment by 0. Now consider how the dataset changes with an _infinitesimal_ change to the position of each point, ${\\rm d}x$. If $x_{\\rm j}$ is within the bin boundaries, then the perturbed position $x_{\\rm j}+{\\rm d}x$ also within the bin boundaries, because the point and the boundary are some _finite_ distance away, and we have only perturbed our point by an infinitesimal amount. The same goes for points outside the bin. The only points with non-zero gradients will be those points that just so happen to fall exactly on the boundary of some bin, a set of measure zero. Thus it makes sense that we get zero-valued gradients for predictions made with hard-edged histograms." ] }, { "cell_type": "markdown", "id": "6cb14d0c-3862-4175-aa06-58abad394cc9", "metadata": {}, "source": [ "## Introducing soft histograms\n", "\n", "The solution to this problem is to use soft histograms. In standard histograms, each point in the dataset $x_{\\rm j}$ contributes either 0 or 1 to the result for each bin. In soft histograms, each point contributes a continuously-valued _weight_, $w_{\\rm j},$ to the result for each bin. Each $w_{\\rm j}$ is computed by integrating a Gaussian kernel across the edges of the bin. For histogram bins with width ${\\rm d}x,$ we typically choose a kernel width $\\sigma\\lesssim{\\rm d}x.$\n", "\n", "There are several soft histogram calculators in diffsky. All of the calculators are written to support N-dimensional data, and so the `soft_xhist` function in `single_gaussian.py` just provides some wrapper behavior that reshapes the data and the bins to ${\\rm (n, 1)}$." ] }, { "cell_type": "code", "execution_count": null, "id": "d6ba0bf0-cc8a-4a75-a0e3-07f4ec4e6905", "metadata": {}, "outputs": [], "source": [ "XHIST_TARGET, __ = jnp.histogram(XDATA_TARGET, bins=XBINS)\n", "XHIST_TARGET_SOFT = sg.soft_xhist(XDATA_TARGET, XBINS)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", " label='standard histogram')\n", "__=ax.plot(XBINS[1:], XHIST_TARGET_SOFT,'--', \n", " label='soft histogram')\n", "leg = ax.legend()" ] }, { "cell_type": "markdown", "id": "976682e5-b257-47df-8370-608e26ba0588", "metadata": {}, "source": [ "### Running gradient descent with a soft histogram\n", "\n", "The next cell takes 100 steps of gradient descent, this time with a loss function based on `soft_xhist_loss_and_grad`." ] }, { "cell_type": "code", "execution_count": null, "id": "2b79da38-ec57-4c40-aea9-2e765cc6804c", "metadata": {}, "outputs": [], "source": [ "ran_key, loss_key = jran.split(ran_key, 2)\n", "loss_data = XHIST_TARGET, XBINS, loss_key\n", "\n", "learn_rate = 0.0005\n", "\n", "nsteps = 100\n", "soft_loss_collector = []\n", "p_best_soft = PARAMS_INIT._replace()\n", "for istep in range(nsteps):\n", " loss, grads = sg.soft_xhist_loss_and_grad(\n", " p_best_soft, loss_data)\n", " p_best_soft = sg.param_update(\n", " p_best_soft, grads, learn_rate)\n", " soft_loss_collector.append(loss)" ] }, { "cell_type": "markdown", "id": "e8369b95-3a4b-40b9-b2d4-36c0225c09d3", "metadata": {}, "source": [ "### Inspect the results" ] }, { "cell_type": "code", "execution_count": null, "id": "20fc856d-31d8-4d73-98f9-abedd5ecbb4d", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1)\n", "xlabel = ax.set_xlabel('step')\n", "ylabel = ax.set_ylabel('log10 loss')\n", "__=ax.plot(np.log10(soft_loss_collector))\n", "\n", "ran_key, pred_key = jran.split(ran_key, 2)\n", "xhist_best_soft = sg.mc_predict_soft_xhist(\n", " p_best_soft, XBINS, pred_key)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", "__=ax.plot(XBINS[1:], XHIST_TARGET, \n", " label='target')\n", "__=ax.plot(XBINS[1:], XHIST_INIT, '--', \n", " label='initial guess')\n", "__=ax.plot(XBINS[1:], xhist_best_soft, ':', \n", " label='best fit MC method')\n", "\n", "leg = ax.legend()" ] }, { "cell_type": "code", "execution_count": null, "id": "e5b40c02-2305-4fc5-832f-db09ce68c51c", "metadata": {}, "outputs": [], "source": [ "gen = zip(p_best_soft._fields, p_best_soft, PARAMS_INIT, PARAMS_TARGET)\n", "for key, val_best, val_init, val_target in gen:\n", " print(f\"Init {key} = {val_init:.2f}\")\n", " print(f\"Best {key} = {val_best:.2f}\")\n", " print(f\"True {key} = {val_target:.2f}\\n\")" ] }, { "cell_type": "markdown", "id": "4b6d1777-8fcd-49dd-945d-1cce2e93bc03", "metadata": {}, "source": [ "### It worked! \n", "\n", "With a soft histogram, when we perturb each point by some infinitesimal amount, the weight of each point also changes infinitesimally, and so we get non-zero gradients." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.11" } }, "nbformat": 4, "nbformat_minor": 5 }