Skip to content

Commit 68d6c63

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5161d9a commit 68d6c63

File tree

1 file changed

+61
-48
lines changed

1 file changed

+61
-48
lines changed

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,23 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
168168
"""
169169
Calculates the feature's decoder vector, subtracting the decoder bias.
170170
"""
171-
172-
171+
173172
d_latent = sae.encoder.out_features
174173
sae_device = sae.encoder.weight.device
175174

176175
# Create a one-hot activation for our single feature.
177176
one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device)
178-
177+
179178
if feature_id >= d_latent:
180-
print(f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}")
179+
print(
180+
f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}"
181+
)
181182
return torch.zeros(1)
182-
183+
183184
one_hot_activation[0, 0, feature_id] = 1.0
184185

185186
# Create the corresponding indices needed for the decode method.
186-
indices = torch.tensor(
187-
[[[feature_id]]], device=sae_device, dtype=torch.long
188-
)
187+
indices = torch.tensor([[[feature_id]]], device=sae_device, dtype=torch.long)
189188

190189
with torch.no_grad():
191190
try:
@@ -196,24 +195,25 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
196195
return torch.zeros(1)
197196

198197
decoder_vector = vector_before_sub - decoded_zero
199-
198+
200199
final_norm = decoder_vector.norm().item()
201-
200+
202201
# --- MODIFIED DEBUG BLOCK ---
203202
# Only print if the feature is "decoder-live"
204203
if final_norm > 1e-6:
205204
print(f"\n--- DEBUG: 'Decoder-Live' Feature Found: {feature_id} ---")
206205
print(f"DEBUG: sae.encoder.out_features (d_latent): {d_latent}")
207206
print(f"DEBUG: sae.encoder.weight.device (sae_device): {sae_device}")
208207
print(f"DEBUG: Norm of decoded_zero: {decoded_zero.norm().item()}")
209-
print(f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}")
208+
print(
209+
f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}"
210+
)
210211
print(f"DEBUG: Feature {feature_id}, FINAL Vector Norm: {final_norm}")
211212
print("--- END DEBUG ---\n")
212213
# --- END MODIFIED BLOCK ---
213214

214215
return decoder_vector.squeeze()
215216

216-
217217
async def __call__(self, record: LatentRecord) -> ScorerResult:
218218

219219
record_copy = copy.deepcopy(record)
@@ -240,7 +240,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
240240
sae = self._get_sae_for_hookpoint(hookpoint_str, record_copy)
241241
if not sae:
242242
raise ValueError(f"Could not find SAE for hookpoint {hookpoint_str}")
243-
243+
244244
intervention_vector = self._get_intervention_vector(sae, record_copy.feature_id)
245245

