File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
pyrecest/distributions/nonperiodic Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ def pdf(self, xs):
4848 elif pyrecest .backend .__name__ == "pyrecest.pytorch" :
4949 # Disable import errors for megalinter
5050 import torch as _torch # pylint: disable=import-error
51+
5152 distribution = _torch .distributions .MultivariateNormal (self .mu , self .C )
5253 if xs .ndim == 1 and self .dim == 1 :
5354 # For 1-D distributions, we need to reshape the input to a 2-D tensor
@@ -57,7 +58,9 @@ def pdf(self, xs):
5758 pdfvals = _torch .exp (log_probs )
5859 elif pyrecest .backend .__name__ == "pyrecest.jax" :
5960 from jax import numpy as jnp # pylint: disable=import-error
60- from jax .scipy .stats import multivariate_normal # pylint: disable=import-error
61+ from jax .scipy .stats import ( # pylint: disable=import-error
62+ multivariate_normal ,
63+ )
6164
6265 if xs .ndim == 1 and self .dim == 1 :
6366 # For 1-D distributions, we need to reshape the input to a 2-D tensor
You can’t perform that action at this time.
0 commit comments