torch>=2.0.0
numpy>=1.21.0

[all]
triton>=2.0.0
jax>=0.4.0

[amd]
triton>=2.0.0

[cuda]
triton>=2.0.0

[dev]
pytest>=7.0.0
black>=23.0.0
isort>=5.12.0

[directml]
torch-directml>=0.2.0

[intel]
intel-extension-for-pytorch>=2.0.0

[tpu]
jax[tpu]>=0.4.0
