Skip to content

Investigation: JAX GPU Performance in Jupyter Notebook Execution #262

@mmcky

Description

@mmcky

Summary

During the implementation of RunsOn GPU support for lecture-python-programming.myst (PR QuantEcon/lecture-python-programming.myst#437), we discovered unexpected performance differences between GPU and CPU runners.

Key Finding: A lecture that took 15.56 seconds on CPU took 176.67 seconds on GPU - despite the GPU being correctly detected and working.

Hardware Benchmark Results

We ran identical benchmarks on both GitHub Actions (CPU) and RunsOn (GPU) to diagnose the performance differences.

System Information

Metric GitHub Actions (CPU) RunsOn (GPU)
Platform Linux Azure Linux AWS
CPU AMD EPYC 7763 64-Core @ 3281 MHz Intel Xeon Platinum 8259CL @ 2500 MHz
CPU Cores 4 8
GPU None Tesla T4 (15360 MiB)

CPU Performance (Pure Python)

Benchmark GitHub Actions RunsOn Winner
Integer sum (10M) 0.599 sec 0.881 sec GitHub Actions 1.5x faster
Float sqrt (1M) 0.077 sec 0.107 sec GitHub Actions 1.4x faster

CPU Performance (NumPy)

Benchmark GitHub Actions RunsOn Winner
Matrix multiply (3000x3000) 0.642 sec 0.224 sec RunsOn 2.9x faster
Element-wise (50M) 1.686 sec 1.768 sec ~Same

CPU Performance (Numba)

Benchmark GitHub Actions RunsOn Winner
Integer sum warm-up 0.320 sec 0.300 sec ~Same
Integer sum compiled 0.000 sec 0.000 sec Same
Parallel sum warm-up 0.344 sec 0.382 sec ~Same
Parallel sum compiled 0.012 sec 0.010 sec ~Same

JAX Performance (Direct Script Execution)

Benchmark GitHub Actions (CPU) RunsOn (GPU) Winner
1000x1000 warm-up 0.030 sec 0.079 sec GitHub Actions
1000x1000 compiled 0.011 sec 0.001 sec GPU 11x faster
3000x3000 warm-up 0.426 sec 0.645 sec GitHub Actions
3000x3000 compiled 0.276 sec 0.009 sec GPU 30x faster
50M element-wise warm-up 0.816 sec 0.118 sec GPU 7x faster
50M element-wise compiled 0.381 sec 0.002 sec GPU 190x faster

Key Findings

  1. Pure Python is slower on RunsOn - The Intel Xeon @ 2.5 GHz is slower than the AMD EPYC @ 3.3 GHz for single-threaded Python (1.4-1.5x slower).

  2. NumPy matrix ops are faster on RunsOn - Likely due to 8 cores vs 4 cores for BLAS parallelization (2.9x faster).

  3. GPU (JAX compiled) is massively faster in direct execution - 30-190x faster for compiled operations!

  4. JIT compilation overhead is higher on GPU - Warm-up times are longer on GPU due to CUDA kernel compilation.

The Mystery: Why is Jupyter Execution So Much Slower?

The benchmark script (direct Python execution) shows GPU is working correctly with massive speedups. However, when the same JAX code runs through Jupyter Book / Jupyter kernel execution, the total time is ~11x slower than CPU.

Possible Causes to Investigate

  1. Jupyter kernel overhead - Each cell execution may trigger additional overhead
  2. JIT cache not persisting - JAX compiled kernels may not persist between cells
  3. Multiple recompilations - Different array sizes in lectures trigger recompilation
  4. Memory management - Jupyter may handle GPU memory differently
  5. Cell isolation - Each cell may create new JAX contexts

Relevant Lectures

  • numpy_vs_numba_vs_jax.md - 176.67 sec (GPU) vs 15.56 sec (CPU)
  • jax_intro.md - Similar pattern expected

Benchmark Script

The benchmark script is available at:

Related PRs

Next Steps

  1. Investigate Jupyter kernel execution overhead with JAX
  2. Consider if JAX JIT cache can be preserved across cells
  3. Test if running JAX lectures as scripts (vs notebooks) shows different performance
  4. Evaluate if this affects other lecture series using JAX

/cc @mmcky

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions