Find the most relevant SAE features for a given prompt or text—optionally constrained to a model + SAE + feature list—and return the top-X features with scores, examples, and quick proto-labels. Works for both the input prompt and the model’s generated text.

🎯 What this does

  • You provide:
    • a prompt (or raw text),
    • a model (base or fine-tuned),
    • an SAE (layer(s) & decoder weights),
    • optional candidate feature list and top-K token controls.
  • The tool:
    • runs the model on your prompt and (optionally) generates a continuation
    • captures token-level activations,
    • projects onto SAE features,
    • scores & ranks features,
    • returns the top-X most relevant features with evidence.
Use it to quickly answer: “Which internal concepts did the model use here?” Then jump straight to steering or labeling.

🧩 Inputs & Options

Name
Type
Default
Description
model_id
string
Target model (e.g., cxllin/Llama2-7b-Finance)
sae_id
string
Trained SAE for a specific layer/stream
text
string
Raw text to analyze instead of generating
prompt
string
Prompt to feed to the model (if generating)
generate
bool
true
If true, generate continuation gen_tokens and analyze prompt+gen; else analyze text only
gen_tokens
int
128
Max new tokens to generate for analysis
layers
list[int]
[22]
Layers to analyze (supports multiple)
streams
list[str]
["resid_pre"]
Activation streams (e.g., resid_pre, mlp_out)
restrict_features
list[int]
[]
Optional candidate feature list to constrain search (IDs)
topk_tokens_per_seq
int
64
Only score top-K most “informative” tokens per sequence (entropy/gradient proxy)
return_top_x
int
25
How many features to return after ranking
score_mode
enum
"mean×coverage"
Scoring: "mean", "max", "mean×coverage", "selectivity"
min_coverage_pct
float
0.5
Min % of analyzed tokens with feature active (coverage gate)
merge_across_layers
bool
true
Aggregate scores across chosen layers
attach_examples
bool
true
Return top activating text spans per feature
attach_proto_labels
bool
true
Attach proto-labels if AutoInterp Lite/Full catalogs exist
seed
int
42
Reproducible generation

🧠 How it works (Algorithm)

  1. Run & Capture
      • If generate=true: run model on prompt → capture activations for prompt tokens and generated tokens (up to gen_tokens).
      • Else: tokenize text and pass through the model to capture activations.
  1. Project to SAE
      • For each selected layer/stream, project activations onto SAE dictionary to obtain feature activations (sparse codes).
  1. Token Pre-Filter
      • Keep only top-K informative tokens per sequence (topk_tokens_per_seq) using a proxy (e.g., high logit-entropy or large activation norm) to save time and focus on salient positions.
  1. Aggregate & Score
      • For each feature f, compute:
        • mean_act(f) = mean activation across kept tokens
        • max_act(f), coverage(f) = % tokens where act > τ
        • selectivity(f) = mean_act(domain tokens) − mean_act(background window)
      • Score using score_mode (default: mean×coverage).
  1. Restrict (Optional)
      • If restrict_features provided, only score those IDs (useful when analysts shortlist candidates).
  1. Rank & Return Top-X
      • Sort by score descending, apply min_coverage_pct gate, return return_top_x.
  1. Attach Evidence
      • For each returned feature: add top activating spans (token windows), layer/stream, score breakdown, and proto-label if available.

📦 API (proposed)

POST /v1/sae/features/discover

Body
json
{ "model_id": "cxllin/Llama2-7b-Finance", "sae_id": "sae_l22_resid_pre_v3", "prompt": "Summarize Q2 earnings of ACME Corp and key KPIs.", "generate": true, "gen_tokens": 128, "layers": [22, 28], "streams": ["resid_pre"], "restrict_features": [159, 258, 345, 375, 116], "topk_tokens_per_seq": 64, "return_top_x": 25, "score_mode": "mean×coverage", "min_coverage_pct": 0.5, "merge_across_layers": true, "attach_examples": true, "attach_proto_labels": true, "seed": 42 }
Response (truncated)
json
{ "model_id": "cxllin/Llama2-7b-Finance", "sae_id": "sae_l22_resid_pre_v3", "analyzed": { "mode": "prompt+gen", "tokens": 231, "layers": [22,28] }, "features": [ { "feature_id": 159, "layer": 22, "stream": "resid_pre", "score": 0.91, "scores": { "mean": 0.78, "coverage": 0.83, "max": 2.35, "selectivity": 0.41 }, "proto_label": "Financial performance & growth", "top_spans": [ { "text": "revenue up 18% YoY; EPS beat...", "position": [78,102], "act": 2.35 }, { "text": "margin expansion; guidance raised", "position": [131,155], "act": 2.11 } ] }, { "feature_id": 258, "layer": 22, "stream": "resid_pre", "score": 0.82, "proto_label": "Market indicators & metrics", "top_spans": [ ... ] } ] }

