Skip to content

Commit 7d371a8

Browse files
ForBetterCodeNineNSDie
authored andcommitted
[Test]Add ut test qwen3_moe and sfa (vllm-project#4121)
### What this PR does / why we need it? Currently, the UT tests lack coverage for the Qwen3_moe network and torchair_sfa. Therefore, supplementary tests are being added. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? by CI - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: CodeNine-CJ <[email protected]> Signed-off-by: nsdie <[email protected]>
1 parent a7c3fcf commit 7d371a8

File tree

2 files changed

+381
-0
lines changed

2 files changed

+381
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
from pytest_mock import MockerFixture
5+
from transformers import PretrainedConfig
6+
from vllm.distributed.parallel_state import GroupCoordinator
7+
8+
from tests.ut.base import PytestBase
9+
from vllm_ascend.torchair.models.qwen3_moe import CustomSparseMoeBlock
10+
11+
12+
class TestCustomSparseMoeBlock(PytestBase):
13+
14+
@pytest.fixture
15+
def setup_csmb(self, mocker: MockerFixture):
16+
config = PretrainedConfig(num_experts=64,
17+
hidden_size=2048,
18+
num_experts_per_tok=2,
19+
moe_intermediate_size=1408,
20+
norm_topk_prob=True)
21+
mocker.patch(
22+
'vllm_ascend.torchair.models.qwen3_moe.get_tensor_model_parallel_world_size',
23+
return_value=10)
24+
mocker.patch(
25+
'vllm.model_executor.layers.linear.ReplicatedLinear.__init__',
26+
return_value=None)
27+
mocker.patch(
28+
'vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__init__',
29+
return_value=None)
30+
31+
tp_group = Mock(spec=GroupCoordinator)
32+
tp_group.rank_in_group = 0
33+
tp_group.world_size = 1
34+
tp_group.device_group = Mock()
35+
36+
dp_group = Mock(spec=GroupCoordinator)
37+
dp_group.rank_in_group = 0
38+
dp_group.world_size = 1
39+
40+
ep_group = Mock(spec=GroupCoordinator)
41+
ep_group.rank_in_group = 0
42+
ep_group.world_size = 1
43+
44+
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_tp_group',
45+
return_value=tp_group)
46+
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_dp_group',
47+
return_value=dp_group)
48+
mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_ep_group',
49+
return_value=ep_group)
50+
ascend_config = mocker.MagicMock()
51+
ascend_config.max_num_batched_tokens = 2048
52+
ascend_config.max_model_len = 1024
53+
mocker.patch("vllm_ascend.utils.get_ascend_config",
54+
return_value=ascend_config)
55+
56+
custom_moe_block = CustomSparseMoeBlock(config, None, "")
57+
return custom_moe_block
58+
59+
def test_init(self, mocker: MockerFixture, setup_csmb):
60+
custom_moe_block = setup_csmb
61+
assert isinstance(custom_moe_block, CustomSparseMoeBlock)
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import torch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
7+
from vllm_ascend.torchair.torchair_sfa import (
8+
AscendSFATorchairBackend, AscendSFATorchairDecodeMetadata,
9+
AscendSFATorchairImpl, AscendSFATorchairMetadata,
10+
AscendSFATorchairMetadataBuilder, AscendSFATorchairPrefillMetadata)
11+
12+
13+
class TestAscendSFATorchairBackend(TestBase):
14+
15+
def test_get_name(self):
16+
self.assertEqual(AscendSFATorchairBackend.get_name(),
17+
"ASCEND_SFA_TORCHAIR")
18+
19+
def test_get_metadata_cls(self):
20+
self.assertEqual(AscendSFATorchairBackend.get_metadata_cls(),
21+
AscendSFATorchairMetadata)
22+
23+
def test_get_builder_cls(self):
24+
self.assertEqual(AscendSFATorchairBackend.get_builder_cls(),
25+
AscendSFATorchairMetadataBuilder)
26+
27+
def test_get_kv_cache_shape(self):
28+
result = AscendSFATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
29+
self.assertEqual(result, (2, 4, 8, 128))
30+
31+
def test_get_impl_cls(self):
32+
result = AscendSFATorchairBackend.get_impl_cls()
33+
self.assertEqual(result, AscendSFATorchairImpl)
34+
35+
36+
class TestAscendSFATorchairPrefillMetadata(TestBase):
37+
38+
def test_ascend_sfa_prefill_metadata_default(self):
39+
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
40+
query_lens = [1, 2]
41+
seq_lens = [2, 2]
42+
context_lens = torch.tensor([1, 2])
43+
input_positions = torch.tensor([0, 1, 0, 1])
44+
query_start_loc = torch.tensor([0, 1, 3])
45+
block_table = torch.tensor([[0, 1], [2, 3]])
46+
max_query_len = 2
47+
max_seq_lens = 2
48+
49+
metadata = AscendSFATorchairPrefillMetadata(
50+
attn_mask=attn_mask,
51+
query_lens=query_lens,
52+
seq_lens=seq_lens,
53+
context_lens=context_lens,
54+
input_positions=input_positions,
55+
query_start_loc=query_start_loc,
56+
block_table=block_table,
57+
max_query_len=max_query_len,
58+
sin=None,
59+
cos=None,
60+
max_seq_lens=max_seq_lens)
61+
self.assertIs(metadata.attn_mask, attn_mask)
62+
self.assertEqual(metadata.query_lens, query_lens)
63+
self.assertEqual(metadata.seq_lens, seq_lens)
64+
self.assertIs(metadata.context_lens, context_lens)
65+
self.assertIs(metadata.input_positions, input_positions)
66+
self.assertIs(metadata.query_start_loc, query_start_loc)
67+
self.assertIs(metadata.block_table, block_table)
68+
self.assertEqual(metadata.max_query_len, max_query_len)
69+
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
70+
self.assertIsNone(metadata.chunked_context)
71+
72+
def test_ascend_sfa_prefill_metadata_with_chunked_context(self):
73+
cu_seq_lens = torch.tensor([0, 2, 4])
74+
starts = torch.tensor([0, 2])
75+
seq_tot = [2, 2]
76+
max_seq_lens = [2, 2]
77+
workspace = torch.randn(2, 4)
78+
chunk_seq_lens = torch.tensor([2, 2])
79+
80+
chunked_context = AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata(
81+
cu_seq_lens=cu_seq_lens,
82+
starts=starts,
83+
seq_tot=seq_tot,
84+
max_seq_lens=max_seq_lens,
85+
workspace=workspace,
86+
chunk_seq_lens=chunk_seq_lens)
87+
88+
metadata = AscendSFATorchairPrefillMetadata(
89+
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
90+
query_lens=[1, 2],
91+
seq_lens=[2, 2],
92+
context_lens=torch.tensor([1, 2]),
93+
input_positions=torch.tensor([0, 1, 0, 1]),
94+
query_start_loc=torch.tensor([0, 1, 3]),
95+
block_table=torch.tensor([[0, 1], [2, 3]]),
96+
max_query_len=2,
97+
max_seq_lens=2,
98+
sin=None,
99+
cos=None,
100+
chunked_context=chunked_context)
101+
102+
self.assertIsNotNone(metadata.chunked_context)
103+
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
104+
self.assertIs(metadata.chunked_context.starts, starts)
105+
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
106+
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
107+
self.assertIs(metadata.chunked_context.workspace, workspace)
108+
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
109+
110+
111+
class TestAscendSFATorchairDecodeMetadata(TestBase):
112+
113+
def test_ascend_sfa_decode_metadata_default(self):
114+
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
115+
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
116+
seq_lens = torch.tensor([[2], [3]])
117+
max_seq_lens = 4
118+
seq_lens_list = [2, 3]
119+
attn_mask = None
120+
121+
metadata = AscendSFATorchairDecodeMetadata(input_positions,
122+
block_table, seq_lens,
123+
max_seq_lens, seq_lens_list,
124+
None, None, attn_mask)
125+
126+
self.assertIs(metadata.input_positions, input_positions)
127+
self.assertIs(metadata.block_table, block_table)
128+
self.assertIs(metadata.seq_lens, seq_lens)
129+
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
130+
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
131+
self.assertIsNone(attn_mask)
132+
133+
134+
class TestAscendSFATorchairMetadata(TestBase):
135+
136+
def test_ascend_sfa_metadata_default(self):
137+
num_actual_tokens = 100
138+
slot_mapping = torch.randn(100, 4, 1024)
139+
query_start_loc = torch.tensor([1, 2, 3, 4])
140+
seq_lens = [30, 50]
141+
block_tables = torch.randint(0, 100, (100, 4))
142+
143+
num_decodes = 4
144+
num_decode_tokens = 8
145+
num_prefills = 8
146+
147+
num_input_tokens = 2
148+
149+
query_lens = None
150+
head_dim = None
151+
attn_mask = None
152+
attn_state = AscendAttentionState.ChunkedPrefill
153+
154+
decode = None
155+
prefill = None
156+
157+
metadata = AscendSFATorchairMetadata(
158+
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
159+
block_tables, num_decodes, num_decode_tokens, num_prefills,
160+
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
161+
decode, prefill)
162+
163+
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
164+
self.assertIs(metadata.slot_mapping, slot_mapping)
165+
self.assertIs(metadata.query_start_loc, query_start_loc)
166+
self.assertEqual(metadata.seq_lens, seq_lens)
167+
self.assertIs(metadata.block_tables, block_tables)
168+
self.assertEqual(metadata.num_decodes, num_decodes)
169+
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
170+
self.assertEqual(metadata.num_prefills, num_prefills)
171+
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
172+
self.assertEqual(metadata.query_lens, query_lens)
173+
self.assertEqual(metadata.head_dim, head_dim)
174+
self.assertEqual(metadata.attn_mask, attn_mask)
175+
self.assertEqual(metadata.attn_state, attn_state)
176+
self.assertEqual(metadata.decode, decode)
177+
self.assertEqual(metadata.prefill, prefill)
178+
179+
180+
class TestAscendSFATorchairMetadataBuilder(TestBase):
181+
182+
def test_ascend_sfa_metadata_builder_default(self):
183+
mock_vllm_config = MagicMock()
184+
mock_vllm_config.model_config.max_model_len = 1024
185+
mock_vllm_config.model_config.get_head_size.return_value = 64
186+
mock_vllm_config.model_config.dtype = torch.float16
187+
mock_vllm_config.cache_config.block_size = 16
188+
mock_vllm_config.scheduler_config.max_num_seqs = 4
189+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
190+
mock_device = 'cpu'
191+
192+
mock_vllm_config.speculative_config = None
193+
194+
ascend_config = MagicMock()
195+
ascend_config.torchair_graph_config = MagicMock()
196+
ascend_config.torchair_graph_config.enabled = True
197+
with patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config",
198+
return_value=ascend_config):
199+
builder = AscendSFATorchairMetadataBuilder(None, None,
200+
mock_vllm_config,
201+
mock_device)
202+
203+
self.assertEqual(builder.block_size,
204+
mock_vllm_config.cache_config.block_size)
205+
self.assertEqual(
206+
builder.chunked_prefill_enabled,
207+
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
208+
self.assertEqual(builder.torchair_graph_enabled, True)
209+
self.assertEqual(builder.max_blocks, (mock_vllm_config.model_config.max_model_len +
210+
mock_vllm_config.cache_config.block_size - 1) \
211+
// mock_vllm_config.cache_config.block_size)
212+
213+
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
214+
def test_reorder_batch_with_torchair_graph(self, ascend_config):
215+
mock_vllm_config = MagicMock()
216+
mock_vllm_config.model_config.max_model_len = 1024
217+
mock_vllm_config.cache_config.block_size = 16
218+
mock_vllm_config.scheduler_config.max_num_seqs = 4
219+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
220+
mock_device = 'cpu'
221+
ascend_config.torchair_graph_config = MagicMock()
222+
ascend_config.torchair_graph_config.enabled = True
223+
224+
mock_vllm_config.speculative_config = None
225+
226+
builder = AscendSFATorchairMetadataBuilder(None, None,
227+
mock_vllm_config,
228+
mock_device)
229+
230+
input_batch = MagicMock()
231+
input_batch.req_ids = [0, 1, 2, 3]
232+
233+
scheduler_output = MagicMock()
234+
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
235+
scheduler_output.scheduled_spec_decode_tokens = {
236+
0: [1],
237+
1: [],
238+
2: [1, 1],
239+
3: []
240+
}
241+
242+
input_batch.swap_states = MagicMock()
243+
244+
modified = builder.reorder_batch(input_batch, scheduler_output)
245+
246+
self.assertFalse(modified)
247+
input_batch.swap_states.assert_not_called()
248+
249+
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
250+
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
251+
ascend_config = MagicMock()
252+
mock_ascend_config.return_value = ascend_config
253+
ascend_config.torchair_graph_config.enabled = False
254+
mock_vllm_config = MagicMock()
255+
mock_vllm_config.model_config.max_model_len = 1024
256+
mock_vllm_config.cache_config.block_size = 16
257+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
258+
mock_device = 'cpu'
259+
260+
mock_vllm_config.speculative_config = None
261+
262+
builder = AscendSFATorchairMetadataBuilder(None, None,
263+
mock_vllm_config,
264+
mock_device)
265+
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
266+
267+
result = builder._get_graph_runner_block_tables(3, block_tables)
268+
self.assertEqual(result.shape[0], 3)
269+
self.assertEqual(result.shape[1], 64)
270+
self.assertTrue(torch.equal(result[:, :10], block_tables))
271+
272+
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
273+
def test_ge_graph_runner_block_tables_truncated(self, mock_ascend_config):
274+
ascend_config = MagicMock()
275+
mock_ascend_config.return_value = ascend_config
276+
ascend_config.torchair_graph_config.enabled = False
277+
mock_vllm_config = MagicMock()
278+
mock_vllm_config.model_config.max_model_len = 64
279+
mock_vllm_config.cache_config.block_size = 16
280+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
281+
mock_device = 'cpu'
282+
283+
mock_vllm_config.speculative_config = None
284+
285+
builder = AscendSFATorchairMetadataBuilder(None, None,
286+
mock_vllm_config,
287+
mock_device)
288+
289+
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
290+
291+
result = builder._get_graph_runner_block_tables(3, block_tables)
292+
self.assertEqual(result.shape[0], 3)
293+
self.assertEqual(result.shape[1], 4)
294+
self.assertTrue(torch.equal(result, block_tables[:, :4]))
295+
296+
@patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config")
297+
def test_get_graph_runner_block_tables_from_numpy(self,
298+
mock_ascend_config):
299+
ascend_config = MagicMock()
300+
mock_ascend_config.return_value = ascend_config
301+
ascend_config.torchair_graph_config.enabled = False
302+
mock_vllm_config = MagicMock()
303+
mock_vllm_config.model_config.max_model_len = 1024
304+
mock_vllm_config.cache_config.block_size = 16
305+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
306+
mock_device = 'cpu'
307+
308+
mock_vllm_config.speculative_config = None
309+
310+
builder = AscendSFATorchairMetadataBuilder(None, None,
311+
mock_vllm_config,
312+
mock_device)
313+
314+
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
315+
316+
result = builder._get_graph_runner_block_tables(3, block_tables)
317+
318+
self.assertEqual(result.shape[0], 3)
319+
self.assertEqual(result.shape[1], 64)
320+
self.assertTrue(torch.equal(result[:, :10], block_tables))

0 commit comments

Comments
 (0)