Skip to content

Commit 90f349f

Browse files
Add res_multistep sampler from the cosmos code.
This sampler should work with all models.
1 parent b9d9bcb commit 90f349f

File tree

3 files changed

+277
-1
lines changed

3 files changed

+277
-1
lines changed

comfy/k_diffusion/res.py

+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Copied from Nvidia Cosmos code.
17+
18+
import torch
19+
from torch import Tensor
20+
from typing import Callable, List, Tuple, Optional, Any
21+
import math
22+
from tqdm.auto import trange
23+
24+
25+
def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
26+
ndims1 = x.ndim
27+
ndims2 = y.ndim
28+
29+
if ndims1 < ndims2:
30+
x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
31+
elif ndims2 < ndims1:
32+
y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
33+
34+
return x, y
35+
36+
37+
def batch_mul(x: Tensor, y: Tensor) -> Tensor:
38+
x, y = common_broadcast(x, y)
39+
return x * y
40+
41+
42+
def phi1(t: torch.Tensor) -> torch.Tensor:
43+
"""
44+
Compute the first order phi function: (exp(t) - 1) / t.
45+
46+
Args:
47+
t: Input tensor.
48+
49+
Returns:
50+
Tensor: Result of phi1 function.
51+
"""
52+
input_dtype = t.dtype
53+
t = t.to(dtype=torch.float32)
54+
return (torch.expm1(t) / t).to(dtype=input_dtype)
55+
56+
57+
def phi2(t: torch.Tensor) -> torch.Tensor:
58+
"""
59+
Compute the second order phi function: (phi1(t) - 1) / t.
60+
61+
Args:
62+
t: Input tensor.
63+
64+
Returns:
65+
Tensor: Result of phi2 function.
66+
"""
67+
input_dtype = t.dtype
68+
t = t.to(dtype=torch.float32)
69+
return ((phi1(t) - 1.0) / t).to(dtype=input_dtype)
70+
71+
72+
def res_x0_rk2_step(
73+
x_s: torch.Tensor,
74+
t: torch.Tensor,
75+
s: torch.Tensor,
76+
x0_s: torch.Tensor,
77+
s1: torch.Tensor,
78+
x0_s1: torch.Tensor,
79+
) -> torch.Tensor:
80+
"""
81+
Perform a residual-based 2nd order Runge-Kutta step.
82+
83+
Args:
84+
x_s: Current state tensor.
85+
t: Target time tensor.
86+
s: Current time tensor.
87+
x0_s: Prediction at current time.
88+
s1: Intermediate time tensor.
89+
x0_s1: Prediction at intermediate time.
90+
91+
Returns:
92+
Tensor: Updated state tensor.
93+
94+
Raises:
95+
AssertionError: If step size is too small.
96+
"""
97+
s = -torch.log(s)
98+
t = -torch.log(t)
99+
m = -torch.log(s1)
100+
101+
dt = t - s
102+
assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
103+
assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
104+
105+
c2 = (m - s) / dt
106+
phi1_val, phi2_val = phi1(-dt), phi2(-dt)
107+
108+
# Handle edge case where t = s = m
109+
b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
110+
b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
111+
112+
return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1))
113+
114+
115+
def reg_x0_euler_step(
116+
x_s: torch.Tensor,
117+
s: torch.Tensor,
118+
t: torch.Tensor,
119+
x0_s: torch.Tensor,
120+
) -> Tuple[torch.Tensor, torch.Tensor]:
121+
"""
122+
Perform a regularized Euler step based on x0 prediction.
123+
124+
Args:
125+
x_s: Current state tensor.
126+
s: Current time tensor.
127+
t: Target time tensor.
128+
x0_s: Prediction at current time.
129+
130+
Returns:
131+
Tuple[Tensor, Tensor]: Updated state tensor and current prediction.
132+
"""
133+
coef_x0 = (s - t) / s
134+
coef_xs = t / s
135+
return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s
136+
137+
138+
def order2_fn(
139+
x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor
140+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
141+
"""
142+
impl the second order multistep method in https://arxiv.org/pdf/2308.02157
143+
Adams Bashforth approach!
144+
"""
145+
if x0_preds:
146+
x0_s1, s1 = x0_preds[0]
147+
x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1)
148+
else:
149+
x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0]
150+
return x_t, [(x0_s, s)]
151+
152+
153+
class SolverConfig:
154+
is_multi: bool = True
155+
rk: str = "2mid"
156+
multistep: str = "2ab"
157+
s_churn: float = 0.0
158+
s_t_max: float = float("inf")
159+
s_t_min: float = 0.0
160+
s_noise: float = 1.0
161+
162+
163+
def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any:
164+
"""
165+
Implements a for loop with a function.
166+
167+
Args:
168+
lower: Lower bound of the loop (inclusive).
169+
upper: Upper bound of the loop (exclusive).
170+
body_fun: Function to be applied in each iteration.
171+
init_val: Initial value for the loop.
172+
173+
Returns:
174+
The final result after all iterations.
175+
"""
176+
val = init_val
177+
for i in trange(lower, upper, disable=disable):
178+
val = body_fun(i, val)
179+
return val
180+
181+
182+
def differential_equation_solver(
183+
x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
184+
sigmas_L: torch.Tensor,
185+
solver_cfg: SolverConfig,
186+
noise_sampler,
187+
callback=None,
188+
disable=None,
189+
) -> Callable[[torch.Tensor], torch.Tensor]:
190+
"""
191+
Creates a differential equation solver function.
192+
193+
Args:
194+
x0_fn: Function to compute x0 prediction.
195+
sigmas_L: Tensor of sigma values with shape [L,].
196+
solver_cfg: Configuration for the solver.
197+
198+
Returns:
199+
A function that solves the differential equation.
200+
"""
201+
num_step = len(sigmas_L) - 1
202+
203+
# if solver_cfg.is_multi:
204+
# update_step_fn = get_multi_step_fn(solver_cfg.multistep)
205+
# else:
206+
# update_step_fn = get_runge_kutta_fn(solver_cfg.rk)
207+
update_step_fn = order2_fn
208+
209+
eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1)
210+
211+
def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor:
212+
"""
213+
Samples from the differential equation.
214+
215+
Args:
216+
input_xT_B_StateShape: Input tensor with shape [B, StateShape].
217+
218+
Returns:
219+
Output tensor with shape [B, StateShape].
220+
"""
221+
ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32)
222+
223+
def step_fn(
224+
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
225+
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
226+
input_x_B_StateShape, x0_preds = state
227+
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
228+
229+
if sigma_next_0 == 0:
230+
output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
231+
else:
232+
# algorithm 2: line 4-6
233+
if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0:
234+
hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0
235+
input_x_B_StateShape = input_x_B_StateShape + (
236+
hat_sigma_cur_0**2 - sigma_cur_0**2
237+
).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape)
238+
sigma_cur_0 = hat_sigma_cur_0
239+
240+
if solver_cfg.is_multi:
241+
x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
242+
output_x_B_StateShape, x0_preds = update_step_fn(
243+
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds
244+
)
245+
else:
246+
output_x_B_StateShape, x0_preds = update_step_fn(
247+
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn
248+
)
249+
250+
if callback is not None:
251+
callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape})
252+
253+
return output_x_B_StateShape, x0_preds
254+
255+
x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable)
256+
return x_at_eps
257+
258+
return sample_fn

