Skip to content

Conversation

@kip-cxj
Copy link
Contributor

@kip-cxj kip-cxj commented Dec 16, 2025

Motivation

Add stateless communication group. To enable more flexible creation of communication groups and resolve compatibility issues with other programs that also use the torch.distributed.
Current support vllm, while sglang does not yet supprt pyhccl. Which feature depends on add pyhccl in sglang.
If the current approach in accptable, we will provide sglang version soon.

@kip-cxj kip-cxj force-pushed the main branch 2 times, most recently from 1b27b3f to f989a80 Compare December 17, 2025 07:12
@kip-cxj kip-cxj changed the title draft: add collective communication for npu draft: add stateless communication for npu Dec 30, 2025
@x1314aq
Copy link

x1314aq commented Jan 7, 2026

@weixiao-huang @HubertZhang pls review this PR

test both on npu and cuda.

Model Device Info device_type GatherMetas Update (Broadcast) Update (P2P)
Qwen3-8b 8xNvidia-A100 TP4 cuda 0.01s 1.28s (1.46GiB) 7.81s (1.72GiB)
Qwen3-8b 8xAscend-A3 TP4 npu 0.02s 1.37s (1.59GiB) 2.02s (1.47GiB)

test the same model using default torch.distributed module.

Model Device Info device_type GatherMetas Update (Broadcast) Update (P2P)
Qwen3-8b 8xNvidia-A100 TP4 torch 0.01s 1.15s (1.46GiB) 7.68s (1.71GiB)
Qwen3-8b 8xAscend-A3 TP4 torch 0.03s 1.44s (1.59GiB) 3.83s (1.46GiB)

@kip-cxj kip-cxj changed the title draft: add stateless communication for npu feat: Replace torch.distributed with StatelessProcessGroup Jan 8, 2026
@kip-cxj kip-cxj changed the title feat: Replace torch.distributed with StatelessProcessGroup feat: add StatelessProcessGroup to extend collective library Jan 8, 2026
@weixiao-huang
Copy link
Collaborator

It seems this PR should depend on vLLM, this is so heavy and not an elegant way. I think ps.py should be a lightweight component, which may not depend on other heavy framework

@hanhan-networking
Copy link

It seems this PR should depend on vLLM, this is so heavy and not an elegant way. I think ps.py should be a lightweight component, which may not depend on other heavy framework

默认的还是通信方式还是torch.distributed诶,只有需要跨资源的时候才需要用到StatelessProcessGroup,如果不支持这个的话,没法合入到verl呀😆 ,不支持训推分离的架构

@HubertZhang
Copy link
Collaborator

是否应当设计一个 protocol DistrubutedLib,给 ps 传入一个 dist: DistributedLib 比较好一些?目前这个 import 的写法感觉隔离的还不太够?

@x1314aq
Copy link

x1314aq commented Jan 12, 2026

是否应当设计一个 protocol DistrubutedLib,给 ps 传入一个 dist: DistributedLib 比较好一些?目前这个 import 的写法感觉隔离的还不太够?

新增了一个dist_wrapper.py文件,把处理dist的逻辑放进去了,这样ps.pyupdate.py就可以通过from dist_wrapper import dist直接使用dist模块,不用关心具体实现

@x1314aq
Copy link

x1314aq commented Jan 12, 2026

It seems this PR should depend on vLLM, this is so heavy and not an elegant way. I think ps.py should be a lightweight component, which may not depend on other heavy framework

In most cases, the logic remains consistent with that before. Only need to depend on vLLM when the cutom distribued module is required. It does not change the way that checkpoint-engien is a lightweight component.

@x1314aq
Copy link

x1314aq commented Jan 13, 2026

是否应当设计一个 protocol DistrubutedLib,给 ps 传入一个 dist: DistributedLib 比较好一些?目前这个 import 的写法感觉隔离的还不太够?

新增了一个dist_wrapper.py文件,把处理dist的逻辑放进去了,这样ps.pyupdate.py就可以通过from dist_wrapper import dist直接使用dist模块,不用关心具体实现

本地试了下用dist_wrapper.py包装有问题,现在改成在distributed/base.py里面把torch.distributed包装进去。用法上只需要把import torch.distributed as dist替换成import checkpoint_engine.distributed as dist

# import torch.distributed as dist
import checkpoint_engine.distributed as dist

dist.init_process_group()
dist.all_reduce()
dist.xxxx()

如果需要使用custom distributed模块的话,只需要给ps传一个custom_dist=True就行。

@x1314aq
Copy link

x1314aq commented Jan 15, 2026

@weixiao-huang @HubertZhang

如果没有其他review意见的话,能否合入下?

If no more comments, can this be merged?

@HubertZhang
Copy link
Collaborator

HubertZhang commented Jan 15, 2026

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

@kip-cxj
Copy link
Contributor Author

kip-cxj commented Jan 19, 2026

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

@HubertZhang
Copy link
Collaborator

HubertZhang commented Jan 20, 2026

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

我仔细看了一下,是否 ParameterServer.gather_metasParameterServer.update_weights 这两个函数,接受一个支持 subgroup 的 Distributed 的实现

class Distributed(ABC):
    ...

    @abstractmethod
    def sub_group(self, ranks: list[int]) -> "AbstractProcessGroup":
        ...

这两个函数里涉及到通信的地方会简化很多,直接用传进来的 process_group.all_gatherprocess_group. broadcast

