logo
Updated

JAX

transformer from g guy

Patrick C Toulme
Last week I traced Splash Attention through the TPU compiler stack. But before that, I had to figure out how JAX actually compiles to TPU machine code. Google's TPU compiler is closed-source. The IRs are undocumented. So I rented a v6e for $1 and traced every layer: JAX → HLO

https://github.com/MizuhoAOKI/jax_generative_models

https://ekzhang.substack.com/p/how-the-jaxjit-jit-compiler-works

Links

bilal
lots of fun quick tpu/performance engineering interview questions in here :D https://t.co/DLS93zWXNi