@@ -875,10 +875,12 @@ def _dose_smoothness_loss(self, batch: Dict[str, torch.Tensor], pred: torch.Tens
875875 return pred .new_tensor (0.0 )
876876 return torch .stack (losses ).mean ()
877877
878- def _compute_distribution_loss (self , pred : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
878+ def _compute_distribution_loss (
879+ self , pred : torch .Tensor , target : torch .Tensor , * , allow_discrete : bool = False
880+ ) -> torch .Tensor :
879881 """Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss."""
880882
881- if self .discretize :
883+ if self .discretize and not allow_discrete :
882884 raise RuntimeError ("Distributional loss is not used when discretize is enabled." )
883885
884886 if isinstance (self .loss_fn , SamplesLoss ) and self .mmd_num_chunks > 1 :
@@ -905,6 +907,18 @@ def _reshape_logits(self, logits: torch.Tensor, padded: bool) -> torch.Tensor:
905907 reshaped = self ._reshape_sequence (logits , padded , self ._project_out_dim )
906908 return reshaped .view (reshaped .size (0 ), reshaped .size (1 ), self ._prediction_dim , self .num_expression_buckets )
907909
910+ def _expected_expression_from_logits (self , logits : torch .Tensor ) -> torch .Tensor :
911+ """Compute expected expression values by weighing bucket means with predicted probabilities."""
912+
913+ if self .bucket_boundaries is None :
914+ raise RuntimeError ("Bucket boundaries must be initialized to compute expectations from logits." )
915+
916+ means = self ._bucket_means ().to (logits .device )
917+ probs = F .softmax (logits , dim = - 1 )
918+ expanded_means = means .unsqueeze (0 ).unsqueeze (0 )
919+ expectation = (probs * expanded_means ).sum (dim = - 1 )
920+ return expectation
921+
908922 def _get_target_expression (self , batch : Dict [str , torch .Tensor ]) -> torch .Tensor :
909923 if "pert_cell_counts" in batch and batch ["pert_cell_counts" ] is not None :
910924 return batch ["pert_cell_counts" ]
@@ -1008,11 +1022,31 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
10081022 with torch .no_grad ():
10091023 bucket_targets = self ._discretize_expression (target_expression )
10101024 self ._update_bucket_statistics (target_expression , bucket_targets )
1011- per_token_losses = self ._emd_per_token (logits , bucket_targets )
1012- per_set_main_losses = per_token_losses .view (per_token_losses .size (0 ), - 1 ).mean (dim = 1 )
1013- main_loss = per_set_main_losses .mean ()
1014- self .log ("train_loss" , main_loss )
1015- total_loss = main_loss
1025+
1026+ per_token_emd_losses = self ._emd_per_token (logits , bucket_targets )
1027+ emd_per_set = per_token_emd_losses .view (per_token_emd_losses .size (0 ), - 1 ).mean (dim = 1 )
1028+ emd_loss = emd_per_set .mean ()
1029+
1030+ expected_expression = self ._expected_expression_from_logits (logits )
1031+ per_set_main_losses = self ._compute_distribution_loss (
1032+ expected_expression ,
1033+ target_expression ,
1034+ allow_discrete = True ,
1035+ )
1036+ main_loss = torch .nanmean (per_set_main_losses )
1037+
1038+ if hasattr (self .loss_fn , "sinkhorn_loss" ) and hasattr (self .loss_fn , "energy_loss" ):
1039+ sinkhorn_component = self .loss_fn .sinkhorn_loss (expected_expression , target_expression ).nanmean ()
1040+ energy_component = self .loss_fn .energy_loss (expected_expression , target_expression ).nanmean ()
1041+ self .log ("train/sinkhorn_loss" , sinkhorn_component )
1042+ self .log ("train/energy_loss" , energy_component )
1043+
1044+ self .log ("decoder_loss" , emd_loss )
1045+ self .log ("train/emd_loss" , emd_loss )
1046+ self .log ("train/mmd_loss" , main_loss )
1047+
1048+ total_loss = main_loss + emd_loss
1049+ self .log ("train_loss" , total_loss )
10161050 else :
10171051 target = batch ["pert_cell_emb" ]
10181052 pred = self ._reshape_sequence (pred , padded , self .output_dim )
@@ -1159,10 +1193,28 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
11591193 target_expression = self ._reshape_sequence (target_expression , padded = True , feature_dim = self ._prediction_dim )
11601194 with torch .no_grad ():
11611195 bucket_targets = self ._discretize_expression (target_expression )
1162- per_token_losses = self ._emd_per_token (logits , bucket_targets )
1163- per_set_main_losses = per_token_losses .view (per_token_losses .size (0 ), - 1 ).mean (dim = 1 )
1164- loss = per_set_main_losses .mean ()
1165- self .log ("val_loss" , loss )
1196+ per_token_emd_losses = self ._emd_per_token (logits , bucket_targets )
1197+ emd_per_set = per_token_emd_losses .view (per_token_emd_losses .size (0 ), - 1 ).mean (dim = 1 )
1198+ emd_loss = emd_per_set .mean ()
1199+
1200+ expected_expression = self ._expected_expression_from_logits (logits )
1201+ per_set_main_losses = self ._compute_distribution_loss (
1202+ expected_expression ,
1203+ target_expression ,
1204+ allow_discrete = True ,
1205+ )
1206+ main_loss = torch .nanmean (per_set_main_losses )
1207+
1208+ if hasattr (self .loss_fn , "sinkhorn_loss" ) and hasattr (self .loss_fn , "energy_loss" ):
1209+ sinkhorn_component = self .loss_fn .sinkhorn_loss (expected_expression , target_expression ).mean ()
1210+ energy_component = self .loss_fn .energy_loss (expected_expression , target_expression ).mean ()
1211+ self .log ("val/sinkhorn_loss" , sinkhorn_component )
1212+ self .log ("val/energy_loss" , energy_component )
1213+
1214+ total_loss = main_loss + emd_loss
1215+ self .log ("val_loss" , total_loss )
1216+ self .log ("val/decoder_loss" , emd_loss )
1217+ self .log ("val/mmd_loss" , main_loss )
11661218 else :
11671219 pred = self ._reshape_sequence (pred , padded = True , feature_dim = self .output_dim )
11681220 target = batch ["pert_cell_emb" ]
@@ -1232,10 +1284,22 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
12321284 target_expression = self ._reshape_sequence (target_expression , padded = False , feature_dim = self ._prediction_dim )
12331285 with torch .no_grad ():
12341286 bucket_targets = self ._discretize_expression (target_expression )
1235- per_token_losses = self ._emd_per_token (logits , bucket_targets )
1236- per_set_main_losses = per_token_losses .view (per_token_losses .size (0 ), - 1 ).mean (dim = 1 )
1237- loss = per_set_main_losses .mean ()
1238- self .log ("test_loss" , loss )
1287+ per_token_emd_losses = self ._emd_per_token (logits , bucket_targets )
1288+ emd_per_set = per_token_emd_losses .view (per_token_emd_losses .size (0 ), - 1 ).mean (dim = 1 )
1289+ emd_loss = emd_per_set .mean ()
1290+
1291+ expected_expression = self ._expected_expression_from_logits (logits )
1292+ per_set_main_losses = self ._compute_distribution_loss (
1293+ expected_expression ,
1294+ target_expression ,
1295+ allow_discrete = True ,
1296+ )
1297+ main_loss = torch .nanmean (per_set_main_losses )
1298+
1299+ total_loss = main_loss + emd_loss
1300+ self .log ("test_loss" , total_loss )
1301+ self .log ("test/decoder_loss" , emd_loss )
1302+ self .log ("test/mmd_loss" , main_loss )
12391303 else :
12401304 target = batch ["pert_cell_emb" ]
12411305 pred = self ._reshape_sequence (pred , padded = False , feature_dim = self .output_dim )
0 commit comments