@kip-cxj
Copy link
Contributor Author

kip-cxj commented Jan 22, 2026

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

我仔细看了一下,是否 ParameterServer.gather_metasParameterServer.update_weights 这两个函数,接受一个支持 subgroup 的 Distributed 的实现

class Distributed(ABC):
    ...

    @abstractmethod
    def sub_group(self, ranks: list[int]) -> "AbstractProcessGroup":
        ...

这两个函数里涉及到通信的地方会简化很多,直接用传进来的 process_group.all_gatherprocess_group. broadcast

我理解当前的抽象方法,符合集合通信惯用的使用方法,这么修改的话不符合使用习惯。

@HubertZhang
Copy link
Collaborator

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

我仔细看了一下,是否 ParameterServer.gather_metasParameterServer.update_weights 这两个函数,接受一个支持 subgroup 的 Distributed 的实现

class Distributed(ABC):
    ...

    @abstractmethod
    def sub_group(self, ranks: list[int]) -> "AbstractProcessGroup":
        ...

这两个函数里涉及到通信的地方会简化很多,直接用传进来的 process_group.all_gatherprocess_group. broadcast

我理解当前的抽象方法,符合集合通信惯用的使用方法,这么修改的话不符合使用习惯。

vllm 里 StatelessProcessGroup 就是直接用 pg.broadcast 这样的使用方法?dist.broadcast(..., pg) -> pg.broadcast 感觉清楚一点🤔

@kip-cxj
Copy link
Contributor Author

kip-cxj commented Jan 22, 2026

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

我仔细看了一下,是否 ParameterServer.gather_metasParameterServer.update_weights 这两个函数,接受一个支持 subgroup 的 Distributed 的实现

class Distributed(ABC):
    ...

    @abstractmethod
    def sub_group(self, ranks: list[int]) -> "AbstractProcessGroup":
        ...

这两个函数里涉及到通信的地方会简化很多,直接用传进来的 process_group.all_gatherprocess_group. broadcast

我理解当前的抽象方法,符合集合通信惯用的使用方法,这么修改的话不符合使用习惯。

vllm 里 StatelessProcessGroup 就是直接用 pg.broadcast 这样的使用方法?dist.broadcast(..., pg) -> pg.broadcast 感觉清楚一点🤔

vllm里StatelessProcessGroup只用来传输metadata: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/utils.py#L146 。数据面的传输还是用pynccl(pynccl): https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py#L15

@HubertZhang
Copy link
Collaborator

话说要不要抽象 StatelessProcessGroup 而非 dist 呢,在 ps 中直接使用封装好的高级方法看起来方便很多?想象中 sub group 的部分可能会复杂一点但是其他的地方应当简单很多?

# dist/vllm.py
from vllm.distributed import StatelessProcessGroup
VLLMStatelessProcessGroup = StatelessProcessGroup

# ps.py
class ParameterServer:
    def __init__(self, grouo):
        self.group = group
        ...
    def gather(self):
        self.group.broadcast(self.metas)

没太理解这块的意思,集合通信是调用的PyNcclCommunicator(PyHcclCommunicator)中的NCCLLibrary、HCCLLibrary实现的,StatelessProcessGroup只在init时Communicator用到

我仔细看了一下,是否 ParameterServer.gather_metasParameterServer.update_weights 这两个函数,接受一个支持 subgroup 的 Distributed 的实现

class Distributed(ABC):
    ...

    @abstractmethod
    def sub_group(self, ranks: list[int]) -> "AbstractProcessGroup":
        ...

这两个函数里涉及到通信的地方会简化很多,直接用传进来的 process_group.all_gatherprocess_group. broadcast

我理解当前的抽象方法,符合集合通信惯用的使用方法,这么修改的话不符合使用习惯。

vllm 里 StatelessProcessGroup 就是直接用 pg.broadcast 这样的使用方法?dist.broadcast(..., pg) -> pg.broadcast 感觉清楚一点🤔

vllm里StatelessProcessGroup只用来传输metadata: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/utils.py#L146 。数据面的传输还是用pynccl(pynccl): https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py#L15

哦 我不是说要用 StatelessProcessGroup 做数据面的传输,是说数据面的传输接口 broadcast 可以类似pg.broadcast 的设计, 在 xxx.broadcastxxx 里就已经封装了 sub_ranks 以及对应的 process_group 和 nccl_comm_t ,不需要在 broadcast 的参数里再传递 pg 了。

换句话说是希望 xxx.broadcast 里 xxx 是个对象而非一个模块 😂

@kip-cxj kip-cxj force-pushed the main branch 6 times, most recently from b0c6ca0 to 47a2561 Compare January 28, 2026 07:28
@kip-cxj kip-cxj force-pushed the main branch 4 times, most recently from 6ad7671 to 0901e9f Compare January 29, 2026 11:58
@HubertZhang
Copy link
Collaborator

其他测了一下应该 vllm_nccl 没问题,可以 rebase 一下然后大致按照功能 squash 下吗

@kip-cxj kip-cxj force-pushed the main branch 3 times, most recently from 75268a4 to d455a21 Compare January 30, 2026 07:52
@kip-cxj
Copy link
Contributor Author

kip-cxj commented Jan 30, 2026

其他测了一下应该 vllm_nccl 没问题,可以 rebase 一下然后大致按照功能 squash 下吗

我这边也测试了hccl的部分,然后squash成了一个commit,麻烦看下能否合入

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants