Skip to content

Commit ad1b8ba

Browse files
committed
another discretization fix
1 parent 57bf2d5 commit ad1b8ba

File tree

1 file changed

+79
-15
lines changed

1 file changed

+79
-15
lines changed

src/state/tx/models/state_transition.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -875,10 +875,12 @@ def _dose_smoothness_loss(self, batch: Dict[str, torch.Tensor], pred: torch.Tens
875875
return pred.new_tensor(0.0)
876876
return torch.stack(losses).mean()
877877

878-
def _compute_distribution_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
878+
def _compute_distribution_loss(
879+
self, pred: torch.Tensor, target: torch.Tensor, *, allow_discrete: bool = False
880+
) -> torch.Tensor:
879881
"""Apply the primary distributional loss, optionally chunking feature dimensions for SamplesLoss."""
880882

881-
if self.discretize:
883+
if self.discretize and not allow_discrete:
882884
raise RuntimeError("Distributional loss is not used when discretize is enabled.")
883885

884886
if isinstance(self.loss_fn, SamplesLoss) and self.mmd_num_chunks > 1:
@@ -905,6 +907,18 @@ def _reshape_logits(self, logits: torch.Tensor, padded: bool) -> torch.Tensor:
905907
reshaped = self._reshape_sequence(logits, padded, self._project_out_dim)
906908
return reshaped.view(reshaped.size(0), reshaped.size(1), self._prediction_dim, self.num_expression_buckets)
907909

910+
def _expected_expression_from_logits(self, logits: torch.Tensor) -> torch.Tensor:
911+
"""Compute expected expression values by weighing bucket means with predicted probabilities."""
912+
913+
if self.bucket_boundaries is None:
914+
raise RuntimeError("Bucket boundaries must be initialized to compute expectations from logits.")
915+
916+
means = self._bucket_means().to(logits.device)
917+
probs = F.softmax(logits, dim=-1)
918+
expanded_means = means.unsqueeze(0).unsqueeze(0)
919+
expectation = (probs * expanded_means).sum(dim=-1)
920+
return expectation
921+
908922
def _get_target_expression(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
909923
if "pert_cell_counts" in batch and batch["pert_cell_counts"] is not None:
910924
return batch["pert_cell_counts"]
@@ -1008,11 +1022,31 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
10081022
with torch.no_grad():
10091023
bucket_targets = self._discretize_expression(target_expression)
10101024
self._update_bucket_statistics(target_expression, bucket_targets)
1011-
per_token_losses = self._emd_per_token(logits, bucket_targets)
1012-
per_set_main_losses = per_token_losses.view(per_token_losses.size(0), -1).mean(dim=1)
1013-
main_loss = per_set_main_losses.mean()
1014-
self.log("train_loss", main_loss)
1015-
total_loss = main_loss
1025+
1026+
per_token_emd_losses = self._emd_per_token(logits, bucket_targets)
1027+
emd_per_set = per_token_emd_losses.view(per_token_emd_losses.size(0), -1).mean(dim=1)
1028+
emd_loss = emd_per_set.mean()
1029+
1030+
expected_expression = self._expected_expression_from_logits(logits)
1031+
per_set_main_losses = self._compute_distribution_loss(
1032+
expected_expression,
1033+
target_expression,
1034+
allow_discrete=True,
1035+
)
1036+
main_loss = torch.nanmean(per_set_main_losses)
1037+
1038+
if hasattr(self.loss_fn, "sinkhorn_loss") and hasattr(self.loss_fn, "energy_loss"):
1039+
sinkhorn_component = self.loss_fn.sinkhorn_loss(expected_expression, target_expression).nanmean()
1040+
energy_component = self.loss_fn.energy_loss(expected_expression, target_expression).nanmean()
1041+
self.log("train/sinkhorn_loss", sinkhorn_component)
1042+
self.log("train/energy_loss", energy_component)
1043+
1044+
self.log("decoder_loss", emd_loss)
1045+
self.log("train/emd_loss", emd_loss)
1046+
self.log("train/mmd_loss", main_loss)
1047+
1048+
total_loss = main_loss + emd_loss
1049+
self.log("train_loss", total_loss)
10161050
else:
10171051
target = batch["pert_cell_emb"]
10181052
pred = self._reshape_sequence(pred, padded, self.output_dim)
@@ -1159,10 +1193,28 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
11591193
target_expression = self._reshape_sequence(target_expression, padded=True, feature_dim=self._prediction_dim)
11601194
with torch.no_grad():
11611195
bucket_targets = self._discretize_expression(target_expression)
1162-
per_token_losses = self._emd_per_token(logits, bucket_targets)
1163-
per_set_main_losses = per_token_losses.view(per_token_losses.size(0), -1).mean(dim=1)
1164-
loss = per_set_main_losses.mean()
1165-
self.log("val_loss", loss)
1196+
per_token_emd_losses = self._emd_per_token(logits, bucket_targets)
1197+
emd_per_set = per_token_emd_losses.view(per_token_emd_losses.size(0), -1).mean(dim=1)
1198+
emd_loss = emd_per_set.mean()
1199+
1200+
expected_expression = self._expected_expression_from_logits(logits)
1201+
per_set_main_losses = self._compute_distribution_loss(
1202+
expected_expression,
1203+
target_expression,
1204+
allow_discrete=True,
1205+
)
1206+
main_loss = torch.nanmean(per_set_main_losses)
1207+
1208+
if hasattr(self.loss_fn, "sinkhorn_loss") and hasattr(self.loss_fn, "energy_loss"):
1209+
sinkhorn_component = self.loss_fn.sinkhorn_loss(expected_expression, target_expression).mean()
1210+
energy_component = self.loss_fn.energy_loss(expected_expression, target_expression).mean()
1211+
self.log("val/sinkhorn_loss", sinkhorn_component)
1212+
self.log("val/energy_loss", energy_component)
1213+
1214+
total_loss = main_loss + emd_loss
1215+
self.log("val_loss", total_loss)
1216+
self.log("val/decoder_loss", emd_loss)
1217+
self.log("val/mmd_loss", main_loss)
11661218
else:
11671219
pred = self._reshape_sequence(pred, padded=True, feature_dim=self.output_dim)
11681220
target = batch["pert_cell_emb"]
@@ -1232,10 +1284,22 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
12321284
target_expression = self._reshape_sequence(target_expression, padded=False, feature_dim=self._prediction_dim)
12331285
with torch.no_grad():
12341286
bucket_targets = self._discretize_expression(target_expression)
1235-
per_token_losses = self._emd_per_token(logits, bucket_targets)
1236-
per_set_main_losses = per_token_losses.view(per_token_losses.size(0), -1).mean(dim=1)
1237-
loss = per_set_main_losses.mean()
1238-
self.log("test_loss", loss)
1287+
per_token_emd_losses = self._emd_per_token(logits, bucket_targets)
1288+
emd_per_set = per_token_emd_losses.view(per_token_emd_losses.size(0), -1).mean(dim=1)
1289+
emd_loss = emd_per_set.mean()
1290+
1291+
expected_expression = self._expected_expression_from_logits(logits)
1292+
per_set_main_losses = self._compute_distribution_loss(
1293+
expected_expression,
1294+
target_expression,
1295+
allow_discrete=True,
1296+
)
1297+
main_loss = torch.nanmean(per_set_main_losses)
1298+
1299+
total_loss = main_loss + emd_loss
1300+
self.log("test_loss", total_loss)
1301+
self.log("test/decoder_loss", emd_loss)
1302+
self.log("test/mmd_loss", main_loss)
12391303
else:
12401304
target = batch["pert_cell_emb"]
12411305
pred = self._reshape_sequence(pred, padded=False, feature_dim=self.output_dim)

0 commit comments

Comments
 (0)