@@ -168,24 +168,23 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
168168 """
169169 Calculates the feature's decoder vector, subtracting the decoder bias.
170170 """
171-
172-
171+
173172 d_latent = sae .encoder .out_features
174173 sae_device = sae .encoder .weight .device
175174
176175 # Create a one-hot activation for our single feature.
177176 one_hot_activation = torch .zeros (1 , 1 , d_latent , device = sae_device )
178-
177+
179178 if feature_id >= d_latent :
180- print (f"DEBUG: ERROR - Feature ID { feature_id } is out of bounds for d_latent { d_latent } " )
179+ print (
180+ f"DEBUG: ERROR - Feature ID { feature_id } is out of bounds for d_latent { d_latent } "
181+ )
181182 return torch .zeros (1 )
182-
183+
183184 one_hot_activation [0 , 0 , feature_id ] = 1.0
184185
185186 # Create the corresponding indices needed for the decode method.
186- indices = torch .tensor (
187- [[[feature_id ]]], device = sae_device , dtype = torch .long
188- )
187+ indices = torch .tensor ([[[feature_id ]]], device = sae_device , dtype = torch .long )
189188
190189 with torch .no_grad ():
191190 try :
@@ -196,24 +195,25 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
196195 return torch .zeros (1 )
197196
198197 decoder_vector = vector_before_sub - decoded_zero
199-
198+
200199 final_norm = decoder_vector .norm ().item ()
201-
200+
202201 # --- MODIFIED DEBUG BLOCK ---
203202 # Only print if the feature is "decoder-live"
204203 if final_norm > 1e-6 :
205204 print (f"\n --- DEBUG: 'Decoder-Live' Feature Found: { feature_id } ---" )
206205 print (f"DEBUG: sae.encoder.out_features (d_latent): { d_latent } " )
207206 print (f"DEBUG: sae.encoder.weight.device (sae_device): { sae_device } " )
208207 print (f"DEBUG: Norm of decoded_zero: { decoded_zero .norm ().item ()} " )
209- print (f"DEBUG: Norm of vector_before_sub: { vector_before_sub .norm ().item ()} " )
208+ print (
209+ f"DEBUG: Norm of vector_before_sub: { vector_before_sub .norm ().item ()} "
210+ )
210211 print (f"DEBUG: Feature { feature_id } , FINAL Vector Norm: { final_norm } " )
211212 print ("--- END DEBUG ---\n " )
212213 # --- END MODIFIED BLOCK ---
213214
214215 return decoder_vector .squeeze ()
215216
216-
217217 async def __call__ (self , record : LatentRecord ) -> ScorerResult :
218218
219219 record_copy = copy .deepcopy (record )
@@ -240,7 +240,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
240240 sae = self ._get_sae_for_hookpoint (hookpoint_str , record_copy )
241241 if not sae :
242242 raise ValueError (f"Could not find SAE for hookpoint { hookpoint_str } " )
243-
243+
244244 intervention_vector = self ._get_intervention_vector (sae , record_copy .feature_id )
245245
246246 tuned_strength , initial_kl = await self ._tune_strength (
@@ -253,10 +253,18 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
253253
254254 for prompt in truncated_prompts :
255255 clean_text , clean_logp_dist = await self ._generate_with_intervention (
256- prompt , record_copy , strength = 0.0 , intervention_vector = intervention_vector , get_logp_dist = True
256+ prompt ,
257+ record_copy ,
258+ strength = 0.0 ,
259+ intervention_vector = intervention_vector ,
260+ get_logp_dist = True ,
257261 )
258262 int_text , int_logp_dist = await self ._generate_with_intervention (
259- prompt , record_copy , strength = tuned_strength , intervention_vector = intervention_vector , get_logp_dist = True
263+ prompt ,
264+ record_copy ,
265+ strength = tuned_strength ,
266+ intervention_vector = intervention_vector ,
267+ get_logp_dist = True ,
260268 )
261269
262270 logp_clean = await self ._score_explanation (
@@ -300,7 +308,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
300308 )
301309 return ScorerResult (record = record_copy , score = final_output_list )
302310
303-
304311 async def _get_latent_activations (
305312 self , prompt : str , record : LatentRecord
306313 ) -> torch .Tensor :
@@ -339,7 +346,6 @@ def capture_hook(module, inp, out):
339346
340347 return feature_acts [0 , :, record .feature_id ].cpu ()
341348
342-
343349 async def _truncate_prompt (self , prompt : str , record : LatentRecord ) -> str :
344350 """
345351 Truncates prompt to end just before the first token where latent activates.
@@ -356,17 +362,18 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
356362 first_activation_idx = all_activation_indices [all_activation_indices > 0 ]
357363
358364 if first_activation_idx .numel () > 0 :
359- truncation_point = first_activation_idx [0 ].item ()
365+ truncation_point = first_activation_idx [0 ].item ()
360366 input_ids = self .tokenizer (prompt , return_tensors = "pt" ).input_ids [0 ]
361- truncated_ids = input_ids [:truncation_point + 1 ]
367+ truncated_ids = input_ids [: truncation_point + 1 ]
362368 return self .tokenizer .decode (truncated_ids , skip_special_tokens = True )
363369
364370 return prompt
365371
366-
367372 async def _tune_strength (
368- self , prompts : List [str ], record : LatentRecord ,
369- intervention_vector : torch .Tensor
373+ self ,
374+ prompts : List [str ],
375+ record : LatentRecord ,
376+ intervention_vector : torch .Tensor ,
370377 ) -> Tuple [float , float ]:
371378 """
372379 Performs a binary search to find intervention strength that matches target_kl.
@@ -408,22 +415,26 @@ async def _tune_strength(
408415 best_strength = mid_strength
409416
410417 # Return the best found strength and the corresponding KL
411- final_kl = await self ._calculate_avg_kl (prompts , record , best_strength , intervention_vector )
418+ final_kl = await self ._calculate_avg_kl (
419+ prompts , record , best_strength , intervention_vector
420+ )
412421 return best_strength , final_kl
413422
414-
415423 async def _calculate_avg_kl (
416- self , prompts : List [str ], record : LatentRecord , strength : float ,
417- intervention_vector : torch .Tensor
424+ self ,
425+ prompts : List [str ],
426+ record : LatentRecord ,
427+ strength : float ,
428+ intervention_vector : torch .Tensor ,
418429 ) -> float :
419430 total_kl = 0.0
420431 n = 0
421432 for prompt in prompts :
422433 _ , clean_logp = await self ._generate_with_intervention (
423- prompt , record , 0.0 , intervention_vector ,True
434+ prompt , record , 0.0 , intervention_vector , True
424435 )
425436 _ , int_logp = await self ._generate_with_intervention (
426- prompt , record , strength , intervention_vector ,True
437+ prompt , record , strength , intervention_vector , True
427438 )
428439 p_clean = torch .exp (clean_logp )
429440 kl_div = F .kl_div (
@@ -433,7 +444,6 @@ async def _calculate_avg_kl(
433444 n += 1
434445 return total_kl / n if n > 0 else 0.0
435446
436-
437447 async def _generate_with_intervention (
438448 self ,
439449 prompt : str ,
@@ -473,8 +483,9 @@ def hook_fn(module, inp, out):
473483 intervention_start_index = prompt_length - 1
474484
475485 if current_seq_len >= prompt_length :
476- new_hiddens [:, intervention_start_index :, :] += delta .to (original_dtype )
477-
486+ new_hiddens [:, intervention_start_index :, :] += delta .to (
487+ original_dtype
488+ )
478489
479490 return (
480491 (new_hiddens ,) + out [1 :] if isinstance (out , tuple ) else new_hiddens
@@ -484,7 +495,7 @@ def hook_fn(module, inp, out):
484495
485496 try :
486497 with torch .no_grad ():
487- outputs = self .subject_model (input_ids , attention_mask = attention_mask )
498+ outputs = self .subject_model (input_ids , attention_mask = attention_mask )
488499 next_token_logits = outputs .logits [0 , - 1 , :]
489500 log_probs_next_token = (
490501 F .log_softmax (next_token_logits , dim = - 1 ) if get_logp_dist else None
@@ -506,10 +517,9 @@ def hook_fn(module, inp, out):
506517 log_probs_next_token .cpu () if get_logp_dist else torch .empty (0 )
507518 )
508519
509-
510520 async def _score_explanation (self , generated_text : str , explanation : str ) -> float :
511521 """
512- Computes log P(explanation | generated_text) using the paper's
522+ Computes log P(explanation | generated_text) using the paper's
513523 prompt format.
514524 """
515525 device = self ._get_device ()
@@ -518,9 +528,9 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
518528 prompt_template = (
519529 "<PASSAGE>\n "
520530 f"{ generated_text } \n "
521- " The above passage contains an amplified amount of \" "
531+ ' The above passage contains an amplified amount of "'
522532 )
523- explanation_suffix = f" { explanation } \" "
533+ explanation_suffix = f' { explanation } "'
524534
525535 # Tokenize the parts
526536 context_enc = self .tokenizer (prompt_template , return_tensors = "pt" )
@@ -536,7 +546,7 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
536546
537547 # We only need to score the explanation part
538548 context_len = context_enc .input_ids .shape [1 ]
539-
549+
540550 # Get logits for positions that predict the explanation tokens
541551 # Shape: [batch_size, explanation_len, vocab_size]
542552 explanation_logits = logits [:, context_len - 1 : - 1 , :]
@@ -548,14 +558,11 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
548558 log_probs = F .log_softmax (explanation_logits , dim = - 1 )
549559
550560 # Gather the log-probabilities of the actual explanation tokens
551- token_log_probs = log_probs .gather (
552- 2 , target_ids .unsqueeze (- 1 )
553- ).squeeze (- 1 )
561+ token_log_probs = log_probs .gather (2 , target_ids .unsqueeze (- 1 )).squeeze (- 1 )
554562
555563 # Return the sum of log-probs for the explanation
556564 return token_log_probs .sum ().item ()
557565
558-
559566 def _get_sae_for_hookpoint (self , hookpoint_str : str , record : LatentRecord ) -> Any :
560567 """
561568 Retrieves the correct SAE model, handling the specific functools.partial
@@ -567,31 +574,37 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
567574 candidate = record .sae
568575 elif self .explainer_model and isinstance (self .explainer_model , dict ):
569576 full_key = self ._get_full_hookpoint_path (hookpoint_str )
570- short_key = "." .join (hookpoint_str .split ("." )[- 2 :]) # e.g., "layers.6.mlp"
577+ short_key = "." .join (hookpoint_str .split ("." )[- 2 :]) # e.g., "layers.6.mlp"
571578
572579 for key in [hookpoint_str , full_key , short_key ]:
573580 if self .explainer_model .get (key ) is not None :
574581 candidate = self .explainer_model .get (key )
575582 break
576-
583+
577584 if candidate is None :
578585 # This will raise an error if the key isn't found
579- raise ValueError (f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{ hookpoint_str } ' in self.explainer_model" )
586+ raise ValueError (
587+ f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{ hookpoint_str } ' in self.explainer_model"
588+ )
580589
581590 if isinstance (candidate , functools .partial ):
582591 # As shown in load_sparsify.py, the SAE is in the 'sae' keyword.
583592 if candidate .keywords and "sae" in candidate .keywords :
584593 return candidate .keywords ["sae" ] # Unwrapped successfully
585594 else :
586595 # This will raise an error if the partial is missing the keyword
587- raise ValueError (f"""ERROR: Found a partial for { hookpoint_str } but could not
596+ raise ValueError (
597+ f"""ERROR: Found a partial for { hookpoint_str } but could not
588598 find the 'sae' keyword.
589599 func: { candidate .func }
590600 args: { candidate .args }
591- keywords: { candidate .keywords } """ )
592-
601+ keywords: { candidate .keywords } """
602+ )
603+
593604 # This will raise an error if the candidate isn't a partial
594- raise ValueError (f"ERROR: Candidate for { hookpoint_str } was not a partial object, which was not expected. Type: { type (candidate )} " )
605+ raise ValueError (
606+ f"ERROR: Candidate for { hookpoint_str } was not a partial object, which was not expected. Type: { type (candidate )} "
607+ )
595608
596609 def _get_intervention_direction (self , record : LatentRecord ) -> torch .Tensor :
597610 hookpoint_str = self .hookpoint_str or getattr (record , "hookpoint" , None )
0 commit comments