comfy/k_diffusion/sampling.py

+18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from . import utils
1010
from . import deis
11+
from . import res
1112
import comfy.model_patcher
1213
import comfy.model_sampling
1314

@@ -1265,3 +1266,20 @@ def post_cfg_function(args):
12651266
x = denoised + denoised_mix + torch.exp(-h) * x
12661267
old_uncond_denoised = uncond_denoised
12671268
return x
1269+
1270+
@torch.no_grad()
1271+
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
1272+
extra_args = {} if extra_args is None else extra_args
1273+
seed = extra_args.get("seed", None)
1274+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1275+
1276+
x0_func = lambda x, sigma: model(x, sigma, **extra_args)
1277+
1278+
solver_cfg = res.SolverConfig()
1279+
solver_cfg.s_churn = s_churn
1280+
solver_cfg.s_t_max = s_tmax
1281+
solver_cfg.s_t_min = s_tmin
1282+
solver_cfg.s_noise = s_noise
1283+
1284+
x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x)
1285+
return x

comfy/samplers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def max_denoise(self, model_wrap, sigmas):
687687
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
688688
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
689689
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
690-
"ipndm", "ipndm_v", "deis"]
690+
"ipndm", "ipndm_v", "deis", "res_multistep"]
691691

692692
class KSAMPLER(Sampler):
693693
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

0 commit comments

Comments
 (0)