A toolkit for training Sparse Autoencoders (SAEs) on transformer models to discover interpretable features inside activations.

πŸ“˜ What is a Sparse Autoencoder?

A Sparse Autoencoder is a neural network that learns to represent activations using only a few active features at a time.
Think of it like a compression algorithm: it identifies the most important patterns in the data and represents them through a small set of interpretable features.

βš™οΈ How it Works

  1. Input β†’ Collect activations from chosen transformer layers.
  1. Encoding β†’ Compress activations into a sparse representation (few active features).
  1. Decoding β†’ Reconstruct original activations from this sparse representation.
  1. Training β†’ Minimize reconstruction error while maintaining sparsity.

πŸ“Š Key Metrics

  • Loss β†’ Reconstruction quality.
  • Dead Feature % β†’ % of features that never activate.
  • L0 Sparsity β†’ Avg. active features per sample.
  • Feature Absorption β†’ Overlap between features (lower = better).

πŸš€ Installation

bash
# Clone repo git clone <repository-url> cd saetrain # Install dependencies pip install -e .

⚑ Quick Start

Option 1: Multi-GPU Training (Recommended)

bash
bash Train_sae_scripts_multiGPU.sh
Benefits:
  • πŸš€ 6–8Γ— faster on 8 GPUs
  • πŸ“Š Larger effective batch size (32 vs 4)
  • ⚑ DDP (Distributed Data Parallel)
  • πŸ“ˆ Real-time WandB logging

Option 2: Single-GPU Training

bash
bash Train_sae_script.sh
Includes:
  • Full training pipeline
  • Automatic WandB logging
  • Post-training health assessment
  • SAEBench metric evaluation

Option 3: Manual Training

bash
python -m saetrain bert-base-uncased wikitext \ --layers 6 \ --max_tokens 1000000 \ --k 192 \ --expansion_factor 32

πŸ”Œ Multi-GPU Setup

bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_LAUNCH_BLOCKING=1 export CUBLAS_WORKSPACE_CONFIG=:4096:8 export PYTHONHASHSEED=42

Torchrun Example

bash
torchrun --nproc_per_node=8 -m saetrain \ bert-base-uncased \ jyanimaulik/yahoo_finance_stockmarket_news \ --layers 6 \ --batch_size 4 \ --k 32 \ --num_latents 200 \ --grad_acc_steps 8 \ --ctx_len 512 \ --save_dir "./output" \ --lr 0.01 \ --run_name "bert_layer6_k32_latents200" \ --log_to_wandb true \ --wandb_log_frequency 10 \ --dead_percentage_threshold 0.1

πŸ“‹ Command Line Parameters

Parameter
Type
Default
Description
model_name
str
-
Transformer model (e.g. bert-base-uncased)
dataset_name
str
-
Dataset name (HuggingFace or custom)
--layers
int
0
Layer(s) to train on
--max_tokens
int
1e6
Training token budget
--batch_size
int
4
Samples per batch
--k
int
192
Active features per sample
--expansion_factor
int
32
Expansion multiplier for SAE size
--num_latents
int
-
Direct latent size (overrides expansion_factor)
--grad_acc_steps
int
1
Gradient accumulation steps
--ctx_len
int
512
Context length
--optimizer
str
"adam"
Optimizer choice
--lr
float
0.001
Learning rate
--dead_percentage_threshold
float
0.0005
Activation threshold for β€œdead” features
--save_dir
str
"./output"
Save path
--run_name
str
-
Name for run
--log_to_wandb
bool
false
Enable WandB logging

πŸ“ˆ Health Metrics

  • Loss Recovered β‰₯ 60–70%
  • L0 Sparsity within 20–200 active features/sample
  • Dead Features ≀ 10–20%
  • Feature Absorption ≀ 0.25

πŸ” Example Configurations

Basic Training
bash
python -m saetrain bert-base-uncased wikitext --layers 6 --max_tokens 1e6 --k 192 --expansion_factor 32
Multi-GPU Custom
bash
torchrun --nproc_per_node=8 -m saetrain bert-base-uncased jyanimaulik/yahoo_finance_stockmarket_news \ --layers 6 --batch_size 4 --k 32 --num_latents 200 --grad_acc_steps 8
Training Multiple Layers
bash
python -m saetrain bert-base-uncased wikitext --layers 0 1 2 3 4 5 6 7 8 9 10 11 --max_tokens 1e6

πŸ› οΈ Tips

  1. Start small (max_tokens=1e6) to debug quickly.
  1. Default: expansion_factor=32, k=192.
  1. Monitor dead feature % β€” too high means SAE too wide.
  1. Use gradient accumulation (grad_acc_steps) to simulate bigger batches.
  1. Use WandB logging to track metrics live.

🐞 Troubleshooting

  • CUDA OOM β†’ lower batch_size or use grad_acc_steps.
  • 100% dead features β†’ lower learning rate, raise threshold.
  • Too many dead features β†’ increase k, reduce expansion factor.
  • DDP sync issues β†’ check GPU env variables & torchrun flags.

πŸ“Š Performance Comparison

Setup
GPUs
Speed
Batch Size
Notes
Single-GPU
1
1Γ—
4
Baseline
Multi-GPU
8
6–8Γ—
32
Standard scaling
Multi-GPU + Grad Acc
8
6–8Γ—
256
Max efficiency

Guidelines for training