From c050b5a3b2ff7d464a6ecf1516aed57a05d27cac Mon Sep 17 00:00:00 2001 From: mathreader Date: Fri, 8 Aug 2025 02:02:16 +0800 Subject: [PATCH 1/5] add sae-dashboard visualization --- delphi/config.py | 3 + delphi/latents/constructors.py | 2 + delphi/latents/latents.py | 3 + examples/convert_sae_dashboard.py | 304 ++++++++++++++++++++++++++++++ pyproject.toml | 5 +- 5 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 examples/convert_sae_dashboard.py diff --git a/delphi/config.py b/delphi/config.py index 6e49b09d..68411322 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -69,6 +69,9 @@ class ConstructorConfig(Serializable): ] = "co-occurrence" """Type of neighbours to use. Only used if non_activating_source is 'neighbours'.""" + save_activation_data: bool = False + """Whether to keep the origianl activation data in the record.""" + @dataclass class CacheConfig(Serializable): diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index a65bcb32..96a44c36 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -329,6 +329,8 @@ def constructor( else: raise ValueError(f"Invalid non-activating source: {source_non_activating}") record.not_active = non_activating_examples + if constructor_cfg.save_activation_data: + record.activation_data = activation_data return record diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 0f4ff94d..dba3a595 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -150,6 +150,9 @@ class LatentRecord: extra_examples: Optional[list[Example]] = None """Extra examples to include in the record.""" + activation_data: Optional[ActivationData] = None + """Activation data for the latent, if it was saved.""" + per_token_frequency: float = 0.0 """Frequency of the latent. Number of activations per total number of tokens.""" diff --git a/examples/convert_sae_dashboard.py b/examples/convert_sae_dashboard.py new file mode 100644 index 00000000..ab64c44a --- /dev/null +++ b/examples/convert_sae_dashboard.py @@ -0,0 +1,304 @@ +# %% +from sae_dashboard.components import ( + ActsHistogramData, + DecoderWeightsDistribution, + FeatureTablesData, + LogitsHistogramData, + SequenceData, + SequenceGroupData, + SequenceMultiGroupData, +) +from sae_dashboard.data_parsing_fns import get_logits_table_data +from sae_dashboard.data_writing_fns import save_feature_centric_vis +from sae_dashboard.feature_data import FeatureData +from sae_dashboard.sae_vis_data import SaeVisConfig, SaeVisData +from sae_dashboard.utils_fns import FeatureStatistics + +try: + from itertools import batched +except ImportError: + from more_itertools import chunked as batched + # Fallback for older Python + # Using more_itertools for compatibility +import gc +from argparse import Namespace +from dataclasses import replace +from pathlib import Path + +import numpy as np +import torch +from sae_dashboard.utils_fns import ASYMMETRIC_RANGES_AND_PRECISIONS +from simple_parsing import ArgumentParser +from tqdm.auto import tqdm +from transformers import AutoConfig, AutoModelForCausalLM + +from delphi.config import ConstructorConfig, SamplerConfig +from delphi.latents import LatentDataset +from delphi.sparse_coders.sparse_model import load_sparsify_sparse_coders + +torch.set_grad_enabled(False) + +parser = ArgumentParser(description="Convert SAE data for dashboard visualization") +parser.add_argument( + "--module", type=str, default="model.layers.9", help="Model module to analyze" +) +parser.add_argument( + "--latents", type=int, default=5, help="Number of latents to process" +) +parser.add_argument( + "--cache-path", + type=str, + default="../results/sae_pkm/baseline", + help="Path to cached activations", +) +parser.add_argument( + "--sae-path", + type=str, + default="../halutsae/sae-pkm/smollm/baseline", + help="Path to trained SAE model", +) +parser.add_argument( + "--out-path", + type=str, + default="results/latent_dashboard.html", + help="Path to save the dashboard", +) +parser.add_arguments(ConstructorConfig, dest="constructor_cfg") +parser.add_arguments( + SamplerConfig, + dest="sampler_cfg", + default=SamplerConfig(n_examples_train=40, n_quantiles=10, train_type="quantiles"), +) +args = parser.parse_args() +constructor_cfg = args.constructor_cfg +sampler_cfg = args.sampler_cfg +out_path = args.out_path + +# %% +module = args.module +n_latents = args.latents +start_latent = 0 +latent_dict = {f"{module}": torch.arange(start_latent, start_latent + n_latents)} +kwargs = dict( + raw_dir=args.cache_path, + modules=[module], + latents=latent_dict, + sampler_cfg=sampler_cfg, + constructor_cfg=constructor_cfg, +) + + +# Utility function to set activation buffer on a record (not used in main loop) +# def set_record_buffer(record, *, latent_data): +# record.buffer = latent_data.activation_data +# return record + + +raw_loader = LatentDataset( + **( + kwargs + | {"constructor_cfg": replace(constructor_cfg, save_activation_data=True)} + ) # type: ignore +) +# %% +print("Loading model") +model_name = raw_loader.cache_config["model_name"] +cache_lm = AutoModelForCausalLM.from_pretrained( + model_name, trust_remote_code=True, device_map="cpu" +) +# %% + +# Try to construct lm_head for different HuggingFace architectures +if hasattr(cache_lm, "model") and hasattr(cache_lm.model, "norm") and hasattr(cache_lm.model, "lm_head"): + # llama models + lm_head = torch.nn.Sequential(cache_lm.model.norm, cache_lm.model.lm_head) +elif hasattr(cache_lm, "gpt_neox") and hasattr(cache_lm, "embed_out"): + # pythia models + lm_head = torch.nn.Sequential(cache_lm.gpt_neox.final_layer_norm, cache_lm.embed_out) +else: + raise ValueError("Unknown model architecture for extracting lm_head.") + +# %% +lm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) +# %% +# sae_path = Path(args.sae_path) +# # heuristic to find the right layer +# module_raw = next(iter(sae_path.glob("*" + args.module.partition(".")[2] + "*"))).name +# hookpoint_to_sparse_model = load_sparsify_sparse_coders( +# sae_path, +# [module_raw], +# "cuda", +# compile=False, +# ) +# transcoder = hookpoint_to_sparse_model[module_raw] +# w_dec = transcoder.W_dec.data +# latent_to_resid = w_dec + +# sae_path = Path(args.sae_path) +# heuristic to find the right layer +# module_raw = next(iter(sae_path.glob("*" + args.module.partition(".")[2] + "*"))).name +hookpoint_to_sparse_model = load_sparsify_sparse_coders( + args.sae_path, + [args.module], + "cuda", + compile=False, +) +transcoder = hookpoint_to_sparse_model[args.module] +w_dec = transcoder.W_dec.data +latent_to_resid = w_dec +# %% +del cache_lm +gc.collect() +# %% +tokens = raw_loader.buffers[0].load()[-1] +n_sequences, max_seq_len = tokens.shape +# %% + +cfg = SaeVisConfig( + hook_point=args.module, + minibatch_size_tokens=raw_loader.cache_config["cache_ctx_len"], + features=[], + # batch_size=dataset.cache_config["batch_size"], +) +layout = cfg.feature_centric_layout + +ranges_and_precisions = ASYMMETRIC_RANGES_AND_PRECISIONS +quantiles = [] +for r, p in ranges_and_precisions: + start, end = r + step = 10**-p + quantiles.extend(np.arange(start, end - 0.5 * step, step)) +quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32) + +# %% +latent_data_dict = {} + +latent_stats = FeatureStatistics() +# supposed_latent = 0 +bar = tqdm(total=args.latents) +i = -1 +for record in raw_loader: + i += 1 + + # https://github.com/jbloomAus/SAEDashboard/blob/main/sae_dashboard/utils_fns.py + assert record.activation_data is not None + latent_id = record.activation_data.locations[0, 2].item() + decoder_resid = latent_to_resid[latent_id].to( + record.activation_data.activations.device + ) + logit_vector = lm_head(decoder_resid) + + buffer = record.activation_data + activations, locations = buffer.activations, buffer.locations + _max = activations.max() + nonzero_mask = activations.abs() > 1e-6 + nonzero_acts = activations[nonzero_mask] + frac_nonzero = nonzero_mask.sum() / (n_sequences * max_seq_len) + quantile_data = torch.quantile(activations.float(), quantiles_tensor) + skew = torch.mean((activations - activations.mean()) ** 3) / ( + activations.std() ** 3 + ) + kurtosis = torch.mean((activations - activations.mean()) ** 4) / ( + activations.std() ** 4 + ) + latent_stats.update( + FeatureStatistics( + max=[_max.item()], + frac_nonzero=[frac_nonzero.item()], + skew=[skew.item()], + kurtosis=[kurtosis.item()], + quantile_data=[quantile_data.unsqueeze(0).tolist()], + quantiles=quantiles + [1.0], + ranges_and_precisions=ranges_and_precisions, + ) + ) + + latent_data = FeatureData() + latent_data.feature_tables_data = FeatureTablesData() + latent_data.logits_histogram_data = LogitsHistogramData.from_data( + data=logit_vector.to(torch.float32), # need this otherwise fails on MPS + n_bins=layout.logits_hist_cfg.n_bins, # type: ignore + tickmode="5 ticks", + title=None, + ) + latent_data.acts_histogram_data = ActsHistogramData.from_data( + data=nonzero_acts.to(torch.float32), + n_bins=layout.act_hist_cfg.n_bins, + tickmode="5 ticks", + title=f"ACTIVATIONS
DENSITY = {frac_nonzero:.3%}", + ) + latent_data.logits_table_data = get_logits_table_data( + logit_vector=logit_vector, + n_rows=layout.logits_table_cfg.n_rows, # type: ignore + ) + latent_data.decoder_weights_data = DecoderWeightsDistribution( + len(decoder_resid), decoder_resid.tolist() + ) + latent_data_dict[latent_id] = latent_data + # supposed_latent += 1 + bar.update(1) + bar.refresh() +bar.close() + +latent_list = latent_dict[module].tolist() +cfg.features = latent_list +# %% +n_quantiles = sampler_cfg.n_quantiles +sequence_loader = LatentDataset( + **kwargs | dict(sampler_cfg=sampler_cfg) # type: ignore +) +bar = tqdm(total=args.latents) +for record in sequence_loader: + groups = [] + for quantile_index, quantile_data in enumerate( + list(batched(record.train, len(record.train) // n_quantiles))[::-1] + ): + group = [] + for example in quantile_data: + default_list = [0.0] * len(example.tokens) + logit_list = [[0.0]] * len(default_list) + token_list = [[0]] * len(default_list) + default_attrs = dict( + loss_contribution=default_list, + token_logits=default_list, + top_token_ids=token_list, + top_logits=logit_list, + bottom_token_ids=token_list, + bottom_logits=logit_list, + ) + group.append( + SequenceData( + token_ids=example.tokens.tolist(), + feat_acts=example.activations.tolist(), + **default_attrs, + ) + ) + groups.append( + SequenceGroupData( + title=f"Quantile {quantile_index/n_quantiles:1%}" + f"-{(quantile_index+1)/n_quantiles:1%}", + seq_data=group, + ) + ) + latent_data_dict[record.latent.latent_index].sequence_data = SequenceMultiGroupData( + seq_group_data=groups + ) + bar.update(1) + bar.refresh() +bar.close() +# %% +latent_list = list(latent_data_dict.keys()) +tokenizer = raw_loader.tokenizer +model = Namespace( + tokenizer=tokenizer, +) + +sae_vis_data = SaeVisData( + cfg=cfg, + feature_data_dict=latent_data_dict, + feature_stats=latent_stats, + model=model, +) +print("Saving dashboard to", out_path) +save_feature_centric_vis(sae_vis_data=sae_vis_data, filename=out_path) +# %% \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0c20b748..c70414e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,10 @@ visualize = [ "kaleido==0.2.1", "plotly>=5.0.0rc2", "pandas", - "ipywidgets" + "ipywidgets", + "transformer-lens<=2.15.4", + "sae-dashboard==0.6.11", + "more-itertools>=10.6.0" ] [tool.pyright] From 3921030c79be05ea88fe11099b3beb2c11b7b2e2 Mon Sep 17 00:00:00 2001 From: mathreader Date: Wed, 13 Aug 2025 01:34:15 +0800 Subject: [PATCH 2/5] update visualization code for readability --- examples/convert_sae_dashboard.py | 582 ++++++++++++++++-------------- 1 file changed, 314 insertions(+), 268 deletions(-) diff --git a/examples/convert_sae_dashboard.py b/examples/convert_sae_dashboard.py index ab64c44a..2aa4d4de 100644 --- a/examples/convert_sae_dashboard.py +++ b/examples/convert_sae_dashboard.py @@ -1,4 +1,14 @@ -# %% +import gc +from typing import Tuple, List, Dict, Any +from argparse import Namespace +from dataclasses import replace +from pathlib import Path +import numpy as np +import torch +from tqdm.auto import tqdm +from simple_parsing import ArgumentParser +from transformers import AutoModelForCausalLM + from sae_dashboard.components import ( ActsHistogramData, DecoderWeightsDistribution, @@ -12,293 +22,329 @@ from sae_dashboard.data_writing_fns import save_feature_centric_vis from sae_dashboard.feature_data import FeatureData from sae_dashboard.sae_vis_data import SaeVisConfig, SaeVisData -from sae_dashboard.utils_fns import FeatureStatistics +from sae_dashboard.utils_fns import FeatureStatistics, ASYMMETRIC_RANGES_AND_PRECISIONS +from delphi.config import ConstructorConfig, SamplerConfig +from delphi.latents import LatentDataset +from delphi.sparse_coders.sparse_model import load_sparsify_sparse_coders try: from itertools import batched except ImportError: from more_itertools import chunked as batched - # Fallback for older Python - # Using more_itertools for compatibility -import gc -from argparse import Namespace -from dataclasses import replace -from pathlib import Path -import numpy as np -import torch -from sae_dashboard.utils_fns import ASYMMETRIC_RANGES_AND_PRECISIONS -from simple_parsing import ArgumentParser -from tqdm.auto import tqdm -from transformers import AutoConfig, AutoModelForCausalLM -from delphi.config import ConstructorConfig, SamplerConfig -from delphi.latents import LatentDataset -from delphi.sparse_coders.sparse_model import load_sparsify_sparse_coders +def parse_args() -> Namespace: + """ + Parse command-line arguments. -torch.set_grad_enabled(False) + Returns: + Namespace: Parsed arguments with all configuration options. + """ + parser = ArgumentParser(description="Convert SAE data for dashboard visualization") + parser.add_argument( + "--module", type=str, default="model.layers.9", help="Model module to analyze" + ) + parser.add_argument( + "--latents", type=int, default=5, help="Number of latents to process" + ) + parser.add_argument( + "--cache-path", + type=str, + default="../results/sae_pkm/baseline", + help="Path to cached activations", + ) + parser.add_argument( + "--sae-path", + type=str, + default="../halutsae/sae-pkm/smollm/baseline", + help="Path to trained SAE model", + ) + parser.add_argument( + "--out-path", + type=str, + default="results/latent_dashboard.html", + help="Path to save the dashboard", + ) + parser.add_arguments(ConstructorConfig, dest="constructor_cfg") + parser.add_arguments( + SamplerConfig, + dest="sampler_cfg", + default=SamplerConfig(n_examples_train=25, n_quantiles=5, train_type="quantiles"), + ) + return parser.parse_args() -parser = ArgumentParser(description="Convert SAE data for dashboard visualization") -parser.add_argument( - "--module", type=str, default="model.layers.9", help="Model module to analyze" -) -parser.add_argument( - "--latents", type=int, default=5, help="Number of latents to process" -) -parser.add_argument( - "--cache-path", - type=str, - default="../results/sae_pkm/baseline", - help="Path to cached activations", -) -parser.add_argument( - "--sae-path", - type=str, - default="../halutsae/sae-pkm/smollm/baseline", - help="Path to trained SAE model", -) -parser.add_argument( - "--out-path", - type=str, - default="results/latent_dashboard.html", - help="Path to save the dashboard", -) -parser.add_arguments(ConstructorConfig, dest="constructor_cfg") -parser.add_arguments( - SamplerConfig, - dest="sampler_cfg", - default=SamplerConfig(n_examples_train=40, n_quantiles=10, train_type="quantiles"), -) -args = parser.parse_args() -constructor_cfg = args.constructor_cfg -sampler_cfg = args.sampler_cfg -out_path = args.out_path - -# %% -module = args.module -n_latents = args.latents -start_latent = 0 -latent_dict = {f"{module}": torch.arange(start_latent, start_latent + n_latents)} -kwargs = dict( - raw_dir=args.cache_path, - modules=[module], - latents=latent_dict, - sampler_cfg=sampler_cfg, - constructor_cfg=constructor_cfg, -) +def get_lm_head(cache_lm: AutoModelForCausalLM) -> torch.nn.Sequential: + """ + Get the LM head for different HuggingFace architectures. -# Utility function to set activation buffer on a record (not used in main loop) -# def set_record_buffer(record, *, latent_data): -# record.buffer = latent_data.activation_data -# return record + Args: + cache_lm (AutoModelForCausalLM): Loaded language model. + Returns: + torch.nn.Sequential: The LM head module for logits computation. + """ + if hasattr(cache_lm, "model") and hasattr(cache_lm.model, "norm") and hasattr(cache_lm, "lm_head"): + # llama models + return torch.nn.Sequential(cache_lm.model.norm, cache_lm.lm_head) + elif hasattr(cache_lm, "gpt_neox") and hasattr(cache_lm, "embed_out"): + # pythia models + return torch.nn.Sequential(cache_lm.gpt_neox.final_layer_norm, cache_lm.embed_out) + else: + raise ValueError("Unknown model architecture for extracting lm_head.") -raw_loader = LatentDataset( - **( - kwargs - | {"constructor_cfg": replace(constructor_cfg, save_activation_data=True)} - ) # type: ignore -) -# %% -print("Loading model") -model_name = raw_loader.cache_config["model_name"] -cache_lm = AutoModelForCausalLM.from_pretrained( - model_name, trust_remote_code=True, device_map="cpu" -) -# %% - -# Try to construct lm_head for different HuggingFace architectures -if hasattr(cache_lm, "model") and hasattr(cache_lm.model, "norm") and hasattr(cache_lm.model, "lm_head"): - # llama models - lm_head = torch.nn.Sequential(cache_lm.model.norm, cache_lm.model.lm_head) -elif hasattr(cache_lm, "gpt_neox") and hasattr(cache_lm, "embed_out"): - # pythia models - lm_head = torch.nn.Sequential(cache_lm.gpt_neox.final_layer_norm, cache_lm.embed_out) -else: - raise ValueError("Unknown model architecture for extracting lm_head.") - -# %% -lm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) -# %% -# sae_path = Path(args.sae_path) -# # heuristic to find the right layer -# module_raw = next(iter(sae_path.glob("*" + args.module.partition(".")[2] + "*"))).name -# hookpoint_to_sparse_model = load_sparsify_sparse_coders( -# sae_path, -# [module_raw], -# "cuda", -# compile=False, -# ) -# transcoder = hookpoint_to_sparse_model[module_raw] -# w_dec = transcoder.W_dec.data -# latent_to_resid = w_dec - -# sae_path = Path(args.sae_path) -# heuristic to find the right layer -# module_raw = next(iter(sae_path.glob("*" + args.module.partition(".")[2] + "*"))).name -hookpoint_to_sparse_model = load_sparsify_sparse_coders( - args.sae_path, - [args.module], - "cuda", - compile=False, -) -transcoder = hookpoint_to_sparse_model[args.module] -w_dec = transcoder.W_dec.data -latent_to_resid = w_dec -# %% -del cache_lm -gc.collect() -# %% -tokens = raw_loader.buffers[0].load()[-1] -n_sequences, max_seq_len = tokens.shape -# %% - -cfg = SaeVisConfig( - hook_point=args.module, - minibatch_size_tokens=raw_loader.cache_config["cache_ctx_len"], - features=[], - # batch_size=dataset.cache_config["batch_size"], -) -layout = cfg.feature_centric_layout - -ranges_and_precisions = ASYMMETRIC_RANGES_AND_PRECISIONS -quantiles = [] -for r, p in ranges_and_precisions: - start, end = r - step = 10**-p - quantiles.extend(np.arange(start, end - 0.5 * step, step)) -quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32) - -# %% -latent_data_dict = {} - -latent_stats = FeatureStatistics() -# supposed_latent = 0 -bar = tqdm(total=args.latents) -i = -1 -for record in raw_loader: - i += 1 - - # https://github.com/jbloomAus/SAEDashboard/blob/main/sae_dashboard/utils_fns.py - assert record.activation_data is not None - latent_id = record.activation_data.locations[0, 2].item() - decoder_resid = latent_to_resid[latent_id].to( - record.activation_data.activations.device - ) - logit_vector = lm_head(decoder_resid) - - buffer = record.activation_data - activations, locations = buffer.activations, buffer.locations - _max = activations.max() - nonzero_mask = activations.abs() > 1e-6 - nonzero_acts = activations[nonzero_mask] - frac_nonzero = nonzero_mask.sum() / (n_sequences * max_seq_len) - quantile_data = torch.quantile(activations.float(), quantiles_tensor) - skew = torch.mean((activations - activations.mean()) ** 3) / ( - activations.std() ** 3 - ) - kurtosis = torch.mean((activations - activations.mean()) ** 4) / ( - activations.std() ** 4 + +def compute_quantiles() -> Tuple[List[float], torch.Tensor, List[Any]]: + """ + Compute quantile bins for activation statistics. + + Returns: + tuple[list[float], torch.Tensor, list[tuple]]: (quantiles, quantiles_tensor, ranges_and_precisions) + """ + ranges_and_precisions = ASYMMETRIC_RANGES_AND_PRECISIONS + quantiles = [] + for r, p in ranges_and_precisions: + start, end = r + step = 10**-p + quantiles.extend(np.arange(start, end - 0.5 * step, step)) + quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32) + return quantiles, quantiles_tensor, ranges_and_precisions + + +def process_latents( + raw_loader: LatentDataset, + lm_head: torch.nn.Module, + latent_to_resid: torch.Tensor, + layout: Any, + quantiles: List[float], + quantiles_tensor: torch.Tensor, + ranges_and_precisions: List[Any], + n_sequences: int, + max_seq_len: int, +) -> Tuple[Dict[int, FeatureData], FeatureStatistics]: + """ + Process each latent, compute statistics and visualization data. + + Args: + raw_loader (LatentDataset): Loader for latent activations. + lm_head (torch.nn.Module): LM head for logits computation. + latent_to_resid (torch.Tensor): Decoder weights. + layout: Dashboard layout config. + quantiles (list[float]): Quantile bins. + quantiles_tensor (torch.Tensor): Quantile bins as tensor. + ranges_and_precisions (list[tuple]): Ranges and precisions for quantiles. + n_sequences (int): Number of sequences. + max_seq_len (int): Maximum sequence length. + + Returns: + tuple[dict[int, FeatureData], FeatureStatistics]: + (latent_data_dict, latent_stats) + """ + latent_data_dict = {} + latent_stats = FeatureStatistics() + bar = tqdm(total=len(raw_loader.buffers)) + for _, record in enumerate(raw_loader): + assert record.activation_data is not None + latent_id = record.activation_data.locations[0, 2].item() + decoder_resid = latent_to_resid[latent_id].to(record.activation_data.activations.device) + logit_vector = lm_head(decoder_resid) + + activations = record.activation_data.activations + _max = activations.max() + nonzero_mask = activations.abs() > 1e-6 + nonzero_acts = activations[nonzero_mask] + frac_nonzero = nonzero_mask.sum() / (n_sequences * max_seq_len) + quantile_data = torch.quantile(activations.float(), quantiles_tensor) + skew = torch.mean((activations - activations.mean()) ** 3) / (activations.std() ** 3) + kurtosis = torch.mean((activations - activations.mean()) ** 4) / (activations.std() ** 4) + latent_stats.update( + FeatureStatistics( + max=[_max.item()], + frac_nonzero=[frac_nonzero.item()], + skew=[skew.item()], + kurtosis=[kurtosis.item()], + quantile_data=[quantile_data.unsqueeze(0).tolist()], + quantiles=quantiles + [1.0], + ranges_and_precisions=ranges_and_precisions, + ) + ) + + latent_data = FeatureData() + latent_data.feature_tables_data = FeatureTablesData() + latent_data.logits_histogram_data = LogitsHistogramData.from_data( + data=logit_vector.to(torch.float32), + n_bins=layout.logits_hist_cfg.n_bins, + tickmode="5 ticks", + title=None, + ) + latent_data.acts_histogram_data = ActsHistogramData.from_data( + data=nonzero_acts.to(torch.float32), + n_bins=layout.act_hist_cfg.n_bins, + tickmode="5 ticks", + title=f"ACTIVATIONS
DENSITY = {frac_nonzero:.3%}", + ) + latent_data.logits_table_data = get_logits_table_data( + logit_vector=logit_vector, + n_rows=layout.logits_table_cfg.n_rows, + ) + latent_data.decoder_weights_data = DecoderWeightsDistribution( + len(decoder_resid), decoder_resid.tolist() + ) + latent_data_dict[latent_id] = latent_data + bar.update(1) + bar.refresh() + bar.close() + return latent_data_dict, latent_stats + + +def process_sequences( + sequence_loader: LatentDataset, + latent_data_dict: Dict[int, FeatureData], + n_quantiles: int, +) -> None: + """ + Group sequence examples by quantile for each latent. + + Args: + sequence_loader (LatentDataset): Loader for sequence data. + latent_data_dict (dict[int, FeatureData]): Dictionary of latent feature data. + n_quantiles (int): Number of quantile groups. + + Returns: + None + """ + bar = tqdm(total=len(sequence_loader.buffers)) + for record in sequence_loader: + groups = [] + for quantile_index, quantile_data in enumerate( + list(batched(record.train, len(record.train) // n_quantiles))[::-1] + ): + group = [] + for example in quantile_data: + default_list = [0.0] * len(example.tokens) + logit_list = [[0.0]] * len(default_list) + token_list = [[0]] * len(default_list) + default_attrs = dict( + loss_contribution=default_list, + token_logits=default_list, + top_token_ids=token_list, + top_logits=logit_list, + bottom_token_ids=token_list, + bottom_logits=logit_list, + ) + group.append( + SequenceData( + token_ids=example.tokens.tolist(), + feat_acts=example.activations.tolist(), + **default_attrs, + ) + ) + groups.append( + SequenceGroupData( + title=f"Quantile {quantile_index/n_quantiles:1%}-{(quantile_index+1)/n_quantiles:1%}", + seq_data=group, + ) + ) + latent_data_dict[record.latent.latent_index].sequence_data = SequenceMultiGroupData( + seq_group_data=groups + ) + bar.update(1) + bar.refresh() + bar.close() + + +def main() -> None: + """ + Main function to run SAE dashboard conversion. + + Returns: + None + """ + # Setup configurations and arguments + torch.set_grad_enabled(False) + args = parse_args() + constructor_cfg = args.constructor_cfg + sampler_cfg = args.sampler_cfg + out_path = args.out_path + module = args.module + n_latents = args.latents + start_latent = 0 + latent_dict = {f"{module}": torch.arange(start_latent, start_latent + n_latents)} + kwargs = dict( + raw_dir=args.cache_path, + modules=[module], + latents=latent_dict, + sampler_cfg=sampler_cfg, + constructor_cfg=constructor_cfg, ) - latent_stats.update( - FeatureStatistics( - max=[_max.item()], - frac_nonzero=[frac_nonzero.item()], - skew=[skew.item()], - kurtosis=[kurtosis.item()], - quantile_data=[quantile_data.unsqueeze(0).tolist()], - quantiles=quantiles + [1.0], - ranges_and_precisions=ranges_and_precisions, + + # Load latent dataset + raw_loader = LatentDataset( + **( + kwargs + | {"constructor_cfg": replace(constructor_cfg, save_activation_data=True)} ) ) + print("Loading model") + model_name = raw_loader.cache_config["model_name"] + cache_lm = AutoModelForCausalLM.from_pretrained( + model_name, trust_remote_code=True, device_map="cpu" + ) + lm_head = get_lm_head(cache_lm) - latent_data = FeatureData() - latent_data.feature_tables_data = FeatureTablesData() - latent_data.logits_histogram_data = LogitsHistogramData.from_data( - data=logit_vector.to(torch.float32), # need this otherwise fails on MPS - n_bins=layout.logits_hist_cfg.n_bins, # type: ignore - tickmode="5 ticks", - title=None, + # Load SAE model + hookpoint_to_sparse_model = load_sparsify_sparse_coders( + args.sae_path, + [args.module], + "cuda", + compile=False, ) - latent_data.acts_histogram_data = ActsHistogramData.from_data( - data=nonzero_acts.to(torch.float32), - n_bins=layout.act_hist_cfg.n_bins, - tickmode="5 ticks", - title=f"ACTIVATIONS
DENSITY = {frac_nonzero:.3%}", + transcoder = hookpoint_to_sparse_model[args.module] + w_dec = transcoder.W_dec.data + latent_to_resid = w_dec + + del cache_lm + gc.collect() + + tokens = raw_loader.buffers[0].load()[-1] + n_sequences, max_seq_len = tokens.shape + + cfg = SaeVisConfig( + hook_point=args.module, + minibatch_size_tokens=raw_loader.cache_config["cache_ctx_len"], + features=[], ) - latent_data.logits_table_data = get_logits_table_data( - logit_vector=logit_vector, - n_rows=layout.logits_table_cfg.n_rows, # type: ignore + layout = cfg.feature_centric_layout + + quantiles, quantiles_tensor, ranges_and_precisions = compute_quantiles() + + # Process latents + latent_data_dict, latent_stats = process_latents( + raw_loader, lm_head, latent_to_resid, layout, quantiles, quantiles_tensor, ranges_and_precisions, n_sequences, max_seq_len ) - latent_data.decoder_weights_data = DecoderWeightsDistribution( - len(decoder_resid), decoder_resid.tolist() + latent_list = latent_dict[module].tolist() + cfg.features = latent_list + + # Process sequences + n_quantiles = sampler_cfg.n_quantiles + sequence_loader = LatentDataset( + **kwargs | dict(sampler_cfg=sampler_cfg) ) - latent_data_dict[latent_id] = latent_data - # supposed_latent += 1 - bar.update(1) - bar.refresh() -bar.close() - -latent_list = latent_dict[module].tolist() -cfg.features = latent_list -# %% -n_quantiles = sampler_cfg.n_quantiles -sequence_loader = LatentDataset( - **kwargs | dict(sampler_cfg=sampler_cfg) # type: ignore -) -bar = tqdm(total=args.latents) -for record in sequence_loader: - groups = [] - for quantile_index, quantile_data in enumerate( - list(batched(record.train, len(record.train) // n_quantiles))[::-1] - ): - group = [] - for example in quantile_data: - default_list = [0.0] * len(example.tokens) - logit_list = [[0.0]] * len(default_list) - token_list = [[0]] * len(default_list) - default_attrs = dict( - loss_contribution=default_list, - token_logits=default_list, - top_token_ids=token_list, - top_logits=logit_list, - bottom_token_ids=token_list, - bottom_logits=logit_list, - ) - group.append( - SequenceData( - token_ids=example.tokens.tolist(), - feat_acts=example.activations.tolist(), - **default_attrs, - ) - ) - groups.append( - SequenceGroupData( - title=f"Quantile {quantile_index/n_quantiles:1%}" - f"-{(quantile_index+1)/n_quantiles:1%}", - seq_data=group, - ) - ) - latent_data_dict[record.latent.latent_index].sequence_data = SequenceMultiGroupData( - seq_group_data=groups + process_sequences(sequence_loader, latent_data_dict, n_quantiles) + + # Save dashboard + latent_list = list(latent_data_dict.keys()) + tokenizer = raw_loader.tokenizer + model = Namespace(tokenizer=tokenizer) + sae_vis_data = SaeVisData( + cfg=cfg, + feature_data_dict=latent_data_dict, + feature_stats=latent_stats, + model=model, ) - bar.update(1) - bar.refresh() -bar.close() -# %% -latent_list = list(latent_data_dict.keys()) -tokenizer = raw_loader.tokenizer -model = Namespace( - tokenizer=tokenizer, -) + print(f"Saving dashboard to {out_path}") + save_feature_centric_vis(sae_vis_data=sae_vis_data, filename=out_path) -sae_vis_data = SaeVisData( - cfg=cfg, - feature_data_dict=latent_data_dict, - feature_stats=latent_stats, - model=model, -) -print("Saving dashboard to", out_path) -save_feature_centric_vis(sae_vis_data=sae_vis_data, filename=out_path) -# %% \ No newline at end of file + +if __name__ == "__main__": + main() \ No newline at end of file From 5dec56fd417af4cc4a48fe30e006914811a6d3ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:46:11 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/convert_sae_dashboard.py | 60 ++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/examples/convert_sae_dashboard.py b/examples/convert_sae_dashboard.py index 2aa4d4de..225b5082 100644 --- a/examples/convert_sae_dashboard.py +++ b/examples/convert_sae_dashboard.py @@ -1,14 +1,10 @@ import gc -from typing import Tuple, List, Dict, Any from argparse import Namespace from dataclasses import replace -from pathlib import Path +from typing import Any, Dict, List, Tuple + import numpy as np import torch -from tqdm.auto import tqdm -from simple_parsing import ArgumentParser -from transformers import AutoModelForCausalLM - from sae_dashboard.components import ( ActsHistogramData, DecoderWeightsDistribution, @@ -22,7 +18,11 @@ from sae_dashboard.data_writing_fns import save_feature_centric_vis from sae_dashboard.feature_data import FeatureData from sae_dashboard.sae_vis_data import SaeVisConfig, SaeVisData -from sae_dashboard.utils_fns import FeatureStatistics, ASYMMETRIC_RANGES_AND_PRECISIONS +from sae_dashboard.utils_fns import ASYMMETRIC_RANGES_AND_PRECISIONS, FeatureStatistics +from simple_parsing import ArgumentParser +from tqdm.auto import tqdm +from transformers import AutoModelForCausalLM + from delphi.config import ConstructorConfig, SamplerConfig from delphi.latents import LatentDataset from delphi.sparse_coders.sparse_model import load_sparsify_sparse_coders @@ -69,7 +69,9 @@ def parse_args() -> Namespace: parser.add_arguments( SamplerConfig, dest="sampler_cfg", - default=SamplerConfig(n_examples_train=25, n_quantiles=5, train_type="quantiles"), + default=SamplerConfig( + n_examples_train=25, n_quantiles=5, train_type="quantiles" + ), ) return parser.parse_args() @@ -84,12 +86,18 @@ def get_lm_head(cache_lm: AutoModelForCausalLM) -> torch.nn.Sequential: Returns: torch.nn.Sequential: The LM head module for logits computation. """ - if hasattr(cache_lm, "model") and hasattr(cache_lm.model, "norm") and hasattr(cache_lm, "lm_head"): + if ( + hasattr(cache_lm, "model") + and hasattr(cache_lm.model, "norm") + and hasattr(cache_lm, "lm_head") + ): # llama models return torch.nn.Sequential(cache_lm.model.norm, cache_lm.lm_head) elif hasattr(cache_lm, "gpt_neox") and hasattr(cache_lm, "embed_out"): # pythia models - return torch.nn.Sequential(cache_lm.gpt_neox.final_layer_norm, cache_lm.embed_out) + return torch.nn.Sequential( + cache_lm.gpt_neox.final_layer_norm, cache_lm.embed_out + ) else: raise ValueError("Unknown model architecture for extracting lm_head.") @@ -146,7 +154,9 @@ def process_latents( for _, record in enumerate(raw_loader): assert record.activation_data is not None latent_id = record.activation_data.locations[0, 2].item() - decoder_resid = latent_to_resid[latent_id].to(record.activation_data.activations.device) + decoder_resid = latent_to_resid[latent_id].to( + record.activation_data.activations.device + ) logit_vector = lm_head(decoder_resid) activations = record.activation_data.activations @@ -155,8 +165,12 @@ def process_latents( nonzero_acts = activations[nonzero_mask] frac_nonzero = nonzero_mask.sum() / (n_sequences * max_seq_len) quantile_data = torch.quantile(activations.float(), quantiles_tensor) - skew = torch.mean((activations - activations.mean()) ** 3) / (activations.std() ** 3) - kurtosis = torch.mean((activations - activations.mean()) ** 4) / (activations.std() ** 4) + skew = torch.mean((activations - activations.mean()) ** 3) / ( + activations.std() ** 3 + ) + kurtosis = torch.mean((activations - activations.mean()) ** 4) / ( + activations.std() ** 4 + ) latent_stats.update( FeatureStatistics( max=[_max.item()], @@ -245,8 +259,8 @@ def process_sequences( seq_data=group, ) ) - latent_data_dict[record.latent.latent_index].sequence_data = SequenceMultiGroupData( - seq_group_data=groups + latent_data_dict[record.latent.latent_index].sequence_data = ( + SequenceMultiGroupData(seq_group_data=groups) ) bar.update(1) bar.refresh() @@ -320,16 +334,22 @@ def main() -> None: # Process latents latent_data_dict, latent_stats = process_latents( - raw_loader, lm_head, latent_to_resid, layout, quantiles, quantiles_tensor, ranges_and_precisions, n_sequences, max_seq_len + raw_loader, + lm_head, + latent_to_resid, + layout, + quantiles, + quantiles_tensor, + ranges_and_precisions, + n_sequences, + max_seq_len, ) latent_list = latent_dict[module].tolist() cfg.features = latent_list # Process sequences n_quantiles = sampler_cfg.n_quantiles - sequence_loader = LatentDataset( - **kwargs | dict(sampler_cfg=sampler_cfg) - ) + sequence_loader = LatentDataset(**kwargs | dict(sampler_cfg=sampler_cfg)) process_sequences(sequence_loader, latent_data_dict, n_quantiles) # Save dashboard @@ -347,4 +367,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() From 1b43927a41dd864c9b018437abab12d67cf540ff Mon Sep 17 00:00:00 2001 From: mathreader Date: Wed, 13 Aug 2025 02:05:24 +0800 Subject: [PATCH 4/5] fix lines that are too long --- examples/convert_sae_dashboard.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/convert_sae_dashboard.py b/examples/convert_sae_dashboard.py index 225b5082..40f1168c 100644 --- a/examples/convert_sae_dashboard.py +++ b/examples/convert_sae_dashboard.py @@ -107,7 +107,8 @@ def compute_quantiles() -> Tuple[List[float], torch.Tensor, List[Any]]: Compute quantile bins for activation statistics. Returns: - tuple[list[float], torch.Tensor, list[tuple]]: (quantiles, quantiles_tensor, ranges_and_precisions) + tuple[list[float], torch.Tensor, list[tuple]]: quantiles, + quantiles_tensor, ranges_and_precisions """ ranges_and_precisions = ASYMMETRIC_RANGES_AND_PRECISIONS quantiles = [] @@ -253,9 +254,12 @@ def process_sequences( **default_attrs, ) ) + quantile_start = quantile_index / n_quantiles + quantile_end = (quantile_index + 1) / n_quantiles + title = f"Quantile {quantile_start:1%}-{quantile_end:1%}" groups.append( SequenceGroupData( - title=f"Quantile {quantile_index/n_quantiles:1%}-{(quantile_index+1)/n_quantiles:1%}", + title=title, seq_data=group, ) ) From 2cb7084e99263b2aef5a3cac14be63505639e041 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Aug 2025 18:05:36 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/convert_sae_dashboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/convert_sae_dashboard.py b/examples/convert_sae_dashboard.py index 40f1168c..385d4662 100644 --- a/examples/convert_sae_dashboard.py +++ b/examples/convert_sae_dashboard.py @@ -107,7 +107,7 @@ def compute_quantiles() -> Tuple[List[float], torch.Tensor, List[Any]]: Compute quantile bins for activation statistics. Returns: - tuple[list[float], torch.Tensor, list[tuple]]: quantiles, + tuple[list[float], torch.Tensor, list[tuple]]: quantiles, quantiles_tensor, ranges_and_precisions """ ranges_and_precisions = ASYMMETRIC_RANGES_AND_PRECISIONS