Skip to content

Commit 94e3cb4

Browse files
authored
made local_rank to be a returned argument instead and added a note on safe tensors in the README (#26)
1 parent f0029b3 commit 94e3cb4

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

demos/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LlamaAttention(nn.Module):
2525
attn_weights = self.dummy(attn_weights)
2626
...
2727
```
28-
You may also need to change the module string literal under the `get_module_names` function.
28+
You may also need to change the module string literal under the `get_module_names` function. If you do not have safe tensors downloaded as part of your HF model, you will need to pass in `use_safetensors=False` as part of the model loading.
2929

3030
To run the file, you can use `torchrun`. We have tested this demo by running Llama-2-70b-hf on 4x A100-80G.
3131
```

demos/induction_heads_multigpu.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.nn.functional as F
14+
from flex_model.core import FlexModel, HookFunction
1415
from torch import nn
1516
from torch.distributed.fsdp import (
1617
FullyShardedDataParallel as FSDP,
@@ -28,16 +29,13 @@
2829
)
2930
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
3031

31-
from flex_model.core import FlexModel, HookFunction
32-
33-
LOCAL_RANK = None
34-
3532

3633
def setup() -> None:
3734
"""Instantiate process group."""
3835
dist.init_process_group("nccl")
39-
global LOCAL_RANK
40-
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
36+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
37+
torch.cuda.set_device(local_rank)
38+
return local_rank
4139

4240

4341
def cleanup() -> None:
@@ -53,19 +51,21 @@ def args() -> Namespace:
5351
return parser.parse_args()
5452

5553

56-
def setup_model(model_path: str) -> tuple[nn.Module, LlamaConfig]:
54+
def setup_model(model_path: str, local_rank: int) -> \
55+
tuple[nn.Module, LlamaConfig]:
5756
"""Instantiate model, tokenizer, and config.
5857
5958
Args:
6059
----
6160
model_path: A path to the model being instantiated
61+
local_rank: The local rank of the worker
6262
6363
Returns:
6464
-------
6565
A tuple of length two containing the model and the config.
6666
"""
6767
config = LlamaConfig.from_pretrained(model_path)
68-
if LOCAL_RANK == 0:
68+
if local_rank == 0:
6969
model = LlamaForCausalLM.from_pretrained(
7070
model_path,
7171
torch_dtype=torch.bfloat16,
@@ -79,9 +79,13 @@ def setup_model(model_path: str) -> tuple[nn.Module, LlamaConfig]:
7979
return model, config
8080

8181

82-
def fsdp_config() -> dict[str:Any]:
82+
def fsdp_config(local_rank: int) -> dict[str:Any]:
8383
"""Return the config to be used by FSDP.
8484
85+
Args:
86+
----
87+
local_rank: The local rank of the worker
88+
8589
Returns:
8690
-------
8791
A dictionary containing keyword -> respective configuration.
@@ -103,7 +107,7 @@ def _module_init_fn(module: nn.Module) -> Callable:
103107
sharding_strategy = ShardingStrategy.FULL_SHARD
104108
device_id = torch.cuda.current_device()
105109
sync_module_states = True
106-
param_init_fn = _module_init_fn if LOCAL_RANK != 0 else None
110+
param_init_fn = _module_init_fn if local_rank != 0 else None
107111
mp_policy = MixedPrecision(
108112
param_dtype=torch.bfloat16,
109113
buffer_dtype=torch.bfloat16,
@@ -267,14 +271,14 @@ def main(args: Namespace) -> None:
267271
----
268272
args: Command-line arguments
269273
"""
270-
torch.cuda.set_device(LOCAL_RANK)
274+
local_rank = setup()
271275

272276
seq_len = args.seq_length
273277
batch_size = 4
274278
min_vocab_idx, max_vocab_idx = 500, 15000
275279

276280
prompt = torch.randint(
277-
min_vocab_idx, max_vocab_idx, (batch_size, seq_len)
281+
min_vocab_idx, max_vocab_idx, (batch_size, seq_len),
278282
).to(
279283
torch.cuda.current_device(),
280284
)
@@ -283,8 +287,8 @@ def main(args: Namespace) -> None:
283287
"batch seq_len -> batch (2 seq_len)",
284288
)
285289

286-
model, config = setup_model(args.model_path)
287-
fsdp_cfg = fsdp_config()
290+
model, config = setup_model(args.model_path, local_rank)
291+
fsdp_cfg = fsdp_config(local_rank)
288292

289293
model = FSDP(
290294
model,
@@ -325,10 +329,9 @@ def main(args: Namespace) -> None:
325329
# Note: we are only calculating this over the main rank's output
326330
# for the purpose of demonstration
327331
calculate_per_token_loss(out, repeated_tokens)
332+
cleanup()
328333

329334

330335
if __name__ == "__main__":
331336
parsed_args = args()
332-
setup()
333337
main(parsed_args)
334-
cleanup()

0 commit comments

Comments
 (0)