Skip to content

Training with JAX MaxText on AAC

This guide explains how to run JAX MaxText training workloads on AMD Accelerator Cloud (AAC) clusters.

Overview

MaxText is a high-performance JAX-based LLM training framework optimized for AMD Instinct GPUs. It provides prebuilt Docker images with JAX, XLA, ROCm libraries, and MaxText utilities for training models like Llama, DeepSeek, and Mixtral.

Prerequisites

  • Access to AAC cluster (MI325X or MI355X)
  • Hugging Face account and access token
  • Basic familiarity with SLURM commands

Supported hardware

  • AMD Instinct MI355X GPUs (MI355X cluster)
  • AMD Instinct MI325X GPUs (MI325X cluster)
  • ROCm 7.2.0 with JAX

Single-node training

Step 1: Allocate a compute node

# For MI325X cluster MI325X
salloc -p 256C8G1H_MI325X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=<ACCOUNT_NAME>

# For MI355X cluster MI355X
salloc -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=<ACCOUNT_NAME>

Example:

salloc -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --account=myteam

Step 2: Load ROCm environment

module load rocm/7.2.0

Step 3: Set environment variables

export MAD_SECRETS_HFTOKEN=<YOUR_HUGGING_FACE_TOKEN>
export HF_HOME=/hf_cache

Step 4: Pull and run Podman container

podman pull docker.io/rocm/jax-training:maxtext-v26.2

podman run -it \
    --device /dev/dri \
    --device /dev/kfd \
    --network host \
    --ipc host \
    --group-add video \
    --cap-add SYS_PTRACE \
    --security-opt seccomp=unconfined \
    --privileged \
    -v $HOME:$HOME \
    -v $HOME/.ssh:/root/.ssh \
    -v /shared/data:/hf_cache \
    -e HF_HOME=/hf_cache \
    -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN \
    --shm-size 64G \
    --name jax_training_env \
    rocm/jax-training:maxtext-v26.2

Step 5: Clone MAD and setup

git clone https://github.com/ROCm/MAD
cd MAD/scripts/jax-maxtext
./jax-maxtext_benchmark_setup.sh -m Llama-2-7B

Step 6: Run training

Without quantization:

./jax-maxtext_benchmark_report.sh -m Llama-2-7B

With FP8 quantization (MI355X/MI325X):

./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8

Multi-node training with SLURM

Step 1: Allocate multiple nodes

# Allocate 4 nodes on MI355X cluster
salloc -N 4 \
  -p 256C8G1H_MI355X_Ubuntu22 \
  --gres=gpu:8 \
  --mem=0 \
  --exclusive \
  --ntasks-per-node=8 \
  --account=<ACCOUNT_NAME>

Example:

salloc -N 4 -p 256C8G1H_MI355X_Ubuntu22 --gres=gpu:8 --mem=0 --exclusive --ntasks-per-node=8 --account=myteam

Step 2: Create SLURM batch script

Create a file train_llama_multinode.sh:

#!/bin/bash
#SBATCH -J jax_maxtext_train
#SBATCH -p 256C8G1H_MI355X_Ubuntu22
#SBATCH --gres=gpu:8
#SBATCH --mem=0
#SBATCH -N 4
#SBATCH --ntasks-per-node=8
#SBATCH --account=<ACCOUNT_NAME>

# Load ROCm
module load rocm/7.2.0

# Set environment variables
export MAD_SECRETS_HFTOKEN=<your_token>
export HF_HOME=/shared/data/hf_cache

# Clone MAD if not already done
if [ ! -d "/shared/data/MAD" ]; then
    cd /shared/data
    git clone https://github.com/ROCm/MAD
fi

# Run multi-node training
cd /shared/data/MAD/scripts/jax-maxtext

srun --container-image=docker://rocm/jax-training:maxtext-v26.2 \
  --container-mounts=$HOME:/workspace,/shared/data:/shared/data \
  --container-workdir=/shared/data/MAD/scripts/jax-maxtext \
  --container-env="MAD_SECRETS_HFTOKEN,HF_HOME" \
  ./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8

Step 3: Submit using sbatch

sbatch train_llama_multinode.sh

Step 4: Monitor the job

# Check job status
squeue -u $USER

# View output
tail -f slurm-<job_id>.out

Using Primus CLI (alternative method)

Step 1: Clone Primus repository

git clone https://github.com/AMD-AIG-AIMA/Primus.git
cd Primus
git checkout dev/fuyuajin/maxtext-backend-test
git submodule update --init third_party/maxtext/

Step 2: Run training in direct mode

./primus-cli direct -- train pretrain \
  --config examples/maxtext/configs/MI300X/llama2_7B-pretrain.yaml

For MI355X, update the config path:

./primus-cli direct -- train pretrain \
  --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml

Supported models

JAX MaxText includes optimized configurations for:

  • Llama family: Llama 2 7B, Llama 3 8B, Llama 3.1 8B/70B, Llama 3.3 70B
  • DeepSeek: DeepSeek-V2-Lite
  • Mixtral: Mixtral 8x7B, Mixtral 8x22B
  • Qwen: Qwen 2.5 models

Configuration files are typically located in: - MAD/scripts/jax-maxtext/env_scripts/ - Primus/examples/maxtext/configs/MI355X/ (when using Primus)

Performance optimization

FP8 quantization

For better performance on MI355X/MI325X, use FP8 quantization:

./jax-maxtext_benchmark_report.sh -m Llama-2-7B -q fp8

Known limitations

Packing Issue: NaNs in losses may occur when setting packing=True in configuration. Workaround: - Disable input sequence packing: set packing=False in your YAML config

Troubleshooting

Container permissions

If you encounter container permission errors on AAC, do not try to change Docker group membership with sudo. AAC supports container workflows through Pyxis and Enroot in Slurm jobs.

  1. Use the Pyxis/Enroot workflow in your job submission instead of trying to access a local Docker daemon on the cluster.
  2. Import or reference your container image with the supported Enroot or Pyxis process.
  3. See Using Enroot with Pyxis for the supported AAC container workflow.

Hugging Face authentication

For gated models, ensure your token is set:

export MAD_SECRETS_HFTOKEN=<your_token>
export HF_HOME=/shared/data/hf_cache

Out of memory errors

  1. Use FP8 quantization instead of BF16
  2. Reduce per_device_batch_size in config
  3. Enable activation checkpointing

Storage recommendations

  • Store models and datasets in /shared/data for multi-node access
  • Set HF_HOME=/shared/data/hf_cache to avoid re-downloading on each node
  • Keep training outputs in $HOME for easy access

External resources