Skip to content

Commit 2a7154c

Browse files
authored
refactor: update sampling evaluation logic (#104)
This PR updates several issues with the previous sampling evaluation logic: The previous version compresses all input probs into single dim frequencies, this introduces vulnerabilities when input tensor's batchsize > 1, this PR addresses by retaining the input shape for sampled token distributions. For sampled tokens, we compute per input probability distribution TVD against the ground truth. The Evaluation class will record the worst (max) TVD amongst all input batch elements. To reduce correctness sampling iterations, we repeat the original input tensor for 10,000 // original_batch_size times, this still allows us to sample the non-deterministic kernel while running fewer forward passes to reduce benchmarking time. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Baseline now reports threshold-aware expected probabilities and uses them for downstream checks * Added public helpers to support threshold-aware sampling and validation * **Bug Fixes** * More accurate detection of valid tokens under top-k/top-p with tie and tolerance handling * Enhanced per-batch error reporting including TVD, max absolute and relative errors * **Refactor** * Increased sampling trials and optimized batched sampling for stronger statistical validation * Streamlined validation flow to apply masks per batch element <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 4a2ebc6 commit 2a7154c

File tree

1 file changed

+176
-155
lines changed

1 file changed

+176
-155
lines changed

flashinfer_bench/bench/evaluators/sampling.py

Lines changed: 176 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def build_baseline(
5353
outputs: List[Dict[str, torch.Tensor]] = []
5454

5555
inp = gen_inputs(defn, workload, device=device, stensors=loaded_stensors)
56-
if "probs" in inp:
57-
inp["probs"] = torch.softmax(
58-
inp["probs"], dim=-1
59-
) # convert logits to probs for sampling
6056
inputs.append(inp)
6157

62-
freq_dist = _compute_frequency_distribution(
63-
ref_runnable, inp, device, defn, num_trials=50000
64-
)
65-
outputs.append({"frequency_distribution": freq_dist})
58+
thresholding_method = _detect_thresholding_method(defn)
59+
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
60+
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)
61+
62+
masked_probs = inp["probs"] * valid_mask.float()
63+
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
64+
65+
outputs.append({"expected_probs": expected_probs})
6666

6767
latencies: List[float] = []
6868
for inp in inputs:
@@ -94,15 +94,20 @@ def check_correctness(
9494
log_path: str,
9595
device: str,
9696
) -> Tuple[Optional[Correctness], Optional[Evaluation]]:
97-
ref_freq = ref_outputs[0]["frequency_distribution"]
98-
vocab_size = ref_freq.shape[0]
97+
expected_probs = ref_outputs[0]["expected_probs"]
98+
vocab_size = expected_probs.shape[-1]
9999

100100
inp = inputs[0]
101101
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
102102

103103
output_names = list(defn.outputs.keys())
104104
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
105105

106+
# Compute valid sampling mask based on thresholding
107+
thresholding_method = _detect_thresholding_method(defn)
108+
probs = inp["probs"]
109+
valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params)
110+
106111
# Validate correct sampling token set
107112
for _ in range(cfg.sampling_validation_trials):
108113
try:
@@ -137,27 +142,32 @@ def check_correctness(
137142
correctness=correctness,
138143
)
139144

140-
# Validate thresholding
141-
thresholding_method = _detect_thresholding_method(defn)
142-
probs = inp["probs"]
143-
if not _check_thresholding(samples, probs, thresholding_method, params):
144-
correctness = Correctness(
145-
max_relative_error=float("inf"), max_absolute_error=float("inf")
146-
)
147-
message = (
148-
f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding"
149-
)
150-
print(message, file=sys.stderr)
151-
return correctness, make_eval(
152-
status=EvaluationStatus.INCORRECT_NUMERICAL,
153-
device=device,
154-
log_path=log_path,
155-
correctness=correctness,
156-
)
145+
# Validate thresholding - check samples are within valid mask
146+
if samples.dim() == 0:
147+
samples_flat = samples.unsqueeze(0)
148+
else:
149+
samples_flat = samples.flatten()
150+
151+
batch_size = valid_mask.shape[0]
152+
for i in range(len(samples_flat)):
153+
batch_idx = i % batch_size
154+
sample_idx = samples_flat[i].item()
155+
if not valid_mask[batch_idx, sample_idx]:
156+
correctness = Correctness(
157+
max_relative_error=float("inf"), max_absolute_error=float("inf")
158+
)
159+
message = f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
160+
print(message, file=sys.stderr)
161+
return correctness, make_eval(
162+
status=EvaluationStatus.INCORRECT_NUMERICAL,
163+
device=device,
164+
log_path=log_path,
165+
correctness=correctness,
166+
)
157167

158168
try:
159-
sol_freq = _compute_frequency_distribution(
160-
sol_runnable, inp, device, defn, num_trials=50000
169+
sol_freqs = _sample_token_distributions(
170+
sol_runnable, inp, device, defn, num_trials=500000
161171
)
162172
torch.cuda.synchronize(device)
163173
except Exception:
@@ -166,13 +176,29 @@ def check_correctness(
166176
status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path
167177
)
168178

169-
# total variation distance
170-
tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item()
171-
max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg)
179+
batch_size = expected_probs.shape[0]
180+
tvds = []
181+
max_abs_errors = []
182+
max_rel_errors = []
183+
184+
for i in range(batch_size):
185+
tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item()
186+
tvds.append(tvd_i)
187+
188+
max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg)
189+
max_abs_errors.append(max_abs_i)
190+
max_rel_errors.append(max_rel_i)
172191

173-
numerical_incorrect = tvd > cfg.sampling_tvd_threshold
192+
# Use the worst (max) TVD and errors across all batch elements
193+
max_tvd = max(tvds)
194+
max_abs = max(max_abs_errors)
195+
max_rel = max(max_rel_errors)
196+
197+
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
174198
correctness = Correctness(
175-
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
199+
max_relative_error=max_rel,
200+
max_absolute_error=max_abs,
201+
extra={"tvd": max_tvd, "tvds_per_batch": tvds},
176202
)
177203
if numerical_incorrect:
178204
return correctness, make_eval(
@@ -201,23 +227,117 @@ def _detect_thresholding_method(defn: Definition) -> str:
201227
return "none" # no thresholding
202228

203229

204-
def _compute_frequency_distribution(
230+
def _compute_valid_sampling_mask(
231+
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 5e-2
232+
) -> torch.Tensor:
233+
"""
234+
For tie-breaking in top_k (allows any token with prob >= k-th largest)
235+
and numerical precision in top_p (allows tokens within eps of nucleus boundary).
236+
"""
237+
if probs.dim() == 1:
238+
probs = probs.unsqueeze(0)
239+
240+
batch_size, vocab_size = probs.shape
241+
device = probs.device
242+
243+
if method == "none":
244+
return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)
245+
246+
mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)
247+
248+
if method in ["top_k", "top_k_top_p"]:
249+
if "top_k" not in params:
250+
raise ValueError(f"top_k parameter required for {method} but not found")
251+
252+
top_k_param = params["top_k"]
253+
for i in range(batch_size):
254+
k = int(top_k_param[i].item()) if top_k_param.dim() > 0 else int(top_k_param.item())
255+
256+
if 0 < k < vocab_size:
257+
sorted_probs, _ = torch.sort(probs[i], descending=True)
258+
# k-th largest value (0-indexed, so k-1)
259+
pivot = sorted_probs[k - 1]
260+
mask[i] = probs[i] >= pivot # tie-breaking handling
261+
262+
# Apply top_p mask with epsilon tolerance
263+
if method in ["top_p", "top_k_top_p"]:
264+
if "top_p" not in params:
265+
raise ValueError(f"top_p parameter required for {method} but not found")
266+
267+
top_p_param = params["top_p"]
268+
for i in range(batch_size):
269+
p = float(top_p_param[i].item()) if top_p_param.dim() > 0 else float(top_p_param.item())
270+
271+
if 0 < p < 1:
272+
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
273+
cumsum = torch.cumsum(sorted_probs, dim=0)
274+
275+
# Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
276+
nucleus_mask = cumsum <= (p + eps)
277+
278+
if not nucleus_mask.any():
279+
nucleus_mask[0] = True
280+
281+
# Map back to original indices
282+
p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
283+
p_mask[sorted_indices[nucleus_mask]] = True
284+
285+
mask[i] = mask[i] & p_mask
286+
287+
return mask
288+
289+
290+
def _sample_token_distributions(
205291
runnable: Runnable,
206292
inputs: Dict[str, Any],
207293
device: str,
208294
defn: Definition,
209-
num_trials: int = 10000,
295+
num_trials: int = 500000,
210296
) -> torch.Tensor:
211-
batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
297+
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
212298
vocab_size = inputs["probs"].shape[-1]
213-
counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device))
214299

215-
trials_needed = (num_trials + batch_size - 1) // batch_size
216-
total_samples_collected = 0
300+
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
301+
target_batch_size = 10000
302+
repeat_count = target_batch_size // original_batch_size
303+
actual_batch_size = repeat_count * original_batch_size
304+
305+
padded_inputs = {}
306+
for key, value in inputs.items():
307+
if isinstance(value, torch.Tensor) and value.dim() > 0:
308+
if key == "probs":
309+
# For probs, repeat the entire batch
310+
if value.dim() == 1:
311+
value = value.unsqueeze(0)
312+
# Repeat the entire batch repeat_count times
313+
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
314+
elif key in ["top_k", "top_p"]:
315+
# For sampling parameters, repeat the entire batch
316+
if value.dim() == 0:
317+
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
318+
else:
319+
padded_value = value.repeat(repeat_count)
320+
else:
321+
# For other tensors, repeat entire batch along batch dimension
322+
if value.dim() == 0:
323+
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
324+
else:
325+
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
326+
padded_inputs[key] = padded_value
327+
else:
328+
# For non-tensor inputs, keep as is
329+
padded_inputs[key] = value
330+
331+
counters = torch.zeros(
332+
(original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device)
333+
)
334+
335+
trials_needed = (num_trials + repeat_count - 1) // repeat_count
336+
total_samples_per_batch = 0
217337

