Skip to content

Commit 7c37925

Browse files
committed
Added metropolis hastings sampling for jax
1 parent 11dc74f commit 7c37925

File tree

3 files changed

+148
-24
lines changed

3 files changed

+148
-24
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ignore-patterns=^\.#
6363
# (useful for modules/projects where namespaces are manipulated during runtime
6464
# and thus existing member attributes cannot be deduced by static analysis). It
6565
# supports qualified module names, as well as Unix pattern matching.
66-
ignored-modules=pyrecest.backend
66+
ignored-modules=pyrecest.backend, jax
6767

6868
# Python code to execute, usually for sys.path manipulation such as
6969
# pygtk.require().

pyrecest/distributions/abstract_manifold_specific_distribution.py

Lines changed: 142 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Callable
33
from typing import Union
4+
import inspect
45

56
import pyrecest.backend
67

7-
# pylint: disable=no-name-in-module,no-member
8+
# pylint: disable=no-name-in-module,no-member,redefined-builtin
89
from pyrecest.backend import empty, int32, int64, log, random, squeeze
910

1011

@@ -64,13 +65,14 @@ def set_mode(self, _):
6465
"""
6566
raise NotImplementedError("set_mode is not implemented for this distribution")
6667

68+
# Need to use Union instead of | to support torch.dtype
6769
# Need to use Union instead of | to support torch.dtype
6870
def sample(self, n: Union[int, int32, int64]):
6971
"""Obtain n samples from the distribution."""
7072
return self.sample_metropolis_hastings(n)
7173

7274
# jscpd:ignore-start
73-
# pylint: disable=too-many-positional-arguments
75+
# pylint: disable=too-many-positional-arguments,too-many-locals
7476
def sample_metropolis_hastings(
7577
self,
7678
n: Union[int, int32, int64],
@@ -81,30 +83,48 @@ def sample_metropolis_hastings(
8183
):
8284
# jscpd:ignore-end
8385
"""Metropolis Hastings sampling algorithm."""
84-
assert (
85-
pyrecest.backend.__backend_name__ != "jax"
86-
), "Not supported on this backend"
86+
if pyrecest.backend.__backend_name__ == "jax":
87+
# Get a key from your global JAX random state *outside* of lax.scan
88+
import jax as _jax
89+
90+
key = random.jax_global_random_state()
91+
key, key_for_mh = _jax.random.split(key)
92+
# Optionally update global state for future calls
93+
random.jax_global_random_state(key)
94+
95+
if proposal is None or start_point is None:
96+
raise NotImplementedError(
97+
"Default proposals and starting points should be set in inheriting classes."
98+
)
99+
_assert_proposal_supports_key(proposal)
100+
101+
samples, _ = sample_metropolis_hastings_jax(
102+
key=key_for_mh,
103+
log_pdf=self.ln_pdf,
104+
proposal=proposal, # must be (key, x) -> x_prop for JAX
105+
start_point=start_point,
106+
n=int(n),
107+
burn_in=int(burn_in),
108+
skipping=int(skipping),
109+
)
110+
# You could optionally stash `key_out` somewhere if you want chain continuation.
111+
return squeeze(samples)
112+
113+
# Non-JAX backends → your old NumPy/Torch code
87114
if proposal is None or start_point is None:
88115
raise NotImplementedError(
89116
"Default proposals and starting points should be set in inheriting classes."
90117
)
91118

92119
total_samples = burn_in + n * skipping
93-
s = empty(
94-
(
95-
total_samples,
96-
self.input_dim,
97-
),
98-
)
120+
s = empty((total_samples, self.input_dim))
99121
x = start_point
100122
i = 0
101123
pdfx = self.pdf(x)
102124

103125
while i < total_samples:
104126
x_new = proposal(x)
105-
assert (
106-
x_new.shape == x.shape
107-
), "Proposal must return a vector of same shape as input"
127+
assert x_new.shape == x.shape, "Proposal must return a vector of same shape as input"
108128
pdfx_new = self.pdf(x_new)
109129
a = pdfx_new / pdfx
110130
if a.item() > 1 or a.item() > random.rand(1):
@@ -115,3 +135,111 @@ def sample_metropolis_hastings(
115135

116136
relevant_samples = s[burn_in::skipping, :]
117137
return squeeze(relevant_samples)
138+
139+
# pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments
140+
def sample_metropolis_hastings_jax(
141+
key,
142+
log_pdf, # function: x -> log p(x)
143+
proposal, # function: (key, x) -> x_prop
144+
start_point,
145+
n: int,
146+
burn_in: int = 10,
147+
skipping: int = 5,
148+
):
149+
"""
150+
Metropolis-Hastings sampler in JAX.
151+
152+
key: jax.random.PRNGKey
153+
log_pdf: callable x -> log p(x)
154+
proposal: callable (key, x) -> x_proposed
155+
start_point: initial state (array)
156+
n: number of samples to return (after burn-in and thinning)
157+
"""
158+
import jax.numpy as _jnp
159+
from jax import _random, _lax
160+
161+
162+
start_point = _jnp.asarray(start_point)
163+
total_steps = burn_in + n * skipping
164+
165+
def one_step(carry, _):
166+
key, x, log_px = carry
167+
key, key_prop, key_u = _random.split(key, 3)
168+
169+
# Propose new state
170+
x_prop = proposal(key_prop, x)
171+
log_px_prop = log_pdf(x_prop)
172+
173+
# log_alpha = log p(x_prop) - log p(x)
174+
log_alpha = log_px_prop - log_px
175+
176+
# Draw u ~ Uniform(0, 1)
177+
u = _random.uniform(key_u, shape=())
178+
log_u = _jnp.log(u)
179+
180+
# Accept if log u < min(0, log_alpha)
181+
# (equivalent to u < exp(min(0, log_alpha)))
182+
log_alpha_capped = _jnp.minimum(0.0, log_alpha)
183+
accept = log_u < log_alpha_capped # scalar bool
184+
185+
# Branch without Python if
186+
x_new = _jnp.where(accept, x_prop, x)
187+
log_px_new = _jnp.where(accept, log_px_prop, log_px)
188+
189+
return (key, x_new, log_px_new), x_new
190+
191+
init_carry = (key, start_point, log_pdf(start_point))
192+
(key_out, _, _), chain = _lax.scan(
193+
one_step,
194+
init_carry,
195+
xs=None,
196+
length=total_steps,
197+
)
198+
199+
samples = chain[burn_in::skipping]
200+
return samples, key_out
201+
202+
203+
def _assert_proposal_supports_key(proposal: Callable):
204+
"""
205+
Check that `proposal` can be called as proposal(key, x).
206+
207+
Raises a TypeError with a helpful message if this is not the case.
208+
"""
209+
# Unwrap jitted / partial / decorated functions if possible
210+
func = proposal
211+
while hasattr(func, "__wrapped__"):
212+
func = func.__wrapped__
213+
214+
try:
215+
sig = inspect.signature(func)
216+
except (TypeError, ValueError):
217+
# Can't introspect (e.g. builtins); fall back to a generic error
218+
raise TypeError(
219+
"For the JAX backend, `proposal` must accept (key, x) as arguments, "
220+
"but its signature could not be inspected."
221+
) from None
222+
223+
params = list(sig.parameters.values())
224+
225+
# Count positional(-or-keyword) parameters
226+
num_positional = sum(
227+
p.kind in (inspect.Parameter.POSITIONAL_ONLY,
228+
inspect.Parameter.POSITIONAL_OR_KEYWORD)
229+
for p in params
230+
)
231+
has_var_positional = any(
232+
p.kind == inspect.Parameter.VAR_POSITIONAL
233+
for p in params
234+
)
235+
236+
if has_var_positional or num_positional >= 2:
237+
# Looks compatible with (key, x)
238+
return
239+
240+
raise TypeError(
241+
"For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n"
242+
f"Got signature: {sig}\n"
243+
"Hint: change your proposal from `def proposal(x): ...` to\n"
244+
"`def proposal(key, x): ...` and use `jax.random` with the passed key."
245+
)

pyrecest/tests/distributions/test_abstract_mixture.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ def _test_sample(self, mix, n):
2525
self.assertEqual(s.shape, (n, mix.input_dim))
2626
return s
2727

28-
@unittest.skipIf(
29-
pyrecest.backend.__backend_name__ == "jax",
30-
reason="Not supported on this backend",
31-
)
3228
def test_sample_metropolis_hastings_basics_only_t2(self):
3329
vmf = ToroidalWrappedNormalDistribution(array([1.0, 0.0]), eye(2))
3430
mix = HypertoroidalMixture(
@@ -37,28 +33,28 @@ def test_sample_metropolis_hastings_basics_only_t2(self):
3733
self._test_sample(mix, 10)
3834

3935
@unittest.skipIf(
40-
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
36+
pyrecest.backend.__backend_name__ in ("pytorch",),
4137
reason="Not supported on this backend",
4238
)
4339
def test_sample_metropolis_hastings_basics_only_s2(self):
4440
vmf1 = VonMisesFisherDistribution(
4541
array([1.0, 0.0, 0.0]), 2.0
46-
) # Needs to be float for scipy
42+
)
4743
vmf2 = VonMisesFisherDistribution(
4844
array([0.0, 1.0, 0.0]), 2.0
49-
) # Needs to be float for scipy
45+
)
5046
mix = HypersphericalMixture([vmf1, vmf2], array([0.5, 0.5]))
5147
s = self._test_sample(mix, 10)
5248
self.assertTrue(allclose(linalg.norm(s, axis=1), ones(10), rtol=1e-10))
5349

5450
@unittest.skipIf(
55-
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
51+
pyrecest.backend.__backend_name__ in ("pytorch",),
5652
reason="Not supported on this backend",
5753
)
5854
def test_sample_metropolis_hastings_basics_only_h2(self):
5955
vmf = VonMisesFisherDistribution(
6056
array([1.0, 0.0, 0.0]), 2.0
61-
) # Needs to be float for scipy
57+
)
6258
mix = CustomHyperhemisphericalDistribution(
6359
lambda x: vmf.pdf(x) + vmf.pdf(-x), 2
6460
)

0 commit comments

Comments
 (0)