@@ -53,16 +53,16 @@ def build_baseline(
5353 outputs : List [Dict [str , torch .Tensor ]] = []
5454
5555 inp = gen_inputs (defn , workload , device = device , stensors = loaded_stensors )
56- if "probs" in inp :
57- inp ["probs" ] = torch .softmax (
58- inp ["probs" ], dim = - 1
59- ) # convert logits to probs for sampling
6056 inputs .append (inp )
6157
62- freq_dist = _compute_frequency_distribution (
63- ref_runnable , inp , device , defn , num_trials = 50000
64- )
65- outputs .append ({"frequency_distribution" : freq_dist })
58+ thresholding_method = _detect_thresholding_method (defn )
59+ params = {k : inp [k ] for k in ["top_k" , "top_p" ] if k in inp }
60+ valid_mask = _compute_valid_sampling_mask (inp ["probs" ], thresholding_method , params )
61+
62+ masked_probs = inp ["probs" ] * valid_mask .float ()
63+ expected_probs = masked_probs / masked_probs .sum (dim = - 1 , keepdim = True )
64+
65+ outputs .append ({"expected_probs" : expected_probs })
6666
6767 latencies : List [float ] = []
6868 for inp in inputs :
@@ -94,15 +94,20 @@ def check_correctness(
9494 log_path : str ,
9595 device : str ,
9696 ) -> Tuple [Optional [Correctness ], Optional [Evaluation ]]:
97- ref_freq = ref_outputs [0 ]["frequency_distribution " ]
98- vocab_size = ref_freq .shape [0 ]
97+ expected_probs = ref_outputs [0 ]["expected_probs " ]
98+ vocab_size = expected_probs .shape [- 1 ]
9999
100100 inp = inputs [0 ]
101101 params = {k : inp [k ] for k in ["top_k" , "top_p" ] if k in inp }
102102
103103 output_names = list (defn .outputs .keys ())
104104 output_dtypes = {k : dtype_str_to_torch_dtype (v .dtype ) for k , v in defn .outputs .items ()}
105105
106+ # Compute valid sampling mask based on thresholding
107+ thresholding_method = _detect_thresholding_method (defn )
108+ probs = inp ["probs" ]
109+ valid_mask = _compute_valid_sampling_mask (probs , thresholding_method , params )
110+
106111 # Validate correct sampling token set
107112 for _ in range (cfg .sampling_validation_trials ):
108113 try :
@@ -137,27 +142,32 @@ def check_correctness(
137142 correctness = correctness ,
138143 )
139144
140- # Validate thresholding
141- thresholding_method = _detect_thresholding_method (defn )
142- probs = inp ["probs" ]
143- if not _check_thresholding (samples , probs , thresholding_method , params ):
144- correctness = Correctness (
145- max_relative_error = float ("inf" ), max_absolute_error = float ("inf" )
146- )
147- message = (
148- f"Samples { samples .tolist ()} does not meet { thresholding_method } thresholding"
149- )
150- print (message , file = sys .stderr )
151- return correctness , make_eval (
152- status = EvaluationStatus .INCORRECT_NUMERICAL ,
153- device = device ,
154- log_path = log_path ,
155- correctness = correctness ,
156- )
145+ # Validate thresholding - check samples are within valid mask
146+ if samples .dim () == 0 :
147+ samples_flat = samples .unsqueeze (0 )
148+ else :
149+ samples_flat = samples .flatten ()
150+
151+ batch_size = valid_mask .shape [0 ]
152+ for i in range (len (samples_flat )):
153+ batch_idx = i % batch_size
154+ sample_idx = samples_flat [i ].item ()
155+ if not valid_mask [batch_idx , sample_idx ]:
156+ correctness = Correctness (
157+ max_relative_error = float ("inf" ), max_absolute_error = float ("inf" )
158+ )
159+ message = f"Sample { sample_idx } is outside valid { thresholding_method } mask for batch { batch_idx } "
160+ print (message , file = sys .stderr )
161+ return correctness , make_eval (
162+ status = EvaluationStatus .INCORRECT_NUMERICAL ,
163+ device = device ,
164+ log_path = log_path ,
165+ correctness = correctness ,
166+ )
157167
158168 try :
159- sol_freq = _compute_frequency_distribution (
160- sol_runnable , inp , device , defn , num_trials = 50000
169+ sol_freqs = _sample_token_distributions (
170+ sol_runnable , inp , device , defn , num_trials = 500000
161171 )
162172 torch .cuda .synchronize (device )
163173 except Exception :
@@ -166,13 +176,29 @@ def check_correctness(
166176 status = EvaluationStatus .RUNTIME_ERROR , device = device , log_path = log_path
167177 )
168178
169- # total variation distance
170- tvd = 0.5 * torch .sum (torch .abs (sol_freq - ref_freq )).item ()
171- max_abs , max_rel , _ , _ = compute_error_stats (sol_freq , ref_freq , cfg )
179+ batch_size = expected_probs .shape [0 ]
180+ tvds = []
181+ max_abs_errors = []
182+ max_rel_errors = []
183+
184+ for i in range (batch_size ):
185+ tvd_i = 0.5 * torch .sum (torch .abs (sol_freqs [i ] - expected_probs [i ])).item ()
186+ tvds .append (tvd_i )
187+
188+ max_abs_i , max_rel_i , _ , _ = compute_error_stats (sol_freqs [i ], expected_probs [i ], cfg )
189+ max_abs_errors .append (max_abs_i )
190+ max_rel_errors .append (max_rel_i )
172191
173- numerical_incorrect = tvd > cfg .sampling_tvd_threshold
192+ # Use the worst (max) TVD and errors across all batch elements
193+ max_tvd = max (tvds )
194+ max_abs = max (max_abs_errors )
195+ max_rel = max (max_rel_errors )
196+
197+ numerical_incorrect = max_tvd > cfg .sampling_tvd_threshold
174198 correctness = Correctness (
175- max_relative_error = max_rel , max_absolute_error = max_abs , extra = {"tvd" : tvd }
199+ max_relative_error = max_rel ,
200+ max_absolute_error = max_abs ,
201+ extra = {"tvd" : max_tvd , "tvds_per_batch" : tvds },
176202 )
177203 if numerical_incorrect :
178204 return correctness , make_eval (
@@ -201,23 +227,117 @@ def _detect_thresholding_method(defn: Definition) -> str:
201227 return "none" # no thresholding
202228
203229
204- def _compute_frequency_distribution (
230+ def _compute_valid_sampling_mask (
231+ probs : torch .Tensor , method : str , params : Dict [str , Any ], eps : float = 5e-2
232+ ) -> torch .Tensor :
233+ """
234+ For tie-breaking in top_k (allows any token with prob >= k-th largest)
235+ and numerical precision in top_p (allows tokens within eps of nucleus boundary).
236+ """
237+ if probs .dim () == 1 :
238+ probs = probs .unsqueeze (0 )
239+
240+ batch_size , vocab_size = probs .shape
241+ device = probs .device
242+
243+ if method == "none" :
244+ return torch .ones ((batch_size , vocab_size ), dtype = torch .bool , device = device )
245+
246+ mask = torch .ones ((batch_size , vocab_size ), dtype = torch .bool , device = device )
247+
248+ if method in ["top_k" , "top_k_top_p" ]:
249+ if "top_k" not in params :
250+ raise ValueError (f"top_k parameter required for { method } but not found" )
251+
252+ top_k_param = params ["top_k" ]
253+ for i in range (batch_size ):
254+ k = int (top_k_param [i ].item ()) if top_k_param .dim () > 0 else int (top_k_param .item ())
255+
256+ if 0 < k < vocab_size :
257+ sorted_probs , _ = torch .sort (probs [i ], descending = True )
258+ # k-th largest value (0-indexed, so k-1)
259+ pivot = sorted_probs [k - 1 ]
260+ mask [i ] = probs [i ] >= pivot # tie-breaking handling
261+
262+ # Apply top_p mask with epsilon tolerance
263+ if method in ["top_p" , "top_k_top_p" ]:
264+ if "top_p" not in params :
265+ raise ValueError (f"top_p parameter required for { method } but not found" )
266+
267+ top_p_param = params ["top_p" ]
268+ for i in range (batch_size ):
269+ p = float (top_p_param [i ].item ()) if top_p_param .dim () > 0 else float (top_p_param .item ())
270+
271+ if 0 < p < 1 :
272+ sorted_probs , sorted_indices = torch .sort (probs [i ], descending = True )
273+ cumsum = torch .cumsum (sorted_probs , dim = 0 )
274+
275+ # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
276+ nucleus_mask = cumsum <= (p + eps )
277+
278+ if not nucleus_mask .any ():
279+ nucleus_mask [0 ] = True
280+
281+ # Map back to original indices
282+ p_mask = torch .zeros (vocab_size , dtype = torch .bool , device = device )
283+ p_mask [sorted_indices [nucleus_mask ]] = True
284+
285+ mask [i ] = mask [i ] & p_mask
286+
287+ return mask
288+
289+
290+ def _sample_token_distributions (
205291 runnable : Runnable ,
206292 inputs : Dict [str , Any ],
207293 device : str ,
208294 defn : Definition ,
209- num_trials : int = 10000 ,
295+ num_trials : int = 500000 ,
210296) -> torch .Tensor :
211- batch_size = inputs ["probs" ].shape [0 ] if inputs ["probs" ].dim () > 1 else 1
297+ original_batch_size = inputs ["probs" ].shape [0 ] if inputs ["probs" ].dim () > 1 else 1
212298 vocab_size = inputs ["probs" ].shape [- 1 ]
213- counter = torch .zeros (vocab_size , dtype = torch .int64 , device = torch .device (device ))
214299
215- trials_needed = (num_trials + batch_size - 1 ) // batch_size
216- total_samples_collected = 0
300+ # Repeat entire input batch to fill up to target_batch_size for efficient sampling
301+ target_batch_size = 10000
302+ repeat_count = target_batch_size // original_batch_size
303+ actual_batch_size = repeat_count * original_batch_size
304+
305+ padded_inputs = {}
306+ for key , value in inputs .items ():
307+ if isinstance (value , torch .Tensor ) and value .dim () > 0 :
308+ if key == "probs" :
309+ # For probs, repeat the entire batch
310+ if value .dim () == 1 :
311+ value = value .unsqueeze (0 )
312+ # Repeat the entire batch repeat_count times
313+ padded_value = value .repeat (repeat_count , * ([1 ] * (value .dim () - 1 )))
314+ elif key in ["top_k" , "top_p" ]:
315+ # For sampling parameters, repeat the entire batch
316+ if value .dim () == 0 :
317+ padded_value = value .unsqueeze (0 ).repeat (actual_batch_size )
318+ else :
319+ padded_value = value .repeat (repeat_count )
320+ else :
321+ # For other tensors, repeat entire batch along batch dimension
322+ if value .dim () == 0 :
323+ padded_value = value .unsqueeze (0 ).repeat (actual_batch_size )
324+ else :
325+ padded_value = value .repeat (repeat_count , * ([1 ] * (value .dim () - 1 )))
326+ padded_inputs [key ] = padded_value
327+ else :
328+ # For non-tensor inputs, keep as is
329+ padded_inputs [key ] = value
330+
331+ counters = torch .zeros (
332+ (original_batch_size , vocab_size ), dtype = torch .int64 , device = torch .device (device )
333+ )
334+
335+ trials_needed = (num_trials + repeat_count - 1 ) // repeat_count
336+ total_samples_per_batch = 0
217337
218338 for _ in range (trials_needed ):
219339 with torch .no_grad ():
220- out = runnable (** inputs )
340+ out = runnable (** padded_inputs )
221341
222342 output_names = list (defn .outputs .keys ())
223343 output_dtypes = {k : dtype_str_to_torch_dtype (v .dtype ) for k , v in defn .outputs .items ()}
@@ -229,118 +349,19 @@ def _compute_frequency_distribution(
229349 samples = out_normalized ["samples" ]
230350
231351 if samples .dim () == 0 :
352+ # Single sample - assign to first batch element
232353 sample_idx = samples .item ()
233- counter [sample_idx ] += 1
234- total_samples_collected += 1
235- else : # Batch of samples
236- for i in range (samples .numel ()):
237- sample_idx = samples .flatten ()[i ].item ()
238- counter [sample_idx ] += 1
239- total_samples_collected += 1
240-
241- frequency = counter .float () / total_samples_collected
242- return frequency
243-
244-
245- def _check_thresholding (
246- samples : torch .Tensor , probs : torch .Tensor , method : str , params : Dict [str , Any ]
247- ) -> bool :
248- """Check if samples conform to the specified thresholding method.
249-
250- Parameters
251- ----------
252- samples : torch.Tensor
253- Sampled token indices.
254- probs : torch.Tensor
255- Probability distribution used for sampling.
256- method : str
257- Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
258- params : Dict[str, Any]
259- Sampling parameters (top_k, top_p values).
260-
261- Returns
262- -------
263- bool
264- True if samples are valid, False otherwise.
265- """
266- batch_size , vocab_size = probs .shape
267- device = probs .device
268-
269- for i in range (batch_size ):
270- prob_row = probs [i ]
271- sample = samples [i ].item ()
272-
273- if method == "top_k" :
274- if "top_k" not in params :
275- raise ValueError ("top_k parameter is required for top_k thresholding but not found" )
276- k = (
277- int (params ["top_k" ][i ].item ())
278- if params ["top_k" ].dim () > 0
279- else int (params ["top_k" ].item ())
280- )
281-
282- if 0 < k < vocab_size :
283- sorted_prob_desc , _ = torch .sort (prob_row , descending = True )
284- pivot = sorted_prob_desc [k - 1 ]
285- mask_top_k = (prob_row >= pivot ).int ()
286- if mask_top_k [sample ] != 1 :
287- return False
288-
289- elif method == "top_p" :
290- if "top_p" not in params :
291- raise ValueError ("top_p parameter is required for top_p thresholding but not found" )
292- p = (
293- float (params ["top_p" ][i ].item ())
294- if params ["top_p" ].dim () > 0
295- else float (params ["top_p" ].item ())
296- )
297-
298- if 0 < p < 1 :
299- eps = 1e-4 # numerical stability
300- sorted_probs , indices = torch .sort (prob_row , descending = False )
301- cdf = torch .cumsum (sorted_probs , dim = 0 )
302- valid_mask = cdf > (1 - p ) - eps
303- valid_indices = indices [valid_mask ]
304-
305- if sample not in valid_indices :
306- return False
307-
308- elif method == "top_k_top_p" :
309- if "top_k" not in params or "top_p" not in params :
310- raise ValueError (
311- "top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
312- )
313- k = (
314- int (params ["top_k" ][i ].item ())
315- if params ["top_k" ].dim () > 0
316- else int (params ["top_k" ].item ())
317- )
318- p = (
319- float (params ["top_p" ][i ].item ())
320- if params ["top_p" ].dim () > 0
321- else float (params ["top_p" ].item ())
322- )
323-
324- if 0 < k < vocab_size :
325- sorted_prob_desc , _ = torch .sort (prob_row , descending = True )
326- pivot = sorted_prob_desc [k - 1 ]
327- mask_top_k = (prob_row >= pivot ).int ()
328- else :
329- mask_top_k = torch .ones (vocab_size , dtype = torch .int32 , device = device )
330-
331- if 0 < p < 1 :
332- eps = 1e-4
333- sorted_probs_asc , indices = torch .sort (prob_row , descending = False )
334- cdf = torch .cumsum (sorted_probs_asc , dim = 0 )
335- mask_top_p = torch .zeros (vocab_size , dtype = torch .int32 , device = device )
336- valid_p_mask = cdf > (1 - p ) - eps
337- mask_top_p [indices [valid_p_mask ]] = 1
338- else :
339- mask_top_p = torch .ones (vocab_size , dtype = torch .int32 , device = device )
340-
341- joint_mask = torch .minimum (mask_top_k , mask_top_p )
342-
343- if joint_mask [sample ] != 1 :
344- return False
345-
346- return True
354+ counters [0 , sample_idx ] += 1
355+ total_samples_per_batch += 1
356+ else :
357+ # slice and accumulate per original batch element
358+ samples_flat = samples .flatten ()
359+ for i in range (samples_flat .numel ()):
360+ batch_idx = i % original_batch_size
361+ sample_idx = samples_flat [i ].item ()
362+ counters [batch_idx , sample_idx ] += 1
363+ total_samples_per_batch += repeat_count
364+
365+ # [batch_size, vocab_size]
366+ frequencies = counters .float () / total_samples_per_batch
367+ return frequencies
0 commit comments