|
| 1 | +from typing import Optional |
| 2 | + |
1 | 3 | import torch |
2 | | -from torch import profiler |
3 | 4 |
|
4 | 5 | from ..lib.kspace_filter import KSpaceFilter |
5 | 6 | from ..lib.kvectors import get_ns_mesh |
@@ -97,51 +98,48 @@ def _compute_kspace( |
97 | 98 | charges: torch.Tensor, |
98 | 99 | cell: torch.Tensor, |
99 | 100 | positions: torch.Tensor, |
| 101 | + periodic: Optional[torch.Tensor] = None, |
100 | 102 | ) -> torch.Tensor: |
101 | 103 | # TODO: Kernel function `G` and initialization of `MeshInterpolator` only depend |
102 | 104 | # on `cell`. Caching may save up to 15% but issues with AD need to be resolved. |
103 | 105 |
|
104 | | - with profiler.record_function("init 0: preparation"): |
105 | | - # Compute number of times each basis vector of the reciprocal space can be |
106 | | - # scaled until the cutoff is reached |
107 | | - ns = get_ns_mesh(cell, self.mesh_spacing) |
| 106 | + # Compute number of times each basis vector of the reciprocal space can be |
| 107 | + # scaled until the cutoff is reached |
| 108 | + ns = get_ns_mesh(cell, self.mesh_spacing) |
108 | 109 |
|
109 | | - with profiler.record_function("init 1: update mesh interpolator"): |
110 | | - self.mesh_interpolator.update(cell, ns) |
| 110 | + self.mesh_interpolator.update(cell, ns) |
111 | 111 |
|
112 | | - with profiler.record_function("update the mesh for the k-space filter"): |
113 | | - self.kspace_filter.update(cell, ns) |
| 112 | + self.kspace_filter.update(cell, ns) |
114 | 113 |
|
115 | | - with profiler.record_function("step 1: compute density interpolation"): |
116 | | - self.mesh_interpolator.compute_weights(positions) |
117 | | - rho_mesh = self.mesh_interpolator.points_to_mesh(particle_weights=charges) |
| 114 | + self.mesh_interpolator.compute_weights(positions) |
| 115 | + rho_mesh = self.mesh_interpolator.points_to_mesh(particle_weights=charges) |
118 | 116 |
|
119 | | - with profiler.record_function("step 2: perform actual convolution using FFT"): |
120 | | - potential_mesh = self.kspace_filter.forward(rho_mesh) |
| 117 | + potential_mesh = self.kspace_filter.forward(rho_mesh) |
121 | 118 |
|
122 | | - with profiler.record_function("step 3: back interpolation + volume scaling"): |
123 | | - ivolume = torch.abs(cell.det()).pow(-1) |
124 | | - interpolated_potential = ( |
125 | | - self.mesh_interpolator.mesh_to_points(potential_mesh) * ivolume |
126 | | - ) |
| 119 | + ivolume = torch.abs(cell.det()).pow(-1) |
| 120 | + interpolated_potential = ( |
| 121 | + self.mesh_interpolator.mesh_to_points(potential_mesh) * ivolume |
| 122 | + ) |
127 | 123 |
|
128 | | - with profiler.record_function("step 4: remove the self-contribution"): |
129 | | - # Using the Coulomb potential as an example, this is the potential generated |
130 | | - # at the origin by the fictituous Gaussian charge density in order to split |
131 | | - # the potential into a SR and LR part. This contribution always should be |
132 | | - # subtracted since it depends on the smearing parameter, which is purely a |
133 | | - # convergence parameter. |
134 | | - interpolated_potential -= charges * self.potential.self_contribution() |
135 | | - |
136 | | - with profiler.record_function("step 5: charge neutralization"): |
137 | | - # If the cell has a net charge (i.e. if sum(charges) != 0), the method |
138 | | - # implicitly assumes that a homogeneous background charge of the opposite |
139 | | - # sign is present to make the cell neutral. In this case, the potential has |
140 | | - # to be adjusted to compensate for this. An extra factor of 2 is added to |
141 | | - # compensate for the division by 2 later on |
142 | | - charge_tot = torch.sum(charges, dim=0) |
143 | | - prefac = self.potential.background_correction() |
144 | | - interpolated_potential -= 2 * prefac * charge_tot * ivolume |
| 124 | + # Using the Coulomb potential as an example, this is the potential generated |
| 125 | + # at the origin by the fictituous Gaussian charge density in order to split |
| 126 | + # the potential into a SR and LR part. This contribution always should be |
| 127 | + # subtracted since it depends on the smearing parameter, which is purely a |
| 128 | + # convergence parameter. |
| 129 | + interpolated_potential -= charges * self.potential.self_contribution() |
| 130 | + |
| 131 | + # If the cell has a net charge (i.e. if sum(charges) != 0), the method |
| 132 | + # implicitly assumes that a homogeneous background charge of the opposite |
| 133 | + # sign is present to make the cell neutral. In this case, the potential has |
| 134 | + # to be adjusted to compensate for this. An extra factor of 2 is added to |
| 135 | + # compensate for the division by 2 later on |
| 136 | + charge_tot = torch.sum(charges, dim=0) |
| 137 | + prefac = self.potential.background_correction() |
| 138 | + interpolated_potential -= 2 * prefac * charge_tot * ivolume |
| 139 | + |
| 140 | + interpolated_potential += self.potential.pbc_correction( |
| 141 | + periodic, positions, cell, charges |
| 142 | + ) |
145 | 143 |
|
146 | 144 | # Compensate for double counting of pairs (i,j) and (j,i) |
147 | 145 | return interpolated_potential / 2 |
0 commit comments