Skip to content

Commit 6848678

Browse files
committed
Merge pull request #104 from Qianruipku:hotfix
PiperOrigin-RevId: 808807612 Change-Id: Iab503acc70b075e87e2a1f3fd4ff452547db265b
2 parents ce6b3c8 + 48931c6 commit 6848678

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ferminet/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def step(
378378
damping=shared_damping,
379379
)
380380

381-
if reset_if_nan and jnp.isnan(stats['loss']):
381+
if reset_if_nan and jnp.any(jnp.isnan(stats['loss'])):
382382
new_params = old_params
383383
new_state = old_state
384384
return data, new_params, new_state, stats['loss'], stats['aux'], pmove

0 commit comments

Comments
 (0)