Skip to content

Commit 0111b70

Browse files
committed
update to 0.1.3; minor fixes
1 parent 50ed12c commit 0111b70

File tree

3 files changed

+9
-16
lines changed

3 files changed

+9
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "deterministic_gaussian_sampling"
7-
version = "0.1.2"
7+
version = "0.1.3"
88
description = "Python library for Localized Distribution (LCD)-based Gaussian sampling."
99
readme = "README.md"
1010
requires-python = ">=3.8"

src/deterministic_gaussian_sampling/approximation/base_approximation.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,9 @@ def _check_numpy_ndarray(
2828
return None
2929
if not isinstance(arr, numpy.ndarray):
3030
raise TypeError("Input must be a numpy array")
31-
if (
32-
arr.dtype != float
33-
and arr.dtype != numpy.float16
34-
and arr.dtype != numpy.float32
35-
and arr.dtype != numpy.float64
36-
and arr.dtype != numpy.float96
37-
and arr.dtype != numpy.float128
38-
):
31+
if not numpy.issubdtype(arr.dtype, numpy.floating) and arr.dtype != float:
3932
raise TypeError(
40-
f"Input array must be of [float, numpy.float16, numpy.float32, numpy.float64, numpy.float96, numpy.float128], but got {arr.dtype}."
33+
f"Input array must be of a floating type, but got {arr.dtype}."
4134
)
4235
if arr.shape != (L, N):
4336
row, cols = arr.shape

src/deterministic_gaussian_sampling/approximation/gaussian_to_dirac.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def approximate_double(
4242
success = cdll.gm_to_dirac_short_double_approximate(
4343
self.gm_to_dirac_double,
4444
self._check_numpy_ndarray(covDiag, covDiag.shape[0], covDiag.shape[0]),
45-
ctypes.c_int(L),
46-
ctypes.c_int(N),
47-
ctypes.c_int(100),
45+
ctypes.c_size_t(L),
46+
ctypes.c_size_t(N),
47+
ctypes.c_size_t(100),
4848
self._check_numpy_ndarray(x, L, N),
4949
self._check_numpy_ndarray(wX, L, 1),
5050
ctypes.byref(minimizer_result),
@@ -72,9 +72,9 @@ def approximate_snd_double(
7272
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
7373
success = cdll.gm_to_dirac_short_standard_normal_deviation_double_approximate(
7474
self.gm_to_dirac_snd_double,
75-
ctypes.c_int(L),
76-
ctypes.c_int(N),
77-
ctypes.c_int(100),
75+
ctypes.c_size_t(L),
76+
ctypes.c_size_t(N),
77+
ctypes.c_size_t(100),
7878
self._check_numpy_ndarray(x, L, N),
7979
self._check_numpy_ndarray(wX, L, 1),
8080
ctypes.byref(minimizer_result),

0 commit comments

Comments
 (0)