-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Is there an existing issue for this bug?
- I have searched the existing issues
The bug has not been fixed in the latest main branch
- I have checked the latest main branch
Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)
Yes, I will share a minimal reproducible script.
🐛 Describe the bug
When using the Gemini plugin together with the CPUAdam optimizer, I encountered the following error:
assert div_scale == -1, "div_scale should remain default"
However, after checking the Gemini plugin source code, I found that these two should be compatible.
The relevant assertion in the code is:
assert type(optim) in _AVAIL_OPTIM_LIST, (
"You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}"
)and _AVAIL_OPTIM_LIST is defined as:
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}which clearly includes CPUAdam in the supported list.
Below is the script main.py that reproduces the issue:
import json
import os
import random
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import colossalai
from colossalai.booster import Booster
from colossalai.nn.optimizer import CPUAdam, Lamb
from colossalai.booster.plugin import GeminiPlugin
class RandomDataset(Dataset):
def __init__(self, num_samples=32 * 100, input_dim=1024, num_classes=10):
self.x = torch.randn(num_samples, input_dim)
self.y = torch.randint(0, num_classes, (num_samples,))
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class MLP(nn.Module):
def __init__(self, input_dim=1024, hidden_dim=512, num_layers=10, num_classes=10):
super().__init__()
layers = []
for i in range(num_layers):
in_dim = input_dim if i == 0 else hidden_dim
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def main():
seed = 1024
colossalai.launch_from_torch(seed=seed)
plugin = GeminiPlugin()
booster = Booster(plugin=plugin)
model = MLP()
optimizer = CPUAdam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
dataset = RandomDataset()
train_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
precision = getattr(plugin, "precision", "fp16")
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map.get(precision, torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for epoch in range(1):
total_loss = 0
for step, (x, y) in enumerate(train_dataloader):
x = x.to(device=device, dtype=dtype)
y = y.to(device=device)
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
booster.backward(loss, optimizer)
optimizer.step()
total_loss += loss.item()
print(f"[Epoch {epoch}] step {step}, loss = {loss.item():.4f}")
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch} finished, average loss = {avg_loss:.4f}")
if __name__ == "__main__":
main()Running the following command:
colossalai run --nproc_per_node 4 main.pyproduces the following log output:
W1112 09:36:44.760368 340847 site-packages/torch/distributed/run.py:793]
W1112 09:36:44.760368 340847 site-packages/torch/distributed/run.py:793] *****************************************
W1112 09:36:44.760368 340847 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1112 09:36:44.760368 340847 site-packages/torch/distributed/run.py:793] *****************************************
[11/12/25 09:36:48] INFO colossalai - colossalai - INFO:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.
9/site-packages/colossalai/initialize.py:75 launch
INFO colossalai - colossalai - INFO: Distributed
environment is initialized, world size: 4
WARNING colossalai - colossalai - WARNING:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.
9/site-packages/colossalai/booster/plugin/gemini_pl
ugin.py:492 __init__
WARNING colossalai - colossalai - WARNING:
enable_async_reduce sets pin_memory=True to achieve
best performance, which is not implicitly set.
[rank0]:[W1112 09:36:49.581515762 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank1]:[W1112 09:36:49.666183352 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank3]:[W1112 09:36:49.739092293 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank2]:[W1112 09:36:49.019903841 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/utils/common.py:59: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if data.storage().size() > 0:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/utils/common.py:59: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if data.storage().size() > 0:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/utils/common.py:59: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if data.storage().size() > 0:
/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/utils/common.py:59: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if data.storage().size() > 0:
[Epoch 0] step 0, loss = 2.2969[Epoch 0] step 0, loss = 2.2969
[Epoch 0] step 0, loss = 2.2969
[Epoch 0] step 0, loss = 2.2969
[Epoch 0] step 1, loss = 2.2988
[Epoch 0] step 1, loss = 2.2988
[Epoch 0] step 1, loss = 2.2988
[Epoch 0] step 1, loss = 2.2988
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 86, in <module>
[rank3]: main()
[rank3]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 76, in main
[rank3]: optimizer.step()
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_optimizer.py", line 288, in step
[rank3]: ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank3]: out = func(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]: return func(*args, **kwargs)
[rank3]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/nn/optimizer/cpu_adam.py", line 193, in step
[rank3]: assert div_scale == -1, "div_scale should remain default"
[rank3]: AssertionError: div_scale should remain default
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 86, in <module>
[rank2]: main()
[rank2]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 76, in main
[rank2]: optimizer.step()
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_optimizer.py", line 288, in step
[rank2]: ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank2]: out = func(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/nn/optimizer/cpu_adam.py", line 193, in step
[rank2]: assert div_scale == -1, "div_scale should remain default"
[rank2]: AssertionError: div_scale should remain default
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 86, in <module>
[rank1]: main()
[rank1]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 76, in main
[rank1]: optimizer.step()
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_optimizer.py", line 288, in step
[rank1]: ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank1]: out = func(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/nn/optimizer/cpu_adam.py", line 193, in step
[rank1]: assert div_scale == -1, "div_scale should remain default"
[rank1]: AssertionError: div_scale should remain default
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 86, in <module>
[rank0]: main()
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/bug1.py", line 76, in main
[rank0]: optimizer.step()
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_optimizer.py", line 288, in step
[rank0]: ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank0]: out = func(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/nn/optimizer/cpu_adam.py", line 193, in step
[rank0]: assert div_scale == -1, "div_scale should remain default"
[rank0]: AssertionError: div_scale should remain default
[rank0]:[W1112 09:36:50.405163289 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
W1112 09:36:51.754890 340847 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 340918 closing signal SIGTERM
W1112 09:36:51.758123 340847 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 340919 closing signal SIGTERM
E1112 09:36:51.922632 340847 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 340920) of binary: /home/yanzhen/miniconda3/envs/colossal/bin/python3.9
Traceback (most recent call last):
File "/home/yanzhen/miniconda3/envs/colossal/bin/torchrun", line 7, in <module>
sys.exit(main())
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
bug1.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2025-11-12_09:36:51
host : ubuntu
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 340921)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-11-12_09:36:51
host : ubuntu
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 340920)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Error: failed to run torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=29505 bug1.py on 127.0.0.1, is localhost: True, exception: Encountered a bad command exit code!
Command: 'cd /home/yanzhen/distributed_test/colossalAI/test && export SHELL="/bin/bash" COLORTERM="truecolor" VSCODE_DEBUGPY_ADAPTER_ENDPOINTS="/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/.noConfigDebugAdapterEndpoints/endpoint-8ca95acfe78cb59c.txt" TERM_PROGRAM_VERSION="1.105.1" CONDA_EXE="/home/yanzhen/miniconda3/bin/conda" NCCL_P2P_DISABLE="1" LC_ADDRESS="zh_CN.UTF-8" LC_NAME="zh_CN.UTF-8" PYDEVD_DISABLE_FILE_VALIDATION="1" LC_MONETARY="zh_CN.UTF-8" PWD="/home/yanzhen/distributed_test/colossalAI/test" LOGNAME="yanzhen" XDG_SESSION_TYPE="tty" CONDA_PREFIX="/home/yanzhen/miniconda3/envs/colossal" BUNDLED_DEBUGPY_PATH="/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/bundled/libs/debugpy" VSCODE_GIT_ASKPASS_NODE="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/node" MOTD_SHOWN="pam" HOME="/home/yanzhen" LC_PAPER="zh_CN.UTF-8" LANG="en_US.UTF-8" LS_COLORS="rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.wim=01;31:*.swm=01;31:*.dwm=01;31:*.esd=01;31:*.jpg=01;35:*.jpeg=01;35:*.mjpg=01;35:*.mjpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=01;35:*.mpg=01;35:*.mpeg=01;35:*.m2v=01;35:*.mkv=01;35:*.webm=01;35:*.webp=01;35:*.ogm=01;35:*.mp4=01;35:*.m4v=01;35:*.mp4v=01;35:*.vob=01;35:*.qt=01;35:*.nuv=01;35:*.wmv=01;35:*.asf=01;35:*.rm=01;35:*.rmvb=01;35:*.flc=01;35:*.avi=01;35:*.fli=01;35:*.flv=01;35:*.gl=01;35:*.dl=01;35:*.xcf=01;35:*.xwd=01;35:*.yuv=01;35:*.cgm=01;35:*.emf=01;35:*.ogv=01;35:*.ogx=01;35:*.aac=00;36:*.au=00;36:*.flac=00;36:*.m4a=00;36:*.mid=00;36:*.midi=00;36:*.mka=00;36:*.mp3=00;36:*.mpc=00;36:*.ogg=00;36:*.ra=00;36:*.wav=00;36:*.oga=00;36:*.opus=00;36:*.spx=00;36:*.xspf=00;36:" PYTHONSTARTUP="/home/yanzhen/.vscode-server/data/User/workspaceStorage/0d3e22743b5008777912953212595ae2/ms-python.python/pythonrc.py" SSL_CERT_DIR="/usr/lib/ssl/certs" CONDA_PROMPT_MODIFIER="(colossal) " GIT_ASKPASS="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/extensions/git/dist/askpass.sh" SSH_CONNECTION="192.168.1.29 32826 192.168.102.133 18022" USE_MODELSCOPE_HUB="1" VSCODE_PYTHON_AUTOACTIVATE_GUARD="1" _CONDA_EXE="/home/yanzhen/miniconda3/bin/conda" LESSCLOSE="/usr/bin/lesspipe %s %s" _CONDA_ROOT="/home/yanzhen/miniconda3" XDG_SESSION_CLASS="user" TERM="xterm-256color" LC_IDENTIFICATION="zh_CN.UTF-8" PYTHON_BASIC_REPL="1" LESSOPEN="| /usr/bin/lesspipe %s" USER="yanzhen" VSCODE_GIT_IPC_HANDLE="/run/user/1006/vscode-git-760712a092.sock" CONDA_SHLVL="2" SHLVL="1" LC_TELEPHONE="zh_CN.UTF-8" LC_MEASUREMENT="zh_CN.UTF-8" XDG_SESSION_ID="6892" CONDA_PYTHON_EXE="/home/yanzhen/miniconda3/bin/python" LD_LIBRARY_PATH="/home/yanzhen/.tensornvme/lib:/usr/local/cuda-12.4/lib64:/home/yanzhen/.tensornvme/lib:/usr/local/cuda-12.4/lib64:" XDG_RUNTIME_DIR="/run/user/1006" SSL_CERT_FILE="/usr/lib/ssl/cert.pem" SSH_CLIENT="192.168.1.29 32826 18022" CONDA_DEFAULT_ENV="colossal" DEBUGINFOD_URLS="https://debuginfod.ubuntu.com " LC_TIME="zh_CN.UTF-8" VSCODE_GIT_ASKPASS_MAIN="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/extensions/git/dist/askpass-main.js" CUDA_HOME="/usr/local/cuda-12.4" XDG_DATA_DIRS="/usr/share/gnome:/usr/local/share:/usr/share:/var/lib/snapd/desktop" BROWSER="/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/bin/helpers/browser.sh" PATH="/usr/local/cuda-12.4/bin:/home/yanzhen/.vscode-server/cli/servers/Stable-7d842fb85a0275a4a8e4d7e040d2625abbf7f084/server/bin/remote-cli:/home/yanzhen/.local/bin:/home/yanzhen/miniconda3/envs/colossal/bin:/home/yanzhen/miniconda3/condabin:/usr/local/cuda-12.4/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/home/yanzhen/miniconda3/bin:/home/yanzhen/.vscode-server/extensions/ms-python.debugpy-2025.14.1/bundled/scripts/noConfigScripts:/home/yanzhen/.vscode-server/data/User/globalStorage/github.copilot-chat/debugCommand:/home/yanzhen/miniconda3/bin" DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1006/bus" CONDA_PREFIX_1="/home/yanzhen/miniconda3" LC_NUMERIC="zh_CN.UTF-8" TERM_PROGRAM="vscode" VSCODE_IPC_HOOK_CLI="/run/user/1006/vscode-ipc-cb574020-ebbf-458e-825e-8dc55d17321b.sock" OLDPWD="/home/yanzhen/distributed_test/colossalAI" _="/home/yanzhen/miniconda3/envs/colossal/bin/colossalai" CUDA_DEVICE_MAX_CONNECTIONS="1" && torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=29505 bug1.py'
Exit code: 1
Stdout: already printed
Stderr: already printed
====== Training on All Nodes =====
127.0.0.1: failure
====== Stopping All Nodes =====
127.0.0.1: finish
Environment
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.9.23 (main, Jun 5 2025, 13:40:20) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
GPU 2: NVIDIA GeForce RTX 4090
GPU 3: NVIDIA GeForce RTX 4090
Nvidia driver version: 580.65.06
cuDNN version: Probably one of the following:
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7773X 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 2
Frequency boost: enabled
CPU max MHz: 3527.7339
CPU min MHz: 1500.0000
BogoMIPS: 4400.15
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 1.5 GiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] galore-torch==1.0
[pip3] numpy==2.0.2
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] galore-torch 1.0 pypi_0 pypi
[conda] numpy 2.0.2 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working