Upload FOFPred pipeline

#9
by kahnchana - opened
pipeline_fofpred.py CHANGED
@@ -59,8 +59,8 @@ from einops import repeat
59
  from huggingface_hub.utils import validate_hf_hub_args
60
  from transformers import Qwen2_5_VLForConditionalGeneration
61
 
62
- from .scheduler_fofpred import FlowMatchEulerDiscreteScheduler
63
- from .transformer_fofpred import OmniGen2Transformer3DModel
64
 
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
 
 
59
  from huggingface_hub.utils import validate_hf_hub_args
60
  from transformers import Qwen2_5_VLForConditionalGeneration
61
 
62
+ from .scheduler.scheduler_fofpred import FlowMatchEulerDiscreteScheduler
63
+ from .transformer.transformer_fofpred import OmniGen2Transformer3DModel
64
 
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
 
scheduler/scheduler_fofpred.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders.lora_base import ( # noqa
8
+ LoraBaseMixin,
9
+ _fetch_state_dict,
10
+ )
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ @dataclass
16
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
17
+ """
18
+ Output class for the scheduler's `step` function output.
19
+
20
+ Args:
21
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
22
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
23
+ denoising loop.
24
+ """
25
+
26
+ prev_sample: torch.FloatTensor
27
+
28
+
29
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
30
+ """
31
+ Euler scheduler.
32
+
33
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
34
+ methods the library implements for all schedulers such as loading and saving.
35
+
36
+ Args:
37
+ num_train_timesteps (`int`, defaults to 1000):
38
+ The number of diffusion steps to train the model.
39
+ timestep_spacing (`str`, defaults to `"linspace"`):
40
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
41
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
42
+ shift (`float`, defaults to 1.0):
43
+ The shift value for the timestep schedule.
44
+ """
45
+
46
+ _compatibles = []
47
+ order = 1
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = True
52
+ ):
53
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[
54
+ :-1
55
+ ]
56
+
57
+ self.timesteps = timesteps
58
+
59
+ self._step_index = None
60
+ self._begin_index = None
61
+
62
+ @property
63
+ def step_index(self):
64
+ """
65
+ The index counter for current timestep. It will increase 1 after each scheduler step.
66
+ """
67
+ return self._step_index
68
+
69
+ @property
70
+ def begin_index(self):
71
+ """
72
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
73
+ """
74
+ return self._begin_index
75
+
76
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
77
+ def set_begin_index(self, begin_index: int = 0):
78
+ """
79
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
80
+
81
+ Args:
82
+ begin_index (`int`):
83
+ The begin index for the scheduler.
84
+ """
85
+ self._begin_index = begin_index
86
+
87
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
88
+ if schedule_timesteps is None:
89
+ schedule_timesteps = self._timesteps
90
+
91
+ indices = (schedule_timesteps == timestep).nonzero()
92
+
93
+ # The sigma index that is taken for the **very** first `step`
94
+ # is always the second index (or the last index if there is only 1)
95
+ # This way we can ensure we don't accidentally skip a sigma in
96
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
97
+ pos = 1 if len(indices) > 1 else 0
98
+
99
+ return indices[pos].item()
100
+
101
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
102
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
103
+
104
+ def set_timesteps(
105
+ self,
106
+ num_inference_steps: int = None,
107
+ device: Union[str, torch.device] = None,
108
+ timesteps: Optional[List[float]] = None,
109
+ num_tokens: Optional[int] = None,
110
+ ):
111
+ """
112
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
113
+
114
+ Args:
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model.
117
+ device (`str` or `torch.device`, *optional*):
118
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
119
+ """
120
+
121
+ if timesteps is None:
122
+ self.num_inference_steps = num_inference_steps
123
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[
124
+ :-1
125
+ ]
126
+ if self.config.dynamic_time_shift and num_tokens is not None:
127
+ m = (
128
+ np.sqrt(num_tokens) / 40
129
+ ) # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
130
+ timesteps = timesteps / (m - m * timesteps + timesteps)
131
+
132
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
133
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
134
+
135
+ self.timesteps = timesteps
136
+ self._timesteps = _timesteps
137
+ self._step_index = None
138
+ self._begin_index = None
139
+
140
+ def _init_step_index(self, timestep):
141
+ if self.begin_index is None:
142
+ if isinstance(timestep, torch.Tensor):
143
+ timestep = timestep.to(self.timesteps.device)
144
+ self._step_index = self.index_for_timestep(timestep)
145
+ else:
146
+ self._step_index = self._begin_index
147
+
148
+ def step(
149
+ self,
150
+ model_output: torch.FloatTensor,
151
+ timestep: Union[float, torch.FloatTensor],
152
+ sample: torch.FloatTensor,
153
+ generator: Optional[torch.Generator] = None,
154
+ return_dict: bool = True,
155
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
156
+ """
157
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
158
+ process from the learned model outputs (most often the predicted noise).
159
+
160
+ Args:
161
+ model_output (`torch.FloatTensor`):
162
+ The direct output from learned diffusion model.
163
+ timestep (`float`):
164
+ The current discrete timestep in the diffusion chain.
165
+ sample (`torch.FloatTensor`):
166
+ A current instance of a sample created by the diffusion process.
167
+ s_churn (`float`):
168
+ s_tmin (`float`):
169
+ s_tmax (`float`):
170
+ s_noise (`float`, defaults to 1.0):
171
+ Scaling factor for noise added to the sample.
172
+ generator (`torch.Generator`, *optional*):
173
+ A random number generator.
174
+ return_dict (`bool`):
175
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
176
+ tuple.
177
+
178
+ Returns:
179
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
180
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
181
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
182
+ """
183
+
184
+ if (
185
+ isinstance(timestep, int)
186
+ or isinstance(timestep, torch.IntTensor)
187
+ or isinstance(timestep, torch.LongTensor)
188
+ ):
189
+ raise ValueError(
190
+ (
191
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
192
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
193
+ " one of the `scheduler.timesteps` as a timestep."
194
+ ),
195
+ )
196
+
197
+ if self.step_index is None:
198
+ self._init_step_index(timestep)
199
+ # Upcast to avoid precision issues when computing prev_sample
200
+ sample = sample.to(torch.float32)
201
+ t = self._timesteps[self.step_index]
202
+ t_next = self._timesteps[self.step_index + 1]
203
+
204
+ prev_sample = sample + (t_next - t) * model_output
205
+
206
+ # Cast sample back to model compatible dtype
207
+ prev_sample = prev_sample.to(model_output.dtype)
208
+
209
+ # upon completion increase step index by one
210
+ self._step_index += 1
211
+
212
+ if not return_dict:
213
+ return (prev_sample,)
214
+
215
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
216
+
217
+ def __len__(self):
218
+ return self.config.num_train_timesteps
transformer/transformer_fofpred.py ADDED
The diff for this file is too large to render. See raw diff