Skip to content

Commit a7da0c5

Browse files
committed
ranktable 表适配
1 parent 5b17ebe commit a7da0c5

File tree

1 file changed

+141
-31
lines changed

1 file changed

+141
-31
lines changed

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 141 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from concurrent.futures import ThreadPoolExecutor
1111
from dataclasses import dataclass
1212
from enum import Enum
13-
from typing import Any, Callable, Optional, Tuple
13+
from typing import Any, Callable, Optional, Tuple, List, Dict
1414

1515
import llm_datadist # type: ignore
1616
import msgspec
@@ -70,6 +70,11 @@ class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
7070
super_device_id: str
7171
cluster_id: int
7272

73+
class LLMDataDistCMgrAgentMetaDataA5(msgspec.Struct):
74+
server_id: str
75+
device_id: str
76+
cluster_id: int
77+
level_list: List[Dict]
7378

7479
@dataclass
7580
class ReqMeta:
@@ -351,8 +356,10 @@ def __init__(self, vllm_config: VllmConfig):
351356
self.dcp_size = get_dcp_group().world_size
352357
self.local_ip = get_ip()
353358
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
354-
self.local_agent_metadata: Optional[
355-
LLMDataDistCMgrAgentMetadata] = None
359+
if is_A5():
360+
self.local_agent_metadata: Optional[LLMDataDistCMgrAgentMetaDataA5] = None
361+
else:
362+
self.local_agent_metadata: Optional[LLMDataDistCMgrAgentMetadata] = None
356363
self.vllm_config = vllm_config
357364
self.executor = ThreadPoolExecutor(1)
358365
self.thread_lock = threading.Lock()
@@ -409,12 +416,18 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
409416
event_msg = LLMDataDistCMgrEvent(event_msg)
410417
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
411418
if "cluster_id" in decode_msg:
412-
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
419+
if is_A5():
420+
decode_msg = LLMDataDistCMgrAgentMetadataA5(**decode_msg)
421+
else :
422+
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
413423
logger.info(
414424
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
415425
)
416426
sock.send_multipart((identity, b"", msg_to_send))
417-
self.add_remote_agent(decode_msg)
427+
if is_A5():
428+
self.add_remote_agentA5(decode_msg)
429+
else :
430+
self.add_remote_agent(decode_msg)
418431
else:
419432
logger.warning(
420433
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
@@ -483,6 +496,17 @@ def _get_visible_devices() -> Callable[[str], bool]:
483496
visible_device_list = visible_devices.split(",")
484497
return lambda device_id: device_id in visible_device_list
485498

499+
def get_device_info(self, global_rank_table, device_filter, device_type):
500+
device_list = global_rank_table[device_type]
501+
device_list = [
502+
d for d in device_list if d.get("server_id") == self.local_ip
503+
and device_filter(d.get("device_id", ""))
504+
]
505+
if len(device_list) <= self.pcp_rank * self.tp_size + self.tp_rank:
506+
retunr None
507+
device_info = device_list[self.pcp_rank * self.tp_size + self.tp_rank]
508+
return device_info
509+
486510
def read_agent_metadata(self, global_rank_table):
487511
device_filter = LLMDataDistCMgrConnectorWorker._get_visible_devices()
488512
devices_type_list = []
@@ -494,30 +518,40 @@ def read_agent_metadata(self, global_rank_table):
494518
else:
495519
devices_type_list.append("prefill_device_list")
496520
devices_type_list.append("decode_device_list")
497-
for device_type in devices_type_list:
498-
device_list = global_rank_table[device_type]
499-
device_list = [
500-
d for d in device_list if d.get("server_id") == self.local_ip
501-
and device_filter(d.get("device_id", ""))
502-
]
503-
if len(device_list) <= self.pcp_rank * self.tp_size + self.tp_rank:
504-
continue
505-
device_info = device_list[self.pcp_rank * self.tp_size +
506-
self.tp_rank]
507-
super_pod_id_ = device_info.get("super_pod_id", None)
508-
server_id_ = device_info["server_id"]
509-
device_id_ = device_info["device_id"]
510-
device_ip_ = device_info["device_ip"]
511-
super_device_id_ = device_info.get("super_device_id", None)
512-
cluster_id_ = int(device_info["cluster_id"])
513-
agent_metadata = LLMDataDistCMgrAgentMetadata(
514-
super_pod_id=super_pod_id_,
515-
server_id=server_id_,
516-
device_id=device_id_,
517-
device_ip=device_ip_,
518-
super_device_id=super_device_id_,
519-
cluster_id=cluster_id_,
520-
)
521+
522+
if is_A5():
523+
for device_type in devices_type_list:
524+
device_info = self.get_device_info(global_rank_table, device_filter, device_type)
525+
if not device_info:
526+
continue
527+
server_id_ = device_info["server_id"]
528+
device_id_ = device_info["device_id"]
529+
cluster_id_ = int(device_info["cluster_id"])
530+
level_list_ = device_info["level_list"]
531+
agent_metadata = LLMDataDistCMgrAgentMetadataA5(
532+
server_id=server_id_,
533+
device_id=device_id_,
534+
device_ip=device_ip_,
535+
cluster_id=cluster_id_,
536+
level_list = level_list_,
537+
)
538+
else :
539+
for device_type in devices_type_list:
540+
device_info = self.get_device_info(global_rank_table, device_filter, device_type)
541+
super_pod_id_ = device_info.get("super_pod_id", None)
542+
server_id_ = device_info["server_id"]
543+
device_id_ = device_info["device_id"]
544+
device_ip_ = device_info["device_ip"]
545+
super_device_id_ = device_info.get("super_device_id", None)
546+
cluster_id_ = int(device_info["cluster_id"])
547+
agent_metadata = LLMDataDistCMgrAgentMetadata(
548+
super_pod_id=super_pod_id_,
549+
server_id=server_id_,
550+
device_id=device_id_,
551+
device_ip=device_ip_,
552+
super_device_id=super_device_id_,
553+
cluster_id=cluster_id_,
554+
)
521555
assert agent_metadata is not None, f"Can't read the target server_id {self.local_ip} and device_rank {self.rank} from rank table"
522556
return agent_metadata
523557

@@ -817,6 +851,78 @@ def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
817851
)
818852
return remote_cluster_id
819853

