{ "cells": [ { "cell_type": "markdown", "id": "494ee05f-bd9b-45eb-9816-1bacc370ce7c", "metadata": {}, "source": [ "# Fitting a double Gaussian with soft histograms\n", "\n", "This notebook shows how to implement a double Gaussian model in JAX, and demonstrates how to optimize the parameters of the model by fitting to soft histograms with gradient descent." ] }, { "cell_type": "code", "execution_count": null, "id": "981d72a6-0042-4f7a-9397-2f7e64429469", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from jax import random as jran\n", "from jax import numpy as jnp\n", "\n", "ran_key = jran.key(0)" ] }, { "cell_type": "markdown", "id": "bc19e244-4499-4b6c-a4a3-19723b21986c", "metadata": {}, "source": [ "## Stochastic Monte Carlo predictions\n", "\n", "The `mc_double_gaussian` function generates a sample of 1d data by standard Monte Carlo methods:\n", "1. Draw $N$ points from the first Gaussian, $\\{\\mu_0, \\sigma_0\\}$\n", "2. Draw $N$ points from the second Gaussian, $\\{\\mu_1, \\sigma_1\\}$\n", "3. Draw $N$ uniform random numbers, $u$\n", "4. If $f$ is the model parameter controlling the relative height of the two Gaussians, then for points with $u