{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cb3896a1",
   "metadata": {},
   "source": [
    "\n",
    "<a id='parallel'></a>\n",
    "<div id=\"qe-notebook-header\" align=\"right\" style=\"text-align:right;\">\n",
    "        <a href=\"https://quantecon.org/\" title=\"quantecon.org\">\n",
    "                <img style=\"width:250px;display:inline;\" width=\"250px\" src=\"https://assets.quantecon.org/img/qe-menubar-logo.svg\" alt=\"QuantEcon\">\n",
    "        </a>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dc31a68",
   "metadata": {},
   "source": [
    "# NumPy vs Numba vs JAX\n",
    "\n",
    "In the preceding lectures, we’ve discussed three core libraries for scientific\n",
    "and numerical computing:\n",
    "\n",
    "- [NumPy](https://python-programming.quantecon.org/numpy.html)  \n",
    "- [Numba](https://python-programming.quantecon.org/numba.html)  \n",
    "- [JAX](https://python-programming.quantecon.org/jax_intro.html)  \n",
    "\n",
    "\n",
    "Which one should we use in any given situation?\n",
    "\n",
    "This lecture addresses that question, at least partially, by discussing some use cases.\n",
    "\n",
    "Before getting started, we note that the first two are a natural pair: NumPy and\n",
    "Numba play well together.\n",
    "\n",
    "JAX, on the other hand, stands alone.\n",
    "\n",
    "When considering each approach, we will consider not just efficiency and memory\n",
    "footprint but also clarity and ease of use.\n",
    "\n",
    "In addition to what’s in Anaconda, this lecture will need the following libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1e768a7",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "!pip install quantecon jax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1a4c761",
   "metadata": {},
   "source": [
    "# GPU\n",
    "\n",
    "This lecture is accelerated via [hardware](https://python-programming.quantecon.org/status.html#status-machine-details) that has access to a GPU and target JAX for GPU programming.\n",
    "\n",
    "Free GPUs are available on Google Colab.\n",
    "To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.\n",
    "\n",
    "Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.\n",
    "If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`\n",
    "\n",
    "We will use the following imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb8d99e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import quantecon as qe\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d.axes3d import Axes3D\n",
    "from matplotlib import cm\n",
    "import jax\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28e9cfa5",
   "metadata": {},
   "source": [
    "## Vectorized operations\n",
    "\n",
    "Some operations can be perfectly vectorized — all loops are easily eliminated\n",
    "and numerical operations are reduced to calculations on arrays.\n",
    "\n",
    "In this case, which approach is best?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "729b8be9",
   "metadata": {},
   "source": [
    "### Problem Statement\n",
    "\n",
    "Consider the problem of maximizing a function $ f $ of two variables $ (x,y) $ over\n",
    "the square $ [-a, a] \\times [-a, a] $.\n",
    "\n",
    "For $ f $ and $ a $ let’s choose\n",
    "\n",
    "$$\n",
    "f(x,y) = \\frac{\\cos(x^2 + y^2)}{1 + x^2 + y^2}\n",
    "\\quad \\text{and} \\quad\n",
    "a = 3\n",
    "$$\n",
    "\n",
    "Here’s a plot of $ f $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c38494a7",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def f(x, y):\n",
    "    return np.cos(x**2 + y**2) / (1 + x**2 + y**2)\n",
    "\n",
    "xgrid = np.linspace(-3, 3, 50)\n",
    "ygrid = xgrid\n",
    "x, y = np.meshgrid(xgrid, ygrid)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 8))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.plot_surface(x,\n",
    "                y,\n",
    "                f(x, y),\n",
    "                rstride=2, cstride=2,\n",
    "                cmap=cm.jet,\n",
    "                alpha=0.7,\n",
    "                linewidth=0.25)\n",
    "ax.set_zlim(-0.5, 1.0)\n",
    "ax.set_xlabel('$x$', fontsize=14)\n",
    "ax.set_ylabel('$y$', fontsize=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77926054",
   "metadata": {},
   "source": [
    "For the sake of this exercise, we’re going to use brute force for the\n",
    "maximization.\n",
    "\n",
    "1. Evaluate $ f $ for all $ (x,y) $ in a grid on the square.  \n",
    "1. Return the maximum of observed values.  \n",
    "\n",
    "\n",
    "Just to illustrate the idea, here’s a non-vectorized version that uses Python loops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea54d32",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "grid = np.linspace(-3, 3, 50)\n",
    "m = -np.inf\n",
    "for x in grid:\n",
    "    for y in grid:\n",
    "        z = f(x, y)\n",
    "        if z > m:\n",
    "            m = z"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "701631df",
   "metadata": {},
   "source": [
    "### NumPy vectorization\n",
    "\n",
    "If we switch to NumPy-style vectorization we can use a much larger grid and the\n",
    "code executes relatively quickly.\n",
    "\n",
    "Here we use `np.meshgrid` to create two-dimensional input grids `x` and `y` such\n",
    "that `f(x, y)` generates all evaluations on the product grid.\n",
    "\n",
    "(This strategy dates back to Matlab.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1de02a09",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "grid = np.linspace(-3, 3, 3_000)\n",
    "x, y = np.meshgrid(grid, grid)\n",
    "\n",
    "with qe.Timer(precision=8):\n",
    "    z_max_numpy = np.max(f(x, y))\n",
    "\n",
    "print(f\"NumPy result: {z_max_numpy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd3cce3f",
   "metadata": {},
   "source": [
    "In the vectorized version, all the looping takes place in compiled code.\n",
    "\n",
    "Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs.\n",
    "\n",
    ">**Note**\n",
    ">\n",
    ">If you have a system monitor such as htop (Linux/Mac) or perfmon\n",
    "(Windows), then try running this and then observing the load on your CPUs.\n",
    "\n",
    "(You will probably need to bump up the grid size to see large effects.)\n",
    "\n",
    "The output typically shows that the operation is successfully distributed across multiple threads.\n",
    "\n",
    "(The parallelization cannot be highly efficient because the binary is compiled\n",
    "before it sees the size of the arrays `x` and `y`.)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e07f6855",
   "metadata": {},
   "source": [
    "### A Comparison with Numba\n",
    "\n",
    "Now let’s see if we can achieve better performance using Numba with a simple loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42e5fb70",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import numba\n",
    "\n",
    "@numba.jit\n",
    "def compute_max_numba(grid):\n",
    "    m = -np.inf\n",
    "    for x in grid:\n",
    "        for y in grid:\n",
    "            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)\n",
    "            if z > m:\n",
    "                m = z\n",
    "    return m\n",
    "\n",
    "grid = np.linspace(-3, 3, 3_000)\n",
    "\n",
    "with qe.Timer(precision=8):\n",
    "    compute_max_numba(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f3370f1",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    compute_max_numba(grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756b27d2",
   "metadata": {},
   "source": [
    "Depending on your machine, the Numba version can be a bit slower or a bit faster\n",
    "than NumPy.\n",
    "\n",
    "On one hand, NumPy combines efficient arithmetic (like Numba) with some\n",
    "multithreading (unlike this Numba code), which provides an advantage.\n",
    "\n",
    "On the other hand, the Numba routine uses much less memory, since we are only\n",
    "working with a single one-dimensional grid."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdec0f7c",
   "metadata": {},
   "source": [
    "### Parallelized Numba\n",
    "\n",
    "Now let’s try parallelization with Numba using `prange`:\n",
    "\n",
    "Here’s a naive and *incorrect* attempt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19854a2c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@numba.jit(parallel=True)\n",
    "def compute_max_numba_parallel(grid):\n",
    "    n = len(grid)\n",
    "    m = -np.inf\n",
    "    for i in numba.prange(n):\n",
    "        for j in range(n):\n",
    "            x = grid[i]\n",
    "            y = grid[j]\n",
    "            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)\n",
    "            if z > m:\n",
    "                m = z\n",
    "    return m"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c3c6579",
   "metadata": {},
   "source": [
    "Usually this returns an incorrect result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30fd0957",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "z_max_parallel_incorrect = compute_max_numba_parallel(grid)\n",
    "print(f\"Incorrect parallel Numba result: {z_max_parallel_incorrect}\")\n",
    "print(f\"NumPy result: {z_max_numpy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c6be910",
   "metadata": {},
   "source": [
    "The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.\n",
    "\n",
    "The reason is that the variable $ m $ is shared across threads and not properly controlled.\n",
    "\n",
    "When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.\n",
    "\n",
    "This results in lost updates—threads read stale values of `m` or overwrite each other’s updates—and the variable often never gets updated from its initial value of `-inf`.\n",
    "\n",
    "Here’s a more carefully written version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c8c977",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@numba.jit(parallel=True)\n",
    "def compute_max_numba_parallel(grid):\n",
    "    n = len(grid)\n",
    "    row_maxes = np.empty(n)\n",
    "    for i in numba.prange(n):\n",
    "        row_max = -np.inf\n",
    "        for j in range(n):\n",
    "            x = grid[i]\n",
    "            y = grid[j]\n",
    "            z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)\n",
    "            if z > row_max:\n",
    "                row_max = z\n",
    "        row_maxes[i] = row_max\n",
    "    return np.max(row_maxes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13f9b724",
   "metadata": {},
   "source": [
    "Now the code block that `for i in numba.prange(n)` acts over is independent\n",
    "across `i`.\n",
    "\n",
    "Each thread writes to a separate element of the array `row_maxes`.\n",
    "\n",
    "Hence the parallelization is safe.\n",
    "\n",
    "Here’s the timings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c2b3b30",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    compute_max_numba_parallel(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e093b9bb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    compute_max_numba_parallel(grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfbc6227",
   "metadata": {},
   "source": [
    "If you have multiple cores, you should see at least some benefits from parallelization here.\n",
    "\n",
    "For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "485ea3a8",
   "metadata": {},
   "source": [
    "### Vectorized code with JAX\n",
    "\n",
    "In most ways, vectorization is the same in JAX as it is in NumPy.\n",
    "\n",
    "But there are also some differences, which we highlight here.\n",
    "\n",
    "Let’s start with the function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d163021c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def f(x, y):\n",
    "    return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eae0835c",
   "metadata": {},
   "source": [
    "As with NumPy, to get the right shape and the correct nested `for` loop\n",
    "calculation, we can use a `meshgrid` operation designed for this purpose:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e066d7",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "grid = jnp.linspace(-3, 3, 3_000)\n",
    "x_mesh, y_mesh = np.meshgrid(grid, grid)\n",
    "\n",
    "with qe.Timer(precision=8):\n",
    "    z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4115b5f9",
   "metadata": {},
   "source": [
    "Let’s run again to eliminate compile time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8870576d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68019900",
   "metadata": {},
   "source": [
    "Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.\n",
    "\n",
    "The compilation overhead is a one-time cost that pays off when the function is called repeatedly."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68285ad3",
   "metadata": {},
   "source": [
    "### JAX plus vmap\n",
    "\n",
    "There is one problem with both the NumPy code and the JAX code:\n",
    "\n",
    "While the flat arrays are low-memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611de79c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "grid.nbytes "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ddc6278",
   "metadata": {},
   "source": [
    "the mesh grids are memory intensive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e813ec59",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "x_mesh.nbytes + y_mesh.nbytes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29b596ba",
   "metadata": {},
   "source": [
    "This extra memory usage can be a big problem in actual research calculations.\n",
    "\n",
    "Fortunately, JAX admits a different approach\n",
    "using [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b6242da",
   "metadata": {},
   "source": [
    "#### Version 1\n",
    "\n",
    "Here’s one way we can apply `vmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "799c8222",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Set up f to compute f(x, y) at every x for any given y\n",
    "f_vec_x = lambda y: f(grid, y)\n",
    "# Create a second function that vectorizes this operation over all y\n",
    "f_vec = jax.vmap(f_vec_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78704700",
   "metadata": {},
   "source": [
    "Now `f_vec` will compute `f(x,y)` at every `x,y` when called with the flat array `grid`.\n",
    "\n",
    "Let’s see the timing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf036787",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    z_max = jnp.max(f_vec(grid))\n",
    "    z_max.block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeef0d2b",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    z_max = jnp.max(f_vec(grid))\n",
    "    z_max.block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63d1e77e",
   "metadata": {},
   "source": [
    "By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.\n",
    "\n",
    "When run on a CPU, its runtime is similar to that of the meshgrid version.\n",
    "\n",
    "When run on a GPU, it is usually significantly faster.\n",
    "\n",
    "In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.\n",
    "\n",
    "This leads to code that is often easier to comprehend than traditional vectorized code.\n",
    "\n",
    "We will investigate these ideas more when we tackle larger problems."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b6c845b",
   "metadata": {},
   "source": [
    "### vmap version 2\n",
    "\n",
    "We can be still more memory efficient using vmap.\n",
    "\n",
    "While we avoid large input arrays in the preceding version,\n",
    "we still create the large output array `f(x,y)` before we compute the max.\n",
    "\n",
    "Let’s try a slightly different approach that takes the max to the inside.\n",
    "\n",
    "Because of this change, we never compute the two-dimensional array `f(x,y)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555677eb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def compute_max_vmap_v2(grid):\n",
    "    # Construct a function that takes the max along each row\n",
    "    f_vec_x_max = lambda y: jnp.max(f(grid, y))\n",
    "    # Vectorize the function so we can call on all rows simultaneously\n",
    "    f_vec_max = jax.vmap(f_vec_x_max)\n",
    "    # Call the vectorized function and take the max\n",
    "    return jnp.max(f_vec_max(grid))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a403210",
   "metadata": {},
   "source": [
    "Here\n",
    "\n",
    "- `f_vec_x_max` computes the max along any given row  \n",
    "- `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.  \n",
    "\n",
    "\n",
    "We apply this function to all rows and then take the max of the row maxes.\n",
    "\n",
    "Let’s try it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb558790",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    z_max = compute_max_vmap_v2(grid).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b34a62b",
   "metadata": {},
   "source": [
    "Let’s run it again to eliminate compilation time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f70070",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    z_max = compute_max_vmap_v2(grid).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cb6cf15",
   "metadata": {},
   "source": [
    "If you are running this on a GPU, as we are, you should see another nontrivial speed gain."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cd46000",
   "metadata": {},
   "source": [
    "### Summary\n",
    "\n",
    "In our view, JAX is the winner for vectorized operations.\n",
    "\n",
    "It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).\n",
    "\n",
    "Moreover, the `vmap` approach can sometimes lead to significantly clearer code.\n",
    "\n",
    "While Numba is impressive, the beauty of JAX is that, with fully vectorized\n",
    "operations, we can run exactly the\n",
    "same code on machines with hardware accelerators and reap all the benefits\n",
    "without extra effort.\n",
    "\n",
    "Moreover, JAX already knows how to effectively parallelize many common array\n",
    "operations, which is key to fast execution.\n",
    "\n",
    "For almost all cases encountered in economics, econometrics, and finance, it is\n",
    "far better to hand over to the JAX compiler for efficient parallelization than to\n",
    "try to hand code these routines ourselves."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0bc7604",
   "metadata": {},
   "source": [
    "## Sequential operations\n",
    "\n",
    "Some operations are inherently sequential – and hence difficult or impossible\n",
    "to vectorize.\n",
    "\n",
    "In this case NumPy is a poor option and we are left with the choice of Numba or\n",
    "JAX.\n",
    "\n",
    "To compare these choices, we will revisit the problem of iterating on the\n",
    "quadratic map that we saw in our [Numba lecture](https://python-programming.quantecon.org/numba.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace86de9",
   "metadata": {},
   "source": [
    "### Numba Version\n",
    "\n",
    "Here’s the Numba version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe8ab917",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "@numba.jit\n",
    "def qm(x0, n, α=4.0):\n",
    "    x = np.empty(n+1)\n",
    "    x[0] = x0\n",
    "    for t in range(n):\n",
    "      x[t+1] = α * x[t] * (1 - x[t])\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc0f4d22",
   "metadata": {},
   "source": [
    "Let’s generate a time series of length 10,000,000 and time the execution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10bc377f",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "n = 10_000_000\n",
    "\n",
    "with qe.Timer(precision=8):\n",
    "    x = qm(0.1, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9946aa4d",
   "metadata": {},
   "source": [
    "Let’s run it again to eliminate compilation time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b17d176",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    x = qm(0.1, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55849dc2",
   "metadata": {},
   "source": [
    "Numba handles this sequential operation very efficiently.\n",
    "\n",
    "Notice that the second run is significantly faster after JIT compilation completes.\n",
    "\n",
    "Numba’s compilation is typically quite fast, and the resulting code performance is excellent for sequential operations like this one."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8d5b898",
   "metadata": {},
   "source": [
    "### JAX Version\n",
    "\n",
    "Now let’s create a JAX version using `lax.scan`:\n",
    "\n",
    "(We’ll hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be979f60",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "from jax import lax\n",
    "from functools import partial\n",
    "\n",
    "cpu = jax.devices(\"cpu\")[0]\n",
    "\n",
    "@partial(jax.jit, static_argnums=(1,), device=cpu)\n",
    "def qm_jax(x0, n, α=4.0):\n",
    "    def update(x, t):\n",
    "        x_new = α * x * (1 - x)\n",
    "        return x_new, x_new\n",
    "\n",
    "    _, x = lax.scan(update, x0, jnp.arange(n))\n",
    "    return jnp.concatenate([jnp.array([x0]), x])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb38f2b6",
   "metadata": {},
   "source": [
    "This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.\n",
    "\n",
    ">**Note**\n",
    ">\n",
    ">Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.\n",
    "\n",
    "The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.\n",
    "\n",
    "As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.\n",
    "\n",
    "Curious readers can try removing this option to see how performance changes.\n",
    "\n",
    "Let’s time it with the same parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ca50bc6",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    x_jax = qm_jax(0.1, n).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a01b297",
   "metadata": {},
   "source": [
    "Let’s run it again to eliminate compilation overhead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3b2f9a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "with qe.Timer(precision=8):\n",
    "    x_jax = qm_jax(0.1, n).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f81fe38",
   "metadata": {},
   "source": [
    "JAX is also efficient for this sequential operation.\n",
    "\n",
    "Both JAX and Numba deliver strong performance after compilation, with Numba\n",
    "typically (but not always) offering slightly better speeds on purely sequential\n",
    "operations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a60894ab",
   "metadata": {},
   "source": [
    "### Summary\n",
    "\n",
    "While both Numba and JAX deliver strong performance for sequential operations,\n",
    "there are significant differences in code readability and ease of use.\n",
    "\n",
    "The Numba version is straightforward and natural to read: we simply allocate an\n",
    "array and fill it element by element using a standard Python loop.\n",
    "\n",
    "This is exactly how most programmers think about the algorithm.\n",
    "\n",
    "The JAX version, on the other hand, requires using `lax.scan`, which is significantly less intuitive.\n",
    "\n",
    "Additionally, JAX’s immutable arrays mean we cannot simply update array elements in place, making it hard to directly replicate the algorithm used by Numba.\n",
    "\n",
    "For this type of sequential operation, Numba is the clear winner in terms of\n",
    "code clarity and ease of implementation, as well as high performance."
   ]
  }
 ],
 "metadata": {
  "date": 1764394126.5387592,
  "filename": "numpy_vs_numba_vs_jax.md",
  "kernelspec": {
   "display_name": "Python",
   "language": "python3",
   "name": "python3"
  },
  "title": "NumPy vs Numba vs JAX"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}