1010from concurrent .futures import ThreadPoolExecutor
1111from dataclasses import dataclass
1212from enum import Enum
13- from typing import Any , Callable , Optional , Tuple
13+ from typing import Any , Callable , Optional , Tuple , List , Dict
1414
1515import llm_datadist # type: ignore
1616import msgspec
@@ -70,6 +70,11 @@ class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
7070 super_device_id : str
7171 cluster_id : int
7272
73+ class LLMDataDistCMgrAgentMetaDataA5 (msgspec .Struct ):
74+ server_id : str
75+ device_id : str
76+ cluster_id : int
77+ level_list : List [Dict ]
7378
7479@dataclass
7580class ReqMeta :
@@ -351,8 +356,10 @@ def __init__(self, vllm_config: VllmConfig):
351356 self .dcp_size = get_dcp_group ().world_size
352357 self .local_ip = get_ip ()
353358 self .kv_transfer_config : KVTransferConfig = vllm_config .kv_transfer_config
354- self .local_agent_metadata : Optional [
355- LLMDataDistCMgrAgentMetadata ] = None
359+ if is_A5 ():
360+ self .local_agent_metadata : Optional [LLMDataDistCMgrAgentMetaDataA5 ] = None
361+ else :
362+ self .local_agent_metadata : Optional [LLMDataDistCMgrAgentMetadata ] = None
356363 self .vllm_config = vllm_config
357364 self .executor = ThreadPoolExecutor (1 )
358365 self .thread_lock = threading .Lock ()
@@ -409,12 +416,18 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
409416 event_msg = LLMDataDistCMgrEvent (event_msg )
410417 if event_msg == LLMDataDistCMgrEvent .ReqForMetadata :
411418 if "cluster_id" in decode_msg :
412- decode_msg = LLMDataDistCMgrAgentMetadata (** decode_msg )
419+ if is_A5 ():
420+ decode_msg = LLMDataDistCMgrAgentMetadataA5 (** decode_msg )
421+ else :
422+ decode_msg = LLMDataDistCMgrAgentMetadata (** decode_msg )
413423 logger .info (
414424 f"LLMDataDistCMgrConnectorWorker: Receive message from cluster { decode_msg .cluster_id } "
415425 )
416426 sock .send_multipart ((identity , b"" , msg_to_send ))
417- self .add_remote_agent (decode_msg )
427+ if is_A5 ():
428+ self .add_remote_agentA5 (decode_msg )
429+ else :
430+ self .add_remote_agent (decode_msg )
418431 else :
419432 logger .warning (
420433 f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data { decode_msg } "
@@ -483,6 +496,17 @@ def _get_visible_devices() -> Callable[[str], bool]:
483496 visible_device_list = visible_devices .split ("," )
484497 return lambda device_id : device_id in visible_device_list
485498
499+ def get_device_info (self , global_rank_table , device_filter , device_type ):
500+ device_list = global_rank_table [device_type ]
501+ device_list = [
502+ d for d in device_list if d .get ("server_id" ) == self .local_ip
503+ and device_filter (d .get ("device_id" , "" ))
504+ ]
505+ if len (device_list ) <= self .pcp_rank * self .tp_size + self .tp_rank :
506+ retunr None
507+ device_info = device_list [self .pcp_rank * self .tp_size + self .tp_rank ]
508+ return device_info
509+
486510 def read_agent_metadata (self , global_rank_table ):
487511 device_filter = LLMDataDistCMgrConnectorWorker ._get_visible_devices ()
488512 devices_type_list = []
@@ -494,30 +518,40 @@ def read_agent_metadata(self, global_rank_table):
494518 else :
495519 devices_type_list .append ("prefill_device_list" )
496520 devices_type_list .append ("decode_device_list" )
497- for device_type in devices_type_list :
498- device_list = global_rank_table [device_type ]
499- device_list = [
500- d for d in device_list if d .get ("server_id" ) == self .local_ip
501- and device_filter (d .get ("device_id" , "" ))
502- ]
503- if len (device_list ) <= self .pcp_rank * self .tp_size + self .tp_rank :
504- continue
505- device_info = device_list [self .pcp_rank * self .tp_size +
506- self .tp_rank ]
507- super_pod_id_ = device_info .get ("super_pod_id" , None )
508- server_id_ = device_info ["server_id" ]
509- device_id_ = device_info ["device_id" ]
510- device_ip_ = device_info ["device_ip" ]
511- super_device_id_ = device_info .get ("super_device_id" , None )
512- cluster_id_ = int (device_info ["cluster_id" ])
513- agent_metadata = LLMDataDistCMgrAgentMetadata (
514- super_pod_id = super_pod_id_ ,
515- server_id = server_id_ ,
516- device_id = device_id_ ,
517- device_ip = device_ip_ ,
518- super_device_id = super_device_id_ ,
519- cluster_id = cluster_id_ ,
520- )
521+
522+ if is_A5 ():
523+ for device_type in devices_type_list :
524+ device_info = self .get_device_info (global_rank_table , device_filter , device_type )
525+ if not device_info :
526+ continue
527+ server_id_ = device_info ["server_id" ]
528+ device_id_ = device_info ["device_id" ]
529+ cluster_id_ = int (device_info ["cluster_id" ])
530+ level_list_ = device_info ["level_list" ]
531+ agent_metadata = LLMDataDistCMgrAgentMetadataA5 (
532+ server_id = server_id_ ,
533+ device_id = device_id_ ,
534+ device_ip = device_ip_ ,
535+ cluster_id = cluster_id_ ,
536+ level_list = level_list_ ,
537+ )
538+ else :
539+ for device_type in devices_type_list :
540+ device_info = self .get_device_info (global_rank_table , device_filter , device_type )
541+ super_pod_id_ = device_info .get ("super_pod_id" , None )
542+ server_id_ = device_info ["server_id" ]
543+ device_id_ = device_info ["device_id" ]
544+ device_ip_ = device_info ["device_ip" ]
545+ super_device_id_ = device_info .get ("super_device_id" , None )
546+ cluster_id_ = int (device_info ["cluster_id" ])
547+ agent_metadata = LLMDataDistCMgrAgentMetadata (
548+ super_pod_id = super_pod_id_ ,
549+ server_id = server_id_ ,
550+ device_id = device_id_ ,
551+ device_ip = device_ip_ ,
552+ super_device_id = super_device_id_ ,
553+ cluster_id = cluster_id_ ,
554+ )
521555 assert agent_metadata is not None , f"Can't read the target server_id { self .local_ip } and device_rank { self .rank } from rank table"
522556 return agent_metadata
523557
@@ -817,6 +851,78 @@ def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
817851 )
818852 return remote_cluster_id
819853
854+ def create_ranktable (self , cluster_rank_info , prefill_metadata , decode_metadata ):
855+ rank_list :List [Dict ] = []
856+ for cluster_id , rank_idx in cluster_rank_info .items ():
857+ #确定当前处理的metaData
858+ if cluster_id == prefill_metadata .cluster_id :
859+ current_metadata = prefill_metadata
860+ else :
861+ current_metadata = decode_metadata
862+ #创建rank_info
863+ rank_info = {
864+ "rank_id" :str (rank_idx ),
865+ "device_id" :current_metadata .device_id ,
866+ "cluster_id" : current_metadata .cluster_id ,
867+ "level_list" :current_metadata .level_list
868+ }
869+ rank_list .append (rank_info )
870+ rank_table = {
871+ "version" : "2.0" ,
872+ "rank_count" :"2" ,
873+ "status" :"completed" ,
874+ "rank_list" :rank_list
875+ }
876+ return rank_table
877+
878+ def add_remote_agentA5 (self , metadata : LLMDataDistCMgrAgentMetadataA5 ) -> int :
879+ assert self .local_agent_metadata is not None
880+ remote_cluster_id = metadata .cluster_id
881+ if remote_cluster_id in self .linked_cluster :
882+ logger .debug (
883+ f"LLMDataDistCMgrConnectorWorker: remote cluster_id: { metadata .cluster_id } already linked with this server, skip the connection"
884+ )
885+ return remote_cluster_id
886+ if self .llm_datadist_role == LLMRole .PROMPT :
887+ prefill_metadata = self .local_agent_metadata
888+ decode_metadata = metadata
889+ else :
890+ prefill_metadata = metadata
891+ decode_metadata = self .local_agent_metadata
892+ comm_name = f"pd_comm_{ prefill_metadata .device_id } _{ decode_metadata .device_id } "
893+ cluster_rank_info = {
894+ prefill_metadata .cluster_id : 0 ,
895+ decode_metadata .cluster_id : 1
896+ }
897+ rank_table = self .create_rank_table (cluster_rank_info , prefill_metadata , decode_metadata )
898+ logger .info (
899+ f"LLMDataDistCMgrConnectorWorker: try link with remote, comm id: { comm_name } "
900+ )
901+ logger .info (f"rank table \n { rank_table } " )
902+ logger .info (f"comm name: { comm_name } " )
903+ logger .info (f"cluster rank info: { cluster_rank_info } " )
904+ comm_id = self .llm_datadist .link (comm_name , cluster_rank_info ,
905+ json .dumps (rank_table ))
906+ while True :
907+ ret = self .llm_datadist .query_register_mem_status (comm_id = comm_id )
908+ if ret == llm_datadist .RegisterMemStatus .OK :
909+ logger .info (
910+ f"LLMDataDistCMgrConnectorWorker: Linking success, comm id: { comm_id } "
911+ )
912+ break
913+ elif ret == llm_datadist .RegisterMemStatus .FAILED :
914+ raise RuntimeError (
915+ f"LLMDataDistCMgrConnectorWorker: Linking failed, comm id: { comm_id } "
916+ )
917+ time .sleep (1 )
918+ logger .info ("Checking query_register_mem_status again" )
919+ self .linked_cluster .update ({remote_cluster_id : comm_id })
920+ logger .info (f"cached linked cluster: { self .linked_cluster } " )
921+ logger .info (
922+ f"Successfully build link with cluster id { remote_cluster_id } with cluster name { comm_name } !"
923+ )
924+ return remote_cluster_id
925+
820926 def remove_remote_agent (self , cluster_id : int ):
821927 if cluster_id not in self .linked_cluster :
822928 logger .warning (
@@ -846,9 +952,13 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
846952 metadata_bytes = sock .recv ()
847953 decoder = msgspec .msgpack .Decoder ()
848954 metadata = decoder .decode (metadata_bytes )
849- metadata = LLMDataDistCMgrAgentMetadata (** metadata )
955+ if is_A5 ():
956+ metadata = LLMDataDistCMgrAgentMetadataA5 (** metadata )
957+ cluster_id = self .add_remote_agentA5 (metadata )
958+ else :
959+ metadata = LLMDataDistCMgrAgentMetadata (** metadata )
960+ cluster_id = self .add_remote_agent (metadata )
850961 logger .info (f"recving metadata: { metadata } " )
851- cluster_id = self .add_remote_agent (metadata )
852962 return cluster_id
853963
854964 def send_finish_to_remote (self , host : str , ports : list [int ], request_id ):
0 commit comments