Training with JAX MaxText on AAC
This guide explains how to run JAX MaxText training workloads on AMD Accelerator Cloud (AAC) clusters.
Overview
MaxText is a high-performance JAX-based LLM training framework optimized for AMD Instinct GPUs. It provides prebuilt Docker images with JAX, XLA, ROCm libraries, and MaxText utilities for training models like Llama, DeepSeek, and Mixtral.
Prerequisites
- Access to AAC cluster (MI325X or MI355X)
- Hugging Face account and access token
- Basic familiarity with SLURM commands
Supported hardware
- AMD Instinct MI355X GPUs (MI355X cluster)
- AMD Instinct MI325X GPUs (MI325X cluster)
- ROCm 7.2.0 with JAX
Single-node training
Step 1: Allocate a compute node
# For MI325X cluster MI325X
salloc -p 256C8G1H_MI325X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=<ACCOUNT_NAME>
# For MI355X cluster MI355X
salloc -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=<ACCOUNT_NAME>
Example:
salloc -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=myteam
Step 2: Load ROCm environment
module load rocm/7.2.0
Step 3: Set environment variables
export MAD_SECRETS_HFTOKEN=<YOUR_HUGGING_FACE_TOKEN>
export HF_HOME=/hf_cache
Step 4: Pull and run Podman container
podman pull docker.io/rocm/jax-training:maxtext-v26.2
podman run -it \
--device /dev/dri \
--device /dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME:$HOME \
-v $HOME/.ssh:/root/.ssh \
-v /shared/data:/hf_cache \
-e HF_HOME=/hf_cache \
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN \
--shm-size 64G \
--name jax_training_env \
rocm/jax-training:maxtext-v26.2
Step 5: Clone MAD and setup
git clone https://github.com/ROCm/MAD
cd MAD/scripts/jax-maxtext
./jax-maxtext_benchmark_setup.sh -m Llama-2-7B
Step 6: Run training
Without quantization:
./jax-maxtext_benchmark_report.sh -m Llama-2-7B
With FP8 quantization (MI355X/MI325X):
./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
Multi-node training with SLURM
Step 1: Allocate multiple nodes
# Allocate 4 nodes on MI355X cluster
salloc -N 4 \
-p 256C8G1H_MI355X_Ubuntu22 \
--gres=gpu:8 \
--mem=0 \
--exclusive \
--ntasks-per-node=8 \
--account=<ACCOUNT_NAME>
Example:
salloc -N 4 -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --ntasks-per-node=8 --account=myteam
Step 2: Create SLURM batch script
Create a file train_llama_multinode.sh:
#!/bin/bash
#SBATCH -J jax_maxtext_train
#SBATCH -p 256C8G1H_MI355X_Ubuntu22
#SBATCH --gres=gpu:8
#SBATCH --mem=0
#SBATCH -N 4
#SBATCH --ntasks-per-node=8
#SBATCH --account=<ACCOUNT_NAME>
# Load ROCm
module load rocm/7.2.0
# Set environment variables
export MAD_SECRETS_HFTOKEN=<your_token>
export HF_HOME=/shared/data/hf_cache
# Clone MAD if not already done
if [ ! -d "/shared/data/MAD" ]; then
cd /shared/data
git clone https://github.com/ROCm/MAD
fi
# Run multi-node training
cd /shared/data/MAD/scripts/jax-maxtext
srun --container-image=docker://rocm/jax-training:maxtext-v26.2 \
--container-mounts=$HOME:/workspace,/shared/data:/shared/data \
--container-workdir=/shared/data/MAD/scripts/jax-maxtext \
--container-env="MAD_SECRETS_HFTOKEN,HF_HOME" \
./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
Step 3: Submit using sbatch
sbatch train_llama_multinode.sh
Step 4: Monitor the job
# Check job status
squeue -u $USER
# View output
tail -f slurm-<job_id>.out
Using Primus CLI (alternative method)
Step 1: Clone Primus repository
git clone https://github.com/AMD-AIG-AIMA/Primus.git
cd Primus
git checkout dev/fuyuajin/maxtext-backend-test
git submodule update --init third_party/maxtext/
Step 2: Run training in direct mode
./primus-cli direct -- train pretrain \
--config examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml
For MI355X, update the config path:
./primus-cli direct -- train pretrain \
--config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml
Supported models
JAX MaxText includes optimized configurations for:
- Llama family: Llama 2 7B, Llama 3 8B, Llama 3.1 8B/70B, Llama 3.3 70B
- DeepSeek: DeepSeek-V2-Lite
- Mixtral: Mixtral 8x7B, Mixtral 8x22B
- Qwen: Qwen 2.5 models
Configuration files are typically located in:
- MAD/scripts/jax-maxtext/env_scripts/
- Primus/examples/maxtext/configs/MI355X/ (when using Primus)
Performance optimization
FP8 quantization
For better performance on MI355X/MI325X, use FP8 quantization:
./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8
Known limitations
Packing Issue: NaNs in losses may occur when setting packing=True in configuration. Workaround:
- Disable input sequence packing: set packing=False in your YAML config
Troubleshooting
Container permissions
If you encounter container permission errors on AAC, do not try to change Docker group membership with sudo. AAC supports container workflows through Pyxis and Enroot in Slurm jobs.
- Use the Pyxis/Enroot workflow in your job submission instead of trying to access a local Docker daemon on the cluster.
- Import or reference your container image with the supported Enroot or Pyxis process.
- See Using Enroot with Pyxis for the supported AAC container workflow.
Hugging Face authentication
For gated models, ensure your token is set:
export MAD_SECRETS_HFTOKEN=<your_token>
export HF_HOME=/shared/data/hf_cache
Out of memory errors
- Use FP8 quantization instead of BF16
- Reduce
per_device_batch_sizein config - Enable activation checkpointing
Storage recommendations
- Store models and datasets in
/shared/datafor multi-node access - Set
HF_HOME=/shared/data/hf_cacheto avoid re-downloading on each node - Keep training outputs in
$HOMEfor easy access
Related documentation
- AAC Slurm Cluster User Guide
- Using Enroot with Pyxis
- Storage and Shared Filesystems
- Node Reference Guide