jax>=0.4.1
jaxlib>=0.4.1

[dev]
pytest==7.2
