Skip to content

Commit 8e14e97

Browse files
committed
wip
1 parent 858c3b7 commit 8e14e97

File tree

3 files changed

+21
-46
lines changed

3 files changed

+21
-46
lines changed

examples/torch-integration/dsl_with_nccl_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def allreduce_nvls(spec: mscclpp.AlgoSpec) -> CollectiveProgram:
1616
gpu_size = spec.world_size
17-
with CollectiveProgram(spec) as program:
17+
with CollectiveProgram.from_spec(spec) as program:
1818
# Creating Channels
1919
nvls_chan = SwitchChannel(rank_list=[gpu for gpu in range(gpu_size)], buffer_type=BufferType.input)
2020
channels = {}

python/mscclpp/language/default_algos/allreduce_2nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size) -> CollectiveProgr
2828
total_gpus = num_nodes * gpus_per_node
2929
packets_per_gpu = 2
3030

31-
with CollectiveProgram(spec) as prog:
31+
with CollectiveProgram.from_spec(spec) as prog:
3232
# Initialize communication channels and buffers
3333
intra_node_memory_channels = {}
3434
inter_node_port_channels = {}

python/mscclpp/language/program.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def __init__(
112112

113113
self.loop_context = None
114114

115-
def __init__(self, spec: AlgoSpec):
115+
@classmethod
116+
def from_spec(cls, spec: AlgoSpec):
116117
"""Initialize a new CollectiveProgram from an algorithm specification.
117118
118119
This constructor provides an alternative way to create a CollectiveProgram
@@ -139,52 +140,26 @@ def __init__(self, spec: AlgoSpec):
139140
... protocol="Simple",
140141
... in_place=False
141142
... )
142-
>>> with CollectiveProgram(spec) as prog:
143+
>>> with CollectiveProgram.from_spec(spec) as prog:
143144
... # Define communication operations
144145
... pass
145146
"""
146-
self.name = spec.name
147-
self.collective = spec.collective
148-
self.num_ranks = spec.world_size
149-
self.in_place = spec.in_place
150-
self.instances = spec.instances
151-
self.protocol = spec.protocol
152-
self.instr_fusion = spec.instr_fusion
153-
self.auto_sync = spec.auto_sync
154-
self.replication_policy = spec.replication_policy
155-
self.reuse_resources = spec.reuse_resources
156-
self.num_threads_per_block = spec.num_threads_per_block
157-
self.use_double_scratch_buffer = spec.use_double_scratch_buffer
158-
self.buffer_alignment = spec.buffer_alignment
159-
self.min_message_size = spec.min_message_size
160-
self.max_message_size = spec.max_message_size
161-
assert (
162-
self.protocol == "Simple" or self.protocol == "LL"
163-
), f"Given protocol: {self.protocol}. Must be either Simple, LL"
164-
self.buffers = self.collective.init_buffers()
165-
self.gpus: List[Gpu] = []
166-
for rank in range(self.num_ranks):
167-
self.gpus.append(
168-
Gpu(rank, self.buffers[rank][BufferType.input].size, self.buffers[rank][BufferType.output].size, 0)
169-
)
170-
171-
self.loop_context = None
172-
173-
def _create_collective_from_name(self, collective_name: str, world_size: int, in_place: bool):
174-
"""Create a collective instance based on the collective name."""
175-
chunk_factor = 1
176-
if collective_name.lower() == "allgather":
177-
return AllGather(world_size, chunk_factor, in_place)
178-
elif collective_name.lower() == "allreduce":
179-
return AllReduce(world_size, chunk_factor, in_place)
180-
elif collective_name.lower() == "reducescatter":
181-
return ReduceScatter(world_size, chunk_factor, in_place)
182-
elif collective_name.lower() == "alltoall":
183-
return AllToAll(world_size, chunk_factor, in_place)
184-
else:
185-
raise ValueError(
186-
f"Unknown collective name: {collective_name}. Supported collectives: allgather, allreduce, reducescatter, alltoall"
187-
)
147+
return cls(
148+
spec.name,
149+
spec.collective,
150+
spec.world_size,
151+
instances=spec.instances,
152+
protocol=spec.protocol,
153+
instr_fusion=spec.instr_fusion,
154+
auto_sync=spec.auto_sync,
155+
replication_policy=spec.replication_policy,
156+
reuse_resources=spec.reuse_resources,
157+
num_threads_per_block=spec.num_threads_per_block,
158+
use_double_scratch_buffer=spec.use_double_scratch_buffer,
159+
buffer_alignment=spec.buffer_alignment,
160+
min_message_size=spec.min_message_size,
161+
max_message_size=spec.max_message_size,
162+
)
188163

189164
def __enter__(self):
190165
"""Enter the program context and set this as the active program.

0 commit comments

Comments
 (0)