11from abc import ABC , abstractmethod
22from collections .abc import Callable
33from typing import Union
4+ import inspect
45
56import pyrecest .backend
67
7- # pylint: disable=no-name-in-module,no-member
8+ # pylint: disable=no-name-in-module,no-member,redefined-builtin
89from pyrecest .backend import empty , int32 , int64 , log , random , squeeze
910
1011
@@ -64,13 +65,14 @@ def set_mode(self, _):
6465 """
6566 raise NotImplementedError ("set_mode is not implemented for this distribution" )
6667
68+ # Need to use Union instead of | to support torch.dtype
6769 # Need to use Union instead of | to support torch.dtype
6870 def sample (self , n : Union [int , int32 , int64 ]):
6971 """Obtain n samples from the distribution."""
7072 return self .sample_metropolis_hastings (n )
7173
7274 # jscpd:ignore-start
73- # pylint: disable=too-many-positional-arguments
75+ # pylint: disable=too-many-positional-arguments,too-many-locals
7476 def sample_metropolis_hastings (
7577 self ,
7678 n : Union [int , int32 , int64 ],
@@ -81,30 +83,48 @@ def sample_metropolis_hastings(
8183 ):
8284 # jscpd:ignore-end
8385 """Metropolis Hastings sampling algorithm."""
84- assert (
85- pyrecest .backend .__backend_name__ != "jax"
86- ), "Not supported on this backend"
86+ if pyrecest .backend .__backend_name__ == "jax" :
87+ # Get a key from your global JAX random state *outside* of lax.scan
88+ import jax as _jax
89+
90+ key = random .jax_global_random_state ()
91+ key , key_for_mh = _jax .random .split (key )
92+ # Optionally update global state for future calls
93+ random .jax_global_random_state (key )
94+
95+ if proposal is None or start_point is None :
96+ raise NotImplementedError (
97+ "Default proposals and starting points should be set in inheriting classes."
98+ )
99+ _assert_proposal_supports_key (proposal )
100+
101+ samples , _ = sample_metropolis_hastings_jax (
102+ key = key_for_mh ,
103+ log_pdf = self .ln_pdf ,
104+ proposal = proposal , # must be (key, x) -> x_prop for JAX
105+ start_point = start_point ,
106+ n = int (n ),
107+ burn_in = int (burn_in ),
108+ skipping = int (skipping ),
109+ )
110+ # You could optionally stash `key_out` somewhere if you want chain continuation.
111+ return squeeze (samples )
112+
113+ # Non-JAX backends → your old NumPy/Torch code
87114 if proposal is None or start_point is None :
88115 raise NotImplementedError (
89116 "Default proposals and starting points should be set in inheriting classes."
90117 )
91118
92119 total_samples = burn_in + n * skipping
93- s = empty (
94- (
95- total_samples ,
96- self .input_dim ,
97- ),
98- )
120+ s = empty ((total_samples , self .input_dim ))
99121 x = start_point
100122 i = 0
101123 pdfx = self .pdf (x )
102124
103125 while i < total_samples :
104126 x_new = proposal (x )
105- assert (
106- x_new .shape == x .shape
107- ), "Proposal must return a vector of same shape as input"
127+ assert x_new .shape == x .shape , "Proposal must return a vector of same shape as input"
108128 pdfx_new = self .pdf (x_new )
109129 a = pdfx_new / pdfx
110130 if a .item () > 1 or a .item () > random .rand (1 ):
@@ -115,3 +135,111 @@ def sample_metropolis_hastings(
115135
116136 relevant_samples = s [burn_in ::skipping , :]
117137 return squeeze (relevant_samples )
138+
139+ # pylint: disable=too-many-positional-arguments,too-many-locals,too-many-arguments
140+ def sample_metropolis_hastings_jax (
141+ key ,
142+ log_pdf , # function: x -> log p(x)
143+ proposal , # function: (key, x) -> x_prop
144+ start_point ,
145+ n : int ,
146+ burn_in : int = 10 ,
147+ skipping : int = 5 ,
148+ ):
149+ """
150+ Metropolis-Hastings sampler in JAX.
151+
152+ key: jax.random.PRNGKey
153+ log_pdf: callable x -> log p(x)
154+ proposal: callable (key, x) -> x_proposed
155+ start_point: initial state (array)
156+ n: number of samples to return (after burn-in and thinning)
157+ """
158+ import jax .numpy as _jnp
159+ from jax import _random , _lax
160+
161+
162+ start_point = _jnp .asarray (start_point )
163+ total_steps = burn_in + n * skipping
164+
165+ def one_step (carry , _ ):
166+ key , x , log_px = carry
167+ key , key_prop , key_u = _random .split (key , 3 )
168+
169+ # Propose new state
170+ x_prop = proposal (key_prop , x )
171+ log_px_prop = log_pdf (x_prop )
172+
173+ # log_alpha = log p(x_prop) - log p(x)
174+ log_alpha = log_px_prop - log_px
175+
176+ # Draw u ~ Uniform(0, 1)
177+ u = _random .uniform (key_u , shape = ())
178+ log_u = _jnp .log (u )
179+
180+ # Accept if log u < min(0, log_alpha)
181+ # (equivalent to u < exp(min(0, log_alpha)))
182+ log_alpha_capped = _jnp .minimum (0.0 , log_alpha )
183+ accept = log_u < log_alpha_capped # scalar bool
184+
185+ # Branch without Python if
186+ x_new = _jnp .where (accept , x_prop , x )
187+ log_px_new = _jnp .where (accept , log_px_prop , log_px )
188+
189+ return (key , x_new , log_px_new ), x_new
190+
191+ init_carry = (key , start_point , log_pdf (start_point ))
192+ (key_out , _ , _ ), chain = _lax .scan (
193+ one_step ,
194+ init_carry ,
195+ xs = None ,
196+ length = total_steps ,
197+ )
198+
199+ samples = chain [burn_in ::skipping ]
200+ return samples , key_out
201+
202+
203+ def _assert_proposal_supports_key (proposal : Callable ):
204+ """
205+ Check that `proposal` can be called as proposal(key, x).
206+
207+ Raises a TypeError with a helpful message if this is not the case.
208+ """
209+ # Unwrap jitted / partial / decorated functions if possible
210+ func = proposal
211+ while hasattr (func , "__wrapped__" ):
212+ func = func .__wrapped__
213+
214+ try :
215+ sig = inspect .signature (func )
216+ except (TypeError , ValueError ):
217+ # Can't introspect (e.g. builtins); fall back to a generic error
218+ raise TypeError (
219+ "For the JAX backend, `proposal` must accept (key, x) as arguments, "
220+ "but its signature could not be inspected."
221+ ) from None
222+
223+ params = list (sig .parameters .values ())
224+
225+ # Count positional(-or-keyword) parameters
226+ num_positional = sum (
227+ p .kind in (inspect .Parameter .POSITIONAL_ONLY ,
228+ inspect .Parameter .POSITIONAL_OR_KEYWORD )
229+ for p in params
230+ )
231+ has_var_positional = any (
232+ p .kind == inspect .Parameter .VAR_POSITIONAL
233+ for p in params
234+ )
235+
236+ if has_var_positional or num_positional >= 2 :
237+ # Looks compatible with (key, x)
238+ return
239+
240+ raise TypeError (
241+ "For the JAX backend, `proposal` must accept `(key, x)` as arguments.\n "
242+ f"Got signature: { sig } \n "
243+ "Hint: change your proposal from `def proposal(x): ...` to\n "
244+ "`def proposal(key, x): ...` and use `jax.random` with the passed key."
245+ )
0 commit comments