check transformer from shrek guy
https://github.com/patrick-kidger/jaxtyping
https://github.com/HMUNACHI/nanodl
https://github.com/google/aqt
https://github.com/google/flax
https://github.com/NVIDIA/JAX-Toolbox
https://ekzhang.substack.com/p/how-the-jaxjit-jit-compiler-works
https://sankalp.bearblog.dev/einsum-new/
https://github.com/xjdr-alt/simple_transformer