A toolkit for training Sparse Autoencoders (SAEs) on transformer models to discover interpretable features inside activations.
π What is a Sparse Autoencoder?
A Sparse Autoencoder is a neural network that learns to represent activations using only a few active features at a time.
Think of it like a compression algorithm: it identifies the most important patterns in the data and represents them through a small set of interpretable features.
βοΈ How it Works
- Input β Collect activations from chosen transformer layers.
- Encoding β Compress activations into a sparse representation (few active features).
- Decoding β Reconstruct original activations from this sparse representation.
- Training β Minimize reconstruction error while maintaining sparsity.
π Key Metrics
- Loss β Reconstruction quality.
- Dead Feature % β % of features that never activate.
- L0 Sparsity β Avg. active features per sample.
- Feature Absorption β Overlap between features (lower = better).
π Installation
bash# Clone repo git clone <repository-url> cd saetrain # Install dependencies pip install -e .
β‘ Quick Start
Option 1: Multi-GPU Training (Recommended)
bashbash Train_sae_scripts_multiGPU.sh
Benefits:
- π 6β8Γ faster on 8 GPUs
- π Larger effective batch size (32 vs 4)
- β‘ DDP (Distributed Data Parallel)
- π Real-time WandB logging
Option 2: Single-GPU Training
bashbash Train_sae_script.sh
Includes:
- Full training pipeline
- Automatic WandB logging
- Post-training health assessment
- SAEBench metric evaluation
Option 3: Manual Training
bashpython -m saetrain bert-base-uncased wikitext \ --layers 6 \ --max_tokens 1000000 \ --k 192 \ --expansion_factor 32
π Multi-GPU Setup
bashexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_LAUNCH_BLOCKING=1 export CUBLAS_WORKSPACE_CONFIG=:4096:8 export PYTHONHASHSEED=42
Torchrun Example
bashtorchrun --nproc_per_node=8 -m saetrain \ bert-base-uncased \ jyanimaulik/yahoo_finance_stockmarket_news \ --layers 6 \ --batch_size 4 \ --k 32 \ --num_latents 200 \ --grad_acc_steps 8 \ --ctx_len 512 \ --save_dir "./output" \ --lr 0.01 \ --run_name "bert_layer6_k32_latents200" \ --log_to_wandb true \ --wandb_log_frequency 10 \ --dead_percentage_threshold 0.1
π Command Line Parameters
Parameter | Type | Default | Description |
model_name | str | - | Transformer model (e.g. bert-base-uncased) |
dataset_name | str | - | Dataset name (HuggingFace or custom) |
--layers | int | 0 | Layer(s) to train on |
--max_tokens | int | 1e6 | Training token budget |
--batch_size | int | 4 | Samples per batch |
--k | int | 192 | Active features per sample |
--expansion_factor | int | 32 | Expansion multiplier for SAE size |
--num_latents | int | - | Direct latent size (overrides expansion_factor) |
--grad_acc_steps | int | 1 | Gradient accumulation steps |
--ctx_len | int | 512 | Context length |
--optimizer | str | "adam" | Optimizer choice |
--lr | float | 0.001 | Learning rate |
--dead_percentage_threshold | float | 0.0005 | Activation threshold for βdeadβ features |
--save_dir | str | "./output" | Save path |
--run_name | str | - | Name for run |
--log_to_wandb | bool | false | Enable WandB logging |
π Health Metrics
- Loss Recovered β₯ 60β70%
- L0 Sparsity within 20β200 active features/sample
- Dead Features β€ 10β20%
- Feature Absorption β€ 0.25
π Example Configurations
Basic Training
bashpython -m saetrain bert-base-uncased wikitext --layers 6 --max_tokens 1e6 --k 192 --expansion_factor 32
Multi-GPU Custom
bashtorchrun --nproc_per_node=8 -m saetrain bert-base-uncased jyanimaulik/yahoo_finance_stockmarket_news \ --layers 6 --batch_size 4 --k 32 --num_latents 200 --grad_acc_steps 8
Training Multiple Layers
bashpython -m saetrain bert-base-uncased wikitext --layers 0 1 2 3 4 5 6 7 8 9 10 11 --max_tokens 1e6
π οΈ Tips
- Start small (
max_tokens=1e6) to debug quickly.
- Default:
expansion_factor=32,k=192.
- Monitor dead feature % β too high means SAE too wide.
- Use gradient accumulation (
grad_acc_steps) to simulate bigger batches.
- Use WandB logging to track metrics live.
π Troubleshooting
- CUDA OOM β lower
batch_sizeor usegrad_acc_steps.
- 100% dead features β lower learning rate, raise threshold.
- Too many dead features β increase
k, reduce expansion factor.
- DDP sync issues β check GPU env variables & torchrun flags.
π Performance Comparison
Setup | GPUs | Speed | Batch Size | Notes |
Single-GPU | 1 | 1Γ | 4 | Baseline |
Multi-GPU | 8 | 6β8Γ | 32 | Standard scaling |
Multi-GPU + Grad Acc | 8 | 6β8Γ | 256 | Max efficiency |
Guidelines for training