218338
for _ in range(trials_needed):
219339
with torch.no_grad():
220-
out = runnable(**inputs)
340+
out = runnable(**padded_inputs)
221341

222342
output_names = list(defn.outputs.keys())
223343
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
@@ -229,118 +349,19 @@ def _compute_frequency_distribution(
229349
samples = out_normalized["samples"]
230350

231351
if samples.dim() == 0:
352+
# Single sample - assign to first batch element
232353
sample_idx = samples.item()
233-
counter[sample_idx] += 1
234-
total_samples_collected += 1
235-
else: # Batch of samples
236-
for i in range(samples.numel()):
237-
sample_idx = samples.flatten()[i].item()
238-
counter[sample_idx] += 1
239-
total_samples_collected += 1
240-
241-
frequency = counter.float() / total_samples_collected
242-
return frequency
243-
244-
245-
def _check_thresholding(
246-
samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any]
247-
) -> bool:
248-
"""Check if samples conform to the specified thresholding method.
249-
250-
Parameters
251-
----------
252-
samples : torch.Tensor
253-
Sampled token indices.
254-
probs : torch.Tensor
255-
Probability distribution used for sampling.
256-
method : str
257-
Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
258-
params : Dict[str, Any]
259-
Sampling parameters (top_k, top_p values).
260-
261-
Returns
262-
-------
263-
bool
264-
True if samples are valid, False otherwise.
265-
"""
266-
batch_size, vocab_size = probs.shape
267-
device = probs.device
268-
269-
for i in range(batch_size):
270-
prob_row = probs[i]
271-
sample = samples[i].item()
272-
273-
if method == "top_k":
274-
if "top_k" not in params:
275-
raise ValueError("top_k parameter is required for top_k thresholding but not found")
276-
k = (
277-
int(params["top_k"][i].item())
278-
if params["top_k"].dim() > 0
279-
else int(params["top_k"].item())
280-
)
281-
282-
if 0 < k < vocab_size:
283-
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
284-
pivot = sorted_prob_desc[k - 1]
285-
mask_top_k = (prob_row >= pivot).int()
286-
if mask_top_k[sample] != 1:
287-
return False
288-
289-
elif method == "top_p":
290-
if "top_p" not in params:
291-
raise ValueError("top_p parameter is required for top_p thresholding but not found")
292-
p = (
293-
float(params["top_p"][i].item())
294-
if params["top_p"].dim() > 0
295-
else float(params["top_p"].item())
296-
)
297-
298-
if 0 < p < 1:
299-
eps = 1e-4 # numerical stability
300-
sorted_probs, indices = torch.sort(prob_row, descending=False)
301-
cdf = torch.cumsum(sorted_probs, dim=0)
302-
valid_mask = cdf > (1 - p) - eps
303-
valid_indices = indices[valid_mask]
304-
305-
if sample not in valid_indices:
306-
return False
307-
308-
elif method == "top_k_top_p":
309-
if "top_k" not in params or "top_p" not in params:
310-
raise ValueError(
311-
"top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
312-
)
313-
k = (
314-
int(params["top_k"][i].item())
315-
if params["top_k"].dim() > 0
316-
else int(params["top_k"].item())
317-
)
318-
p = (
319-
float(params["top_p"][i].item())
320-
if params["top_p"].dim() > 0
321-
else float(params["top_p"].item())
322-
)
323-
324-
if 0 < k < vocab_size:
325-
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
326-
pivot = sorted_prob_desc[k - 1]
327-
mask_top_k = (prob_row >= pivot).int()
328-
else:
329-
mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device)
330-
331-
if 0 < p < 1:
332-
eps = 1e-4
333-
sorted_probs_asc, indices = torch.sort(prob_row, descending=False)
334-
cdf = torch.cumsum(sorted_probs_asc, dim=0)
335-
mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device)
336-
valid_p_mask = cdf > (1 - p) - eps
337-
mask_top_p[indices[valid_p_mask]] = 1
338-
else:
339-
mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device)
340-
341-
joint_mask = torch.minimum(mask_top_k, mask_top_p)
342-
343-
if joint_mask[sample] != 1:
344-
return False
345-
346-
return True
354+
counters[0, sample_idx] += 1
355+
total_samples_per_batch += 1
356+
else:
357+
# slice and accumulate per original batch element
358+
samples_flat = samples.flatten()
359+
for i in range(samples_flat.numel()):
360+
batch_idx = i % original_batch_size
361+
sample_idx = samples_flat[i].item()
362+
counters[batch_idx, sample_idx] += 1
363+
total_samples_per_batch += repeat_count
364+
365+
# [batch_size, vocab_size]
366+
frequencies = counters.float() / total_samples_per_batch
367+
return frequencies

0 commit comments

Comments
 (0)