246246
tuned_strength, initial_kl = await self._tune_strength(
@@ -253,10 +253,18 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
253253

254254
for prompt in truncated_prompts:
255255
clean_text, clean_logp_dist = await self._generate_with_intervention(
256-
prompt, record_copy, strength=0.0, intervention_vector=intervention_vector, get_logp_dist=True
256+
prompt,
257+
record_copy,
258+
strength=0.0,
259+
intervention_vector=intervention_vector,
260+
get_logp_dist=True,
257261
)
258262
int_text, int_logp_dist = await self._generate_with_intervention(
259-
prompt, record_copy, strength=tuned_strength, intervention_vector=intervention_vector, get_logp_dist=True
263+
prompt,
264+
record_copy,
265+
strength=tuned_strength,
266+
intervention_vector=intervention_vector,
267+
get_logp_dist=True,
260268
)
261269

262270
logp_clean = await self._score_explanation(
@@ -300,7 +308,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
300308
)
301309
return ScorerResult(record=record_copy, score=final_output_list)
302310

303-
304311
async def _get_latent_activations(
305312
self, prompt: str, record: LatentRecord
306313
) -> torch.Tensor:
@@ -339,7 +346,6 @@ def capture_hook(module, inp, out):
339346

340347
return feature_acts[0, :, record.feature_id].cpu()
341348

342-
343349
async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
344350
"""
345351
Truncates prompt to end just before the first token where latent activates.
@@ -356,17 +362,18 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
356362
first_activation_idx = all_activation_indices[all_activation_indices > 0]
357363

358364
if first_activation_idx.numel() > 0:
359-
truncation_point = first_activation_idx[0].item()
365+
truncation_point = first_activation_idx[0].item()
360366
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0]
361-
truncated_ids = input_ids[:truncation_point + 1]
367+
truncated_ids = input_ids[: truncation_point + 1]
362368
return self.tokenizer.decode(truncated_ids, skip_special_tokens=True)
363369

364370
return prompt
365371

366-
367372
async def _tune_strength(
368-
self, prompts: List[str], record: LatentRecord,
369-
intervention_vector: torch.Tensor
373+
self,
374+
prompts: List[str],
375+
record: LatentRecord,
376+
intervention_vector: torch.Tensor,
370377
) -> Tuple[float, float]:
371378
"""
372379
Performs a binary search to find intervention strength that matches target_kl.
@@ -408,22 +415,26 @@ async def _tune_strength(
408415
best_strength = mid_strength
409416

410417
# Return the best found strength and the corresponding KL
411-
final_kl = await self._calculate_avg_kl(prompts, record, best_strength, intervention_vector)
418+
final_kl = await self._calculate_avg_kl(
419+
prompts, record, best_strength, intervention_vector
420+
)
412421
return best_strength, final_kl
413422

414-
415423
async def _calculate_avg_kl(
416-
self, prompts: List[str], record: LatentRecord, strength: float,
417-
intervention_vector: torch.Tensor
424+
self,
425+
prompts: List[str],
426+
record: LatentRecord,
427+
strength: float,
428+
intervention_vector: torch.Tensor,
418429
) -> float:
419430
total_kl = 0.0
420431
n = 0
421432
for prompt in prompts:
422433
_, clean_logp = await self._generate_with_intervention(
423-
prompt, record, 0.0, intervention_vector,True
434+
prompt, record, 0.0, intervention_vector, True
424435
)
425436
_, int_logp = await self._generate_with_intervention(
426-
prompt, record, strength, intervention_vector,True
437+
prompt, record, strength, intervention_vector, True
427438
)
428439
p_clean = torch.exp(clean_logp)
429440
kl_div = F.kl_div(
@@ -433,7 +444,6 @@ async def _calculate_avg_kl(
433444
n += 1
434445
return total_kl / n if n > 0 else 0.0
435446

436-
437447
async def _generate_with_intervention(
438448
self,
439449
prompt: str,
@@ -473,8 +483,9 @@ def hook_fn(module, inp, out):
473483
intervention_start_index = prompt_length - 1
474484

475485
if current_seq_len >= prompt_length:
476-
new_hiddens[:, intervention_start_index:, :] += delta.to(original_dtype)
477-
486+
new_hiddens[:, intervention_start_index:, :] += delta.to(
487+
original_dtype
488+
)
478489

479490
return (
480491
(new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens
@@ -484,7 +495,7 @@ def hook_fn(module, inp, out):
484495

485496
try:
486497
with torch.no_grad():
487-
outputs =self.subject_model(input_ids, attention_mask=attention_mask)
498+
outputs = self.subject_model(input_ids, attention_mask=attention_mask)
488499
next_token_logits = outputs.logits[0, -1, :]
489500
log_probs_next_token = (
490501
F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None
@@ -506,10 +517,9 @@ def hook_fn(module, inp, out):
506517
log_probs_next_token.cpu() if get_logp_dist else torch.empty(0)
507518
)
508519

509-
510520
async def _score_explanation(self, generated_text: str, explanation: str) -> float:
511521
"""
512-
Computes log P(explanation | generated_text) using the paper's
522+
Computes log P(explanation | generated_text) using the paper's
513523
prompt format.
514524
"""
515525
device = self._get_device()
@@ -518,9 +528,9 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
518528
prompt_template = (
519529
"<PASSAGE>\n"
520530
f"{generated_text}\n"
521-
"The above passage contains an amplified amount of \""
531+
'The above passage contains an amplified amount of "'
522532
)
523-
explanation_suffix = f"{explanation}\""
533+
explanation_suffix = f'{explanation}"'
524534

525535
# Tokenize the parts
526536
context_enc = self.tokenizer(prompt_template, return_tensors="pt")
@@ -536,7 +546,7 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
536546

537547
# We only need to score the explanation part
538548
context_len = context_enc.input_ids.shape[1]
539-
549+
540550
# Get logits for positions that predict the explanation tokens
541551
# Shape: [batch_size, explanation_len, vocab_size]
542552
explanation_logits = logits[:, context_len - 1 : -1, :]
@@ -548,14 +558,11 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
548558
log_probs = F.log_softmax(explanation_logits, dim=-1)
549559

550560
# Gather the log-probabilities of the actual explanation tokens
551-
token_log_probs = log_probs.gather(
552-
2, target_ids.unsqueeze(-1)
553-
).squeeze(-1)
561+
token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
554562

555563
# Return the sum of log-probs for the explanation
556564
return token_log_probs.sum().item()
557565

558-
559566
def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any:
560567
"""
561568
Retrieves the correct SAE model, handling the specific functools.partial
@@ -567,31 +574,37 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
567574
candidate = record.sae
568575
elif self.explainer_model and isinstance(self.explainer_model, dict):
569576
full_key = self._get_full_hookpoint_path(hookpoint_str)
570-
short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp"
577+
short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp"
571578

572579
for key in [hookpoint_str, full_key, short_key]:
573580
if self.explainer_model.get(key) is not None:
574581
candidate = self.explainer_model.get(key)
575582
break
576-
583+
577584
if candidate is None:
578585
# This will raise an error if the key isn't found
579-
raise ValueError(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model")
586+
raise ValueError(
587+
f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model"
588+
)
580589

581590
if isinstance(candidate, functools.partial):
582591
# As shown in load_sparsify.py, the SAE is in the 'sae' keyword.
583592
if candidate.keywords and "sae" in candidate.keywords:
584593
return candidate.keywords["sae"] # Unwrapped successfully
585594
else:
586595
# This will raise an error if the partial is missing the keyword
587-
raise ValueError(f"""ERROR: Found a partial for {hookpoint_str} but could not
596+
raise ValueError(
597+
f"""ERROR: Found a partial for {hookpoint_str} but could not
588598
find the 'sae' keyword.
589599
func: {candidate.func}
590600
args: {candidate.args}
591-
keywords: {candidate.keywords}""")
592-
601+
keywords: {candidate.keywords}"""
602+
)
603+
593604
# This will raise an error if the candidate isn't a partial
594-
raise ValueError(f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}")
605+
raise ValueError(
606+
f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}"
607+
)
595608

596609
def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor:
597610
hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None)

0 commit comments

Comments
 (0)