854+
def create_ranktable(self, cluster_rank_info, prefill_metadata, decode_metadata):
855+
rank_list:List[Dict] =[]
856+
for cluster_id, rank_idx in cluster_rank_info.items():
857+
#确定当前处理的metaData
858+
if cluster_id == prefill_metadata.cluster_id:
859+
current_metadata = prefill_metadata
860+
else:
861+
current_metadata = decode_metadata
862+
#创建rank_info
863+
rank_info = {
864+
"rank_id":str(rank_idx),
865+
"device_id":current_metadata.device_id,
866+
"cluster_id": current_metadata.cluster_id,
867+
"level_list":current_metadata.level_list
868+
}
869+
rank_list.append(rank_info)
870+
rank_table = {
871+
"version": "2.0",
872+
"rank_count":"2",
873+
"status":"completed",
874+
"rank_list":rank_list
875+
}
876+
return rank_table
877+
878+
def add_remote_agentA5(self, metadata: LLMDataDistCMgrAgentMetadataA5) -> int:
879+
assert self.local_agent_metadata is not None
880+
remote_cluster_id = metadata.cluster_id
881+
if remote_cluster_id in self.linked_cluster:
882+
logger.debug(
883+
f"LLMDataDistCMgrConnectorWorker: remote cluster_id: {metadata.cluster_id} already linked with this server, skip the connection"
884+
)
885+
return remote_cluster_id
886+
if self.llm_datadist_role == LLMRole.PROMPT:
887+
prefill_metadata = self.local_agent_metadata
888+
decode_metadata = metadata
889+
else:
890+
prefill_metadata = metadata
891+
decode_metadata = self.local_agent_metadata
892+
comm_name = f"pd_comm_{prefill_metadata.device_id}_{decode_metadata.device_id}"
893+
cluster_rank_info = {
894+
prefill_metadata.cluster_id: 0,
895+
decode_metadata.cluster_id: 1
896+
}
897+
rank_table = self.create_rank_table(cluster_rank_info, prefill_metadata, decode_metadata)
898+
logger.info(
899+
f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: {comm_name}"
900+
)
901+
logger.info(f"rank table \n{rank_table}")
902+
logger.info(f"comm name: {comm_name}")
903+
logger.info(f"cluster rank info: {cluster_rank_info}")
904+
comm_id = self.llm_datadist.link(comm_name, cluster_rank_info,
905+
json.dumps(rank_table))
906+
while True:
907+
ret = self.llm_datadist.query_register_mem_status(comm_id=comm_id)
908+
if ret == llm_datadist.RegisterMemStatus.OK:
909+
logger.info(
910+
f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: {comm_id}"
911+
)
912+
break
913+
elif ret == llm_datadist.RegisterMemStatus.FAILED:
914+
raise RuntimeError(
915+
f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: {comm_id}"
916+
)
917+
time.sleep(1)
918+
logger.info("Checking query_register_mem_status again")
919+
self.linked_cluster.update({remote_cluster_id: comm_id})
920+
logger.info(f"cached linked cluster: {self.linked_cluster}")
921+
logger.info(
922+
f"Successfully build link with cluster id {remote_cluster_id} with cluster name {comm_name} !"
923+
)
924+
return remote_cluster_id
925+
820926
def remove_remote_agent(self, cluster_id: int):
821927
if cluster_id not in self.linked_cluster:
822928
logger.warning(
@@ -846,9 +952,13 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
846952
metadata_bytes = sock.recv()
847953
decoder = msgspec.msgpack.Decoder()
848954
metadata = decoder.decode(metadata_bytes)
849-
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
955+
if is_A5():
956+
metadata = LLMDataDistCMgrAgentMetadataA5(**metadata)
957+
cluster_id = self.add_remote_agentA5(metadata)
958+
else :
959+
metadata = LLMDataDistCMgrAgentMetadata(**metadata)
960+
cluster_id = self.add_remote_agent(metadata)
850961
logger.info(f"recving metadata: {metadata}")
851-
cluster_id = self.add_remote_agent(metadata)
852962
return cluster_id
853963

854964
def send_finish_to_remote(self, host: str, ports: list[int], request_id):

0 commit comments

Comments
 (0)