@@ -112,7 +112,8 @@ def __init__(
112112
113113 self .loop_context = None
114114
115- def __init__ (self , spec : AlgoSpec ):
115+ @classmethod
116+ def from_spec (cls , spec : AlgoSpec ):
116117 """Initialize a new CollectiveProgram from an algorithm specification.
117118
118119 This constructor provides an alternative way to create a CollectiveProgram
@@ -139,52 +140,26 @@ def __init__(self, spec: AlgoSpec):
139140 ... protocol="Simple",
140141 ... in_place=False
141142 ... )
142- >>> with CollectiveProgram(spec) as prog:
143+ >>> with CollectiveProgram.from_spec (spec) as prog:
143144 ... # Define communication operations
144145 ... pass
145146 """
146- self .name = spec .name
147- self .collective = spec .collective
148- self .num_ranks = spec .world_size
149- self .in_place = spec .in_place
150- self .instances = spec .instances
151- self .protocol = spec .protocol
152- self .instr_fusion = spec .instr_fusion
153- self .auto_sync = spec .auto_sync
154- self .replication_policy = spec .replication_policy
155- self .reuse_resources = spec .reuse_resources
156- self .num_threads_per_block = spec .num_threads_per_block
157- self .use_double_scratch_buffer = spec .use_double_scratch_buffer
158- self .buffer_alignment = spec .buffer_alignment
159- self .min_message_size = spec .min_message_size
160- self .max_message_size = spec .max_message_size
161- assert (
162- self .protocol == "Simple" or self .protocol == "LL"
163- ), f"Given protocol: { self .protocol } . Must be either Simple, LL"
164- self .buffers = self .collective .init_buffers ()
165- self .gpus : List [Gpu ] = []
166- for rank in range (self .num_ranks ):
167- self .gpus .append (
168- Gpu (rank , self .buffers [rank ][BufferType .input ].size , self .buffers [rank ][BufferType .output ].size , 0 )
169- )
170-
171- self .loop_context = None
172-
173- def _create_collective_from_name (self , collective_name : str , world_size : int , in_place : bool ):
174- """Create a collective instance based on the collective name."""
175- chunk_factor = 1
176- if collective_name .lower () == "allgather" :
177- return AllGather (world_size , chunk_factor , in_place )
178- elif collective_name .lower () == "allreduce" :
179- return AllReduce (world_size , chunk_factor , in_place )
180- elif collective_name .lower () == "reducescatter" :
181- return ReduceScatter (world_size , chunk_factor , in_place )
182- elif collective_name .lower () == "alltoall" :
183- return AllToAll (world_size , chunk_factor , in_place )
184- else :
185- raise ValueError (
186- f"Unknown collective name: { collective_name } . Supported collectives: allgather, allreduce, reducescatter, alltoall"
187- )
147+ return cls (
148+ spec .name ,
149+ spec .collective ,
150+ spec .world_size ,
151+ instances = spec .instances ,
152+ protocol = spec .protocol ,
153+ instr_fusion = spec .instr_fusion ,
154+ auto_sync = spec .auto_sync ,
155+ replication_policy = spec .replication_policy ,
156+ reuse_resources = spec .reuse_resources ,
157+ num_threads_per_block = spec .num_threads_per_block ,
158+ use_double_scratch_buffer = spec .use_double_scratch_buffer ,
159+ buffer_alignment = spec .buffer_alignment ,
160+ min_message_size = spec .min_message_size ,
161+ max_message_size = spec .max_message_size ,
162+ )
188163
189164 def __enter__ (self ):
190165 """Enter the program context and set this as the active program.
0 commit comments