-
Notifications
You must be signed in to change notification settings - Fork 1.3k
timestep scheduling with np.linspace #8623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: ytl0623 <[email protected]>
WalkthroughThe diff changes how DDPM and DDIM compute inference timesteps in set_timesteps: instead of using np.arange * step_ratio, they use numpy.linspace from (num_train_timesteps - 1) down to 0 with num_inference_steps, round and cast to int64, then convert to a torch tensor. DDIM also adds validation for steps_offset to be within [0, num_train_timesteps). Public APIs remain unchanged. The selected discrete timesteps and their ordering are altered and the endpoint is ensured to be included. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Areas to review:
Pre-merge checks and finishing touches✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/schedulers/ddim.py (1)
120-126: Validation logic is now incorrect with linspace approach.
step_ratiois no longer used in timestep generation (line 130 useslinspace). The validation at lines 121-126 incorrectly assumesstep_ratiospacing. With linspace starting atnum_train_timesteps - 1, the constraint should preventsteps_offsetfrom causing timesteps to exceed valid indices[0, num_train_timesteps - 1].Required fix: replace the validation to ensure
steps_offset < 1when using linspace fromnum_train_timesteps - 1, or adjust the linspace start point to accommodate the offset.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/schedulers/ddim.py(1 hunks)monai/networks/schedulers/ddpm.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddpm.pymonai/networks/schedulers/ddim.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
🔇 Additional comments (2)
monai/networks/schedulers/ddpm.py (1)
128-129: Fix correctly addresses the endpoint issue and tests pass.The linspace approach generates timesteps from
num_train_timesteps - 1to 0, ensuring both endpoints are included and all values are valid array indices for the alpha arrays used in thestep()method. Existing tests verify functionality and error handling.monai/networks/schedulers/ddim.py (1)
179-179: Spacing assumption needs verification.The concern is valid:
self.timestepsuses linspace with rounding (line 130), creating non-uniform gaps. Fixed offset arithmetic (line 179:timestep - step_ratio) doesn't guarantee alignment with actual previous scheduled timestep. If gaps vary,prev_timestepmay misalign, causing incorrect alpha lookups.However, impact on diffusion quality (whether approximation is acceptable per DDIM paper or requires dynamic lookup from
self.timesteps) requires verification against:
- Reference DDIM implementations (HuggingFace, Stable Diffusion, etc.)
- Whether existing tests catch quality degradation
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/schedulers/ddim.py (1)
121-126: Validation logic is incorrect for the new implementation.With the linspace approach,
max_timestep = (num_train_timesteps - 1 - steps_offset) + steps_offset = num_train_timesteps - 1regardless of the relationship betweensteps_offsetandstep_ratio. The actual constraint is0 <= steps_offset < num_train_timesteps.Apply this diff:
- step_ratio = self.num_train_timesteps // self.num_inference_steps - if self.steps_offset >= step_ratio: + if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: raise ValueError( - f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " - f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" - f" the max train timestep." + f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps})." )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/schedulers/ddim.py(1 hunks)monai/networks/schedulers/ddpm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- monai/networks/schedulers/ddpm.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddim.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
🔇 Additional comments (1)
monai/networks/schedulers/ddim.py (1)
128-130: Correctly fixes the endpoint issue.The subtract-then-add pattern ensures max timestep =
num_train_timesteps - 1and min timestep =steps_offset, keeping all indices valid. The linspace approach guarantees the endpoint is reached.
… steps_offset) + steps_offset = num_train_timesteps - 1 regardless of the relationship between steps_offset and step_ratio. The actual constraint is 0 <= steps_offset < num_train_timesteps. Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
monai/networks/schedulers/ddim.py (1)
120-121: Validation logic is correct.The range check ensures the linspace start is non-negative and the final max timestep stays at
num_train_timesteps - 1.Static analysis suggests defining a custom exception or shortening the message to comply with TRY003, though this is a minor style concern.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/schedulers/ddim.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/schedulers/ddim.py
🪛 Ruff (0.14.3)
monai/networks/schedulers/ddim.py
121-121: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (1)
monai/networks/schedulers/ddim.py (1)
123-129: Timestep computation is correct and validated by tests.The linspace approach with rounding produces nearly-uniform spacing (gaps of 10–11). The
step()method's prev_timestep approximation (gap =num_train_timesteps // num_inference_steps) works correctly despite minor non-uniformity. Thetest_full_timestep_looptest validates numerical correctness end-to-end, confirmingalphas_cumprodindexing is sound.
|
Hi @virginiafdez, @KumoLiu, @Nic-Ma and @ericspod, Sorry to bother. Thanks in advance! |
Fixes #8600
Description
The
np.linspaceapproach generates a descending array that starts exactly at 999 and ends exactly at 0 (after rounding), ensuring the scheduler samples the entire intended trajectory.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.