🖥️ CLI (single-prompt discovery)

bash
python -m sae_discover \ --model cxllin/Llama2-7b-Finance \ --sae sae_l22_resid_pre_v3 \ --prompt "Stock market analysis indicates" \ --generate true --gen_tokens 128 \ --layers 22 28 --streams resid_pre \ --restrict_features 159 258 345 375 116 \ --topk_tokens_per_seq 64 \ --return_top_x 25 \ --score_mode "mean×coverage" \ --min_coverage_pct 0.5 \ --merge_across_layers true \ --attach_examples true --attach_proto_labels true \ --seed 42 \ --out results/discover_single.json

Batch mode (file of prompts)

bash
python -m sae_discover \ --model cxllin/Llama2-7b-Finance \ --sae sae_l22_resid_pre_v3 \ --prompts_file data/finance_prompts.jsonl \ --generate true --gen_tokens 128 \ --layers 22 \ --return_top_x 10 \ --out results/discover_batch.jsonl

📊 Scoring Details

Default score (mean×coverage):
plain text
score(f) = mean_act(f) × coverage(f) # coverage ∈ [0,1]
  • Favors features that are consistently active (not just single spikes).
  • Good for prompt+generation analyses.
Alternatives
  • mean: raw mean activation (simple & fast).
  • max: highlights peaky features (good for “signature” detectors).
  • selectivity: difference between in-span vs background windows (helps avoid generic features).
Thresholds
  • A token is “active” for coverage when act > τ (τ defaults to feature-wise P95 of base activations; configurable).

🧪 Output Schema (JSON & CSV)

JSON (per feature)
json
{ "feature_id": 159, "layer": 22, "stream": "resid_pre", "score": 0.91, "scores": { "mean": 0.78, "coverage": 0.83, "max": 2.35, "selectivity": 0.41 }, "proto_label": "Financial performance & growth", "top_spans": [ {"text": "...", "position":[start,end], "act":2.35} ] }
CSV columns
plain text
feature_id, layer, stream, score, mean, coverage, max, selectivity, proto_label, top_span_text

🔗 Integrations

  • Labeling: Pipe feature_id list to AutoInterp Full for validated labels (F1/precision/recall).
  • Steering: Send selected features to SAE Steering Tool for interactive amplification/suppression.
  • Alignment: Compare feature ranks across base vs fine-tuned models to flag drift.

⚡ Performance & Caching

  • Use topk_tokens_per_seq to keep runs fast on long generations.
  • Cache tokenization, activations, and projection per prompt hash.
  • When restricting to a feature list, only gather codes for those features for speed.

✅ Best Practices

  1. Prompt+Gen usually reveals more features than text-only.
  1. Start with one strong layer (e.g., 22) then add a second (28) and compare.
  1. Use restrict_features for analyst shortlists (e.g., from Lite).
  1. Prefer mean×coverage for ranking; fall back to max to discover sharp detectors.
  1. Always inspect top spans—they make the semantics obvious.

🧯 Troubleshooting

  • Too many generic features at top: switch score_mode to selectivity, raise min_coverage_pct.
  • Sparse outputs, few features passing gates: lower activation threshold τ or increase topk_tokens_per_seq.
  • No visible layer effect: try later layers (22/28) or mlp_out stream.
  • Runtime high: disable generation (generate=false) for quick prompt-only scans, or restrict features.

✨ Example (end-to-end)

Goal: For an earnings prompt, find the top 15 features active in the response and the prompt, limited to a candidate list.
bash
python -m sae_discover \ --model cxllin/Llama2-7b-Finance \ --sae sae_l22_resid_pre_v3 \ --prompt "Summarize Q2 earnings for ACME: revenue, EPS, guidance, and margin drivers." \ --generate true --gen_tokens 160 \ --layers 22 28 --streams resid_pre \ --restrict_features 159 258 345 375 116 402 417 423 518 612 705 711 802 809 900 \ --return_top_x 15 \ --score_mode "mean×coverage" \ --attach_examples true --attach_proto_labels true \ --out results/earnings_feature_discovery.json
This returns a ranked list of features (with spans & proto-labels) that you can immediately hand to Steering or AutoInterp Full.