@@ -51,17 +51,21 @@ def _compute_rspace(
5151 charges : torch .Tensor ,
5252 neighbor_indices : torch .Tensor ,
5353 neighbor_distances : torch .Tensor ,
54+ node_mask : torch .Tensor | None = None ,
55+ pair_mask : torch .Tensor | None = None ,
5456 ) -> torch .Tensor :
5557 # Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
5658 # contained in the neighbor list
5759 with profiler .record_function ("compute bare potential" ):
5860 if self .potential .smearing is None :
5961 if self .potential .exclusion_radius is None :
60- potentials_bare = self .potential .from_dist (neighbor_distances )
61- else :
62- potentials_bare = self .potential .from_dist (neighbor_distances ) * (
63- 1 - self .potential .f_cutoff (neighbor_distances )
62+ potentials_bare = self .potential .from_dist (
63+ neighbor_distances , pair_mask = pair_mask
6464 )
65+ else :
66+ potentials_bare = self .potential .from_dist (
67+ neighbor_distances , pair_mask = pair_mask
68+ ) * (1 - self .potential .f_cutoff (neighbor_distances ))
6569 else :
6670 potentials_bare = self .potential .sr_from_dist (neighbor_distances )
6771
@@ -109,6 +113,8 @@ def forward(
109113 neighbor_indices : torch .Tensor ,
110114 neighbor_distances : torch .Tensor ,
111115 periodic : Optional [torch .Tensor ] = None ,
116+ node_mask : torch .Tensor | None = None ,
117+ pair_mask : torch .Tensor | None = None ,
112118 ):
113119 r"""
114120 Compute the potential "energy".
@@ -161,6 +167,8 @@ def forward(
161167 charges = charges ,
162168 neighbor_indices = neighbor_indices ,
163169 neighbor_distances = neighbor_distances ,
170+ node_mask = node_mask ,
171+ pair_mask = pair_mask ,
164172 )
165173
166174 if self .potential .smearing is None :
0 commit comments