maxholsman commited on
Commit
4a9570a
·
verified ·
1 Parent(s): 3e54cb4

Upload folder using huggingface_hub

Browse files
custom_generate/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Custom generate function for fuzzy speculative decoding
2
+ from .generate import generate
3
+
4
+ __all__ = ["generate"]
5
+
custom_generate/generate.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Custom generate function for fuzzy speculative decoding
3
+ # Based on transformers.generation.utils with modifications for custom acceptance/rejection logic
4
+
5
+ import copy
6
+ import inspect
7
+ import warnings
8
+ from collections.abc import Callable
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Any, Optional, Union
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch import nn
15
+
16
+ from torch.nn.functional import kl_div, log_softmax
17
+
18
+ from transformers.cache_utils import Cache
19
+ from transformers.generation.candidate_generator import (
20
+ AssistedCandidateGenerator,
21
+ _prepare_attention_mask,
22
+ _prepare_token_type_ids,
23
+ )
24
+ from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
25
+ from transformers.generation.logits_process import LogitsProcessorList
26
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
27
+ from transformers.utils import ModelOutput, is_sklearn_available
28
+
29
+ if is_sklearn_available():
30
+ from sklearn.metrics import roc_curve
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
35
+ from transformers.generation.streamers import BaseStreamer
36
+
37
+ # Variable names used to hold the cache at generation time
38
+ ALL_CACHE_NAMES = [
39
+ "past_key_values", # default
40
+ "cache_params", # mamba-based models
41
+ "state", # rwkv
42
+ "mems", # xlnet
43
+ "past_buckets_states", # reformer
44
+ ]
45
+
46
+ GENERATION_MODES_MAPPING = {
47
+ GenerationMode.SAMPLE: "_sample",
48
+ GenerationMode.GREEDY_SEARCH: "_sample",
49
+ GenerationMode.BEAM_SEARCH: "_beam_search",
50
+ GenerationMode.BEAM_SAMPLE: "_beam_search",
51
+ GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
52
+ }
53
+
54
+
55
+ @dataclass
56
+ class GenerateDecoderOnlyOutput(ModelOutput):
57
+ """Outputs of decoder-only generation models, when using non-beam methods."""
58
+
59
+ sequences: torch.LongTensor
60
+ scores: tuple[torch.FloatTensor] | None = None
61
+ logits: tuple[torch.FloatTensor] | None = None
62
+ attentions: tuple[tuple[torch.FloatTensor]] | None = None
63
+ hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
64
+ past_key_values: Cache | None = None
65
+
66
+
67
+ @dataclass
68
+ class GenerateEncoderDecoderOutput(ModelOutput):
69
+ """Outputs of encoder-decoder generation models, when using non-beam methods."""
70
+
71
+ sequences: torch.LongTensor
72
+ scores: tuple[torch.FloatTensor] | None = None
73
+ logits: tuple[torch.FloatTensor] | None = None
74
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
75
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
76
+ decoder_attentions: tuple[tuple[torch.FloatTensor]] | None = None
77
+ cross_attentions: tuple[tuple[torch.FloatTensor]] | None = None
78
+ decoder_hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
79
+ past_key_values: Cache | None = None
80
+
81
+
82
+ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
83
+ """
84
+ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
85
+ where each member corresponds to a single generated token.
86
+ """
87
+ # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
88
+ # prompt.
89
+ if len(outputs) == 0:
90
+ new_tuple = ()
91
+ for layer in new_outputs:
92
+ last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
93
+ new_tuple += (layer[..., :cur_len, :last_dim_size],)
94
+ outputs += (new_tuple,)
95
+ # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
96
+ cur_len += 1
97
+ added_len -= cur_len
98
+
99
+ for i in range(added_len):
100
+ new_tuple = ()
101
+ for layer in new_outputs:
102
+ last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
103
+ new_tuple += (layer[..., i : i + 1, :last_dim_size],)
104
+ outputs += (new_tuple,)
105
+ return outputs
106
+
107
+
108
+ class RawLogitsCandidateGenerator(AssistedCandidateGenerator):
109
+ """
110
+ Custom candidate generator that returns both processed and raw logits from the assistant model.
111
+ Extends AssistedCandidateGenerator to support returning raw logits when output_logits=True.
112
+ """
113
+
114
+ def __init__(self, *args, **kwargs):
115
+ """Initialize the custom candidate generator."""
116
+ super().__init__(*args, **kwargs)
117
+ # Initialize probs list if sklearn is available and confidence threshold is enabled
118
+ if (
119
+ is_sklearn_available()
120
+ and self.assistant_generation_config.assistant_confidence_threshold
121
+ ):
122
+ if not hasattr(self, 'probs'):
123
+ self.probs = []
124
+ if not hasattr(self, 'matches'):
125
+ self.matches = []
126
+
127
+ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None, torch.FloatTensor | None]:
128
+ """
129
+ Fetches the candidates to be tried for the current input.
130
+ Returns: (candidate_ids, candidate_logits_processed, candidate_logits_raw)
131
+ - candidate_logits_processed: Processed logits (scores) from assistant model
132
+ - candidate_logits_raw: Raw logits from assistant model (None if output_logits=False)
133
+ """
134
+ input_ids = input_ids.to(self.assistant_model.device)
135
+ # Calculate new tokens to generate
136
+ min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
137
+ if max_new_tokens == 0:
138
+ return input_ids, None, None
139
+ # Update past key values and masks
140
+ self._update_past_and_masks(input_ids)
141
+ # Generate candidates
142
+ generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
143
+ candidate_ids, candidate_logits_processed, candidate_logits_raw = self._generate_candidates(generation_args)
144
+ return candidate_ids, candidate_logits_processed, candidate_logits_raw
145
+
146
+ def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, torch.FloatTensor | None, torch.FloatTensor | None]:
147
+ """Generate candidate sequences using the assistant model, returning both processed and raw logits."""
148
+ assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
149
+ self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
150
+
151
+ # Handle sklearn confidence threshold tracking (if enabled)
152
+ if (
153
+ is_sklearn_available()
154
+ and self.assistant_generation_config.assistant_confidence_threshold
155
+ and type(self) is RawLogitsCandidateGenerator
156
+ ):
157
+ scores_tensor = torch.cat(assistant_output.scores, dim=0)
158
+ scores_softmax = torch.softmax(scores_tensor, dim=-1)
159
+ ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
160
+ p = scores_softmax[range(len(ids)), ids]
161
+ self.probs.extend(p.tolist())
162
+
163
+ # Extract processed logits (scores) - always available
164
+ candidate_logits_processed = torch.stack(assistant_output.scores, dim=1)
165
+ candidate_ids = assistant_output.sequences
166
+
167
+ # Extract raw logits if available (when output_logits=True)
168
+ candidate_logits_raw = None
169
+ if self.generation_config.output_logits and hasattr(assistant_output, 'logits') and assistant_output.logits is not None:
170
+ candidate_logits_raw = torch.stack(assistant_output.logits, dim=1)
171
+
172
+ return candidate_ids, candidate_logits_processed, candidate_logits_raw
173
+
174
+
175
+ def _speculative_sampling(
176
+ candidate_input_ids,
177
+ candidate_logits,
178
+ candidate_length,
179
+ new_logits,
180
+ next_token_logits,
181
+ is_done_candidate,
182
+ candidate_logits_raw,
183
+ fsd_threshold: float = 0.0,
184
+ fsd_div_type: str = "kl"
185
+ ):
186
+ """
187
+ Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
188
+ the selected tokens, as well as the number of candidate matches.
189
+
190
+ NOTE: Unless otherwise stated, the variable names match those in the paper.
191
+
192
+ """
193
+ new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
194
+ # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
195
+ # selected by the assistant, respectively.
196
+ q = candidate_logits.softmax(dim=-1)
197
+ q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
198
+ p = new_logits.softmax(dim=-1)
199
+ p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
200
+ probability_ratio = p_i / q_i
201
+
202
+ target_probs = next_token_logits.softmax(dim=-1)
203
+ cand_probs = candidate_logits_raw.softmax(dim=-1)
204
+
205
+ if fsd_div_type == "kl":
206
+ divs = kl_div(
207
+ cand_probs.log().clamp(min=-1e10), # log-probabilities of candidate distribution
208
+ target_probs[:, :-1, :], # probabilities of target distribution
209
+ reduction='none'
210
+ ).sum(dim=-1)
211
+ elif fsd_div_type == "js":
212
+
213
+ m = 0.5 * (cand_probs + target_probs[:, :-1, :]) # Mixture distribution
214
+
215
+ # Compute KL(P || M) and KL(Q || M)
216
+ kl_pm = kl_div(
217
+ m.log().clamp(min=-1e10), # log-probabilities of mixture
218
+ cand_probs, # probabilities of candidate
219
+ reduction='none'
220
+ )
221
+ kl_qm = kl_div(
222
+ m.log().clamp(min=-1e10), # log-probabilities of mixture
223
+ target_probs[:, :-1, :], # probabilities of target
224
+ reduction='none'
225
+ )
226
+
227
+ divs = 0.5 * (kl_pm + kl_qm).sum(dim=-1)
228
+
229
+ elif fsd_div_type == "draft_tokens":
230
+ draft_token_ids = new_candidate_input_ids # shape: (batch, candidate_length)
231
+ draft_token_probs_candidate = cand_probs[:, torch.arange(candidate_length), draft_token_ids].squeeze(0, 1)
232
+ draft_token_probs_target = target_probs[:, :-1, :][:, torch.arange(candidate_length), draft_token_ids].squeeze(0,
233
+ 1)
234
+ divs = (draft_token_probs_candidate - draft_token_probs_target).abs().sum(dim=-1)
235
+ else:
236
+ raise ValueError(f"Invalid fsd_div_type: {fsd_div_type}")
237
+ # print(f"divs: {divs}")
238
+ is_accepted_fsd = divs <= fsd_threshold
239
+
240
+ # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
241
+ # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
242
+ # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
243
+ r_i = torch.rand_like(probability_ratio)
244
+ is_accepted_sd = r_i <= probability_ratio
245
+
246
+ is_accepted = is_accepted_fsd | is_accepted_sd
247
+ n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
248
+ # print(f"is_accepted_fsd: {is_accepted_fsd}\n is_accepted_sd: {is_accepted_sd}\n is_accepted: {is_accepted}")
249
+
250
+ # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
251
+ if is_done_candidate and n_matches == candidate_length:
252
+ # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
253
+ # due to acceptance on EOS we fix `n_matches`
254
+ n_matches -= 1
255
+ valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
256
+ else:
257
+ # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
258
+ gamma = candidate_logits.shape[1]
259
+ p_n_plus_1 = p[:, n_matches, :]
260
+ if n_matches < gamma:
261
+ q_n_plus_1 = q[:, n_matches, :]
262
+ p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
263
+ p_prime.div_(p_prime.sum())
264
+ else:
265
+ p_prime = p_n_plus_1
266
+ t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
267
+
268
+ # The selected tokens include the matches (if any) plus the next sampled tokens
269
+ if n_matches > 0:
270
+ valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
271
+ else:
272
+ valid_tokens = t
273
+
274
+ return valid_tokens, n_matches
275
+
276
+
277
+ def _assisted_decoding(
278
+ model,
279
+ input_ids: torch.LongTensor,
280
+ logits_processor: LogitsProcessorList,
281
+ stopping_criteria: StoppingCriteriaList,
282
+ generation_config: GenerationConfig,
283
+ synced_gpus: bool = False,
284
+ streamer: Optional["BaseStreamer"] = None,
285
+ inputs_tensor: torch.FloatTensor | None = None,
286
+ assistant_model: Optional["PreTrainedModel"] = None,
287
+ assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
288
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
289
+ fsd_threshold: float = 0.0,
290
+ fsd_div_type: str = "kl",
291
+ **model_kwargs,
292
+ ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
293
+ r"""
294
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
295
+ **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
296
+ candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
297
+ models.
298
+ """
299
+ # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
300
+ if not model_kwargs["use_cache"]:
301
+ raise ValueError("assisted generate requires `use_cache=True`")
302
+ if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
303
+ "past_key_values" in model_kwargs
304
+ and hasattr(model_kwargs["past_key_values"], "layers")
305
+ and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
306
+ ):
307
+ raise ValueError("assisted generate is not supported with Static cache classes`")
308
+
309
+ # Create custom candidate generator that supports raw logits
310
+ # Set output_logits based on generation_config (don't force it)
311
+ if assistant_model is None:
312
+ raise ValueError("assistant_model is required for assisted generation")
313
+
314
+
315
+ generation_config.output_logits = True
316
+ candidate_generator = RawLogitsCandidateGenerator(
317
+ input_ids=input_ids,
318
+ assistant_model=assistant_model,
319
+ generation_config=generation_config,
320
+ model_kwargs=model_kwargs,
321
+ inputs_tensor=inputs_tensor,
322
+ logits_processor=logits_processor,
323
+ )
324
+ # init values
325
+ do_sample = generation_config.do_sample
326
+ output_attentions = generation_config.output_attentions
327
+ output_hidden_states = generation_config.output_hidden_states
328
+ output_scores = generation_config.output_scores
329
+ output_logits = generation_config.output_logits
330
+ return_dict_in_generate = generation_config.return_dict_in_generate
331
+
332
+ # init attention / hidden states / scores tuples
333
+ scores = () if (return_dict_in_generate and output_scores) else None
334
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
335
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
336
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
337
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
338
+
339
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
340
+ if return_dict_in_generate and model.config.is_encoder_decoder:
341
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
342
+ encoder_hidden_states = (
343
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
344
+ )
345
+
346
+ # keep track of which sequences are already finished
347
+ batch_size, cur_len = input_ids.shape[:2]
348
+ if batch_size > 1:
349
+ raise ValueError("assisted generate is only supported for batch_size = 1")
350
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
351
+ model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
352
+
353
+ this_peer_finished = False
354
+ is_first_iteration = True # to preserve the same API in the output as other generation methods
355
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
356
+ cur_len = input_ids.shape[1]
357
+
358
+ # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
359
+ candidate_input_ids, candidate_logits, candidate_logits_raw = candidate_generator.get_candidates(input_ids)
360
+ candidate_input_ids = candidate_input_ids.to(model.device)
361
+ if candidate_logits is not None:
362
+ candidate_logits = candidate_logits.to(model.device)
363
+ if candidate_logits_raw is not None:
364
+ candidate_logits_raw = candidate_logits_raw.to(model.device)
365
+
366
+ candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
367
+ is_done_candidate = stopping_criteria(candidate_input_ids, None)
368
+
369
+ # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
370
+ # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
371
+ # we use this forward pass to also pick the subsequent logits in the original model.
372
+
373
+ # 2.1. Prepare the model inputs
374
+ candidate_kwargs = copy.copy(model_kwargs)
375
+ candidate_kwargs = _prepare_attention_mask(
376
+ candidate_kwargs, candidate_input_ids.shape[1], model.config.is_encoder_decoder
377
+ )
378
+ candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
379
+ if "cache_position" in candidate_kwargs:
380
+ candidate_kwargs["cache_position"] = torch.cat(
381
+ (
382
+ candidate_kwargs["cache_position"],
383
+ torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
384
+ ),
385
+ dim=0,
386
+ )
387
+
388
+ model_inputs = model.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
389
+ if "logits_to_keep" in model_inputs:
390
+ model_inputs["logits_to_keep"] = candidate_length + 1
391
+
392
+ # 2.2. Run a forward pass on the candidate sequence
393
+ outputs = model(**model_inputs)
394
+
395
+ # 2.3. Process the new logits
396
+ # .float() is needed to retain precision for later logits manipulations
397
+ new_logits = outputs.logits[:, -candidate_length - 1 :].to(
398
+ dtype=torch.float32, device=input_ids.device
399
+ ) # excludes the input prompt if present
400
+ next_token_logits = new_logits.clone()
401
+ if len(logits_processor) > 0:
402
+ for i in range(candidate_length + 1):
403
+ new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
404
+
405
+ # 3. Select the accepted tokens. There are two possible cases:
406
+ # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
407
+ # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
408
+ if do_sample and candidate_logits is not None:
409
+ valid_tokens, n_matches = _speculative_sampling(
410
+ candidate_input_ids,
411
+ candidate_logits,
412
+ candidate_length,
413
+ new_logits,
414
+ next_token_logits,
415
+ is_done_candidate,
416
+ candidate_logits_raw=candidate_logits_raw,
417
+ fsd_threshold=fsd_threshold,
418
+ fsd_div_type=fsd_div_type,
419
+ )
420
+
421
+ # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
422
+ # original model logits with the candidate tokens. We can keep the candidate tokens until the first
423
+ # mismatch, or until the max length is reached.
424
+ else:
425
+ if do_sample:
426
+ probs = new_logits.softmax(dim=-1)
427
+ selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
428
+ else:
429
+ selected_tokens = new_logits.argmax(dim=-1)
430
+
431
+ candidate_new_tokens = candidate_input_ids[:, cur_len:]
432
+ n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
433
+
434
+ # Ensure we don't generate beyond max_len or an EOS token
435
+ if is_done_candidate and n_matches == candidate_length:
436
+ n_matches -= 1
437
+ valid_tokens = selected_tokens[:, : n_matches + 1]
438
+
439
+ # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
440
+ # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
441
+ # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
442
+ # is no match.
443
+
444
+ # 4.1. Get the valid continuation, after the matching tokens
445
+ input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
446
+ if streamer is not None:
447
+ streamer.put(valid_tokens.cpu())
448
+ new_cur_len = input_ids.shape[1]
449
+
450
+ # 4.2. Discard past key values relative to unused assistant tokens
451
+ outputs.past_key_values.crop(new_cur_len - 1)
452
+
453
+ # 5. Update the candidate generation strategy if needed
454
+ candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
455
+
456
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
457
+ model_kwargs = model._update_model_kwargs_for_generation(
458
+ outputs,
459
+ model_kwargs,
460
+ is_encoder_decoder=model.config.is_encoder_decoder,
461
+ num_new_tokens=n_matches + 1,
462
+ )
463
+ if synced_gpus and this_peer_finished:
464
+ continue
465
+
466
+ # Store scores, attentions and hidden_states when required
467
+ # Assistant: modified to append one tuple element per token, as in the other generation methods.
468
+ if return_dict_in_generate:
469
+ newly_added_length = n_matches + 1
470
+ if output_scores:
471
+ scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
472
+ if output_logits:
473
+ raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
474
+
475
+ newly_added_length = new_cur_len if is_first_iteration else newly_added_length
476
+ if output_attentions:
477
+ if model.config.is_encoder_decoder:
478
+ cross_attentions = _split_model_outputs(
479
+ cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
480
+ )
481
+ decoder_attentions = _split_model_outputs(
482
+ decoder_attentions,
483
+ outputs.decoder_attentions,
484
+ cur_len,
485
+ newly_added_length,
486
+ is_decoder_attention=True,
487
+ )
488
+ # some (V)LLMs have hard requirement on SDPA and thus never return attn
489
+ elif outputs.attentions[0] is not None:
490
+ decoder_attentions = _split_model_outputs(
491
+ decoder_attentions,
492
+ outputs.attentions,
493
+ cur_len,
494
+ newly_added_length,
495
+ is_decoder_attention=True,
496
+ )
497
+ if output_hidden_states:
498
+ if model.config.is_encoder_decoder:
499
+ decoder_hidden_states = _split_model_outputs(
500
+ decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
501
+ )
502
+ else:
503
+ decoder_hidden_states = _split_model_outputs(
504
+ decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
505
+ )
506
+
507
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
508
+ this_peer_finished = unfinished_sequences.max() == 0
509
+ is_first_iteration = False
510
+
511
+ if streamer is not None:
512
+ streamer.end()
513
+
514
+ if (
515
+ hasattr(candidate_generator, "assistant_model")
516
+ and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
517
+ ):
518
+ candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
519
+ candidate_generator.num_assistant_tokens
520
+ )
521
+ if return_dict_in_generate:
522
+ cache = None
523
+ if any(cache_key in model_kwargs for cache_key in ALL_CACHE_NAMES):
524
+ cache_key = next(cache_key for cache_key in ALL_CACHE_NAMES if cache_key in model_kwargs)
525
+ cache = model_kwargs[cache_key]
526
+ if model.config.is_encoder_decoder:
527
+ return GenerateEncoderDecoderOutput(
528
+ sequences=input_ids,
529
+ scores=scores,
530
+ logits=raw_logits,
531
+ encoder_attentions=encoder_attentions,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ decoder_attentions=decoder_attentions,
534
+ cross_attentions=cross_attentions,
535
+ decoder_hidden_states=decoder_hidden_states,
536
+ past_key_values=cache,
537
+ )
538
+ else:
539
+ return GenerateDecoderOnlyOutput(
540
+ sequences=input_ids,
541
+ scores=scores,
542
+ logits=raw_logits,
543
+ attentions=decoder_attentions,
544
+ hidden_states=decoder_hidden_states,
545
+ past_key_values=cache,
546
+ )
547
+ else:
548
+ return input_ids
549
+
550
+
551
+ def generate(
552
+ model,
553
+ inputs: torch.Tensor | None = None,
554
+ generation_config: GenerationConfig | None = None,
555
+ logits_processor: LogitsProcessorList | None = None,
556
+ stopping_criteria: StoppingCriteriaList | None = None,
557
+ prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
558
+ synced_gpus: bool | None = None,
559
+ assistant_model: Optional["PreTrainedModel"] = None,
560
+ streamer: Optional["BaseStreamer"] = None,
561
+ negative_prompt_ids: torch.Tensor | None = None,
562
+ negative_prompt_attention_mask: torch.Tensor | None = None,
563
+ **kwargs,
564
+ ) -> GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput | torch.LongTensor:
565
+ r"""
566
+ Generates sequences of token ids for models with a language modeling head.
567
+
568
+ This is a custom generate function that replaces the standard one. It supports all standard generation modes
569
+ and includes custom speculative decoding acceptance/rejection logic.
570
+ """
571
+ # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
572
+ # Extract custom parameters before validation (they're not standard generation config params)
573
+ fsd_threshold = kwargs.pop("fsd_threshold", 0.0)
574
+ fsd_div_type = kwargs.pop("fsd_div_type", "kl")
575
+
576
+ generation_mode_kwargs = model._extract_generation_mode_kwargs(
577
+ None, # custom_generate
578
+ kwargs,
579
+ synced_gpus,
580
+ assistant_model,
581
+ streamer,
582
+ )
583
+ # Add custom FSD parameters to generation_mode_kwargs so they're passed to _assisted_decoding
584
+ generation_mode_kwargs["fsd_threshold"] = fsd_threshold
585
+ generation_mode_kwargs["fsd_div_type"] = fsd_div_type
586
+
587
+ # Check length values before updating the config with defaults
588
+ has_default_max_length = kwargs.get("max_length") is None and (
589
+ generation_config is None or generation_config.max_length is None
590
+ )
591
+ has_default_min_length = kwargs.get("min_length") is None and (
592
+ generation_config is None or generation_config.min_length is None
593
+ )
594
+ generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
595
+
596
+ generation_mode = generation_config.get_generation_mode(assistant_model)
597
+ # type() required to access the unbound class-level method
598
+ decoding_method = getattr(type(model), GENERATION_MODES_MAPPING[generation_mode])
599
+
600
+ model._validate_model_kwargs(model_kwargs.copy())
601
+ model._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
602
+
603
+ # 2. Set generation parameters if not already defined
604
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
605
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
606
+
607
+ accepts_attention_mask = "attention_mask" in set(inspect.signature(model.forward).parameters.keys())
608
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
609
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
610
+
611
+ # 3. Define model inputs
612
+ inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs(
613
+ inputs, generation_config.bos_token_id, model_kwargs
614
+ )
615
+ # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
616
+ if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
617
+ generation_mode_kwargs["inputs_tensor"] = inputs_tensor
618
+ batch_size = inputs_tensor.shape[0]
619
+
620
+ device = inputs_tensor.device
621
+ model._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
622
+
623
+ # decoder-only models must use left-padding for batched generation.
624
+ if not model.config.is_encoder_decoder:
625
+ # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
626
+ if (
627
+ generation_config._pad_token_tensor is not None
628
+ and batch_size > 1
629
+ and len(inputs_tensor.shape) == 2
630
+ and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
631
+ ):
632
+ import logging
633
+ logger = logging.get_logger(__name__)
634
+ logger.warning(
635
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
636
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
637
+ )
638
+
639
+ # 4. Define other model kwargs
640
+ # decoder-only models with inputs_embeds forwarding must use caching
641
+ if not model.config.is_encoder_decoder and model_input_name == "inputs_embeds":
642
+ generation_config.use_cache = True
643
+
644
+ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
645
+ model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
646
+ inputs_tensor, generation_config, model_kwargs
647
+ )
648
+ elif kwargs_has_attention_mask:
649
+ # TODO (joao): generalize this check with other types of inputs
650
+ if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
651
+ raise ValueError("`attention_mask` passed to `generate` must be 2D.")
652
+
653
+ if model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
654
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
655
+ model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
656
+ inputs_tensor, model_kwargs, model_input_name, generation_config
657
+ )
658
+
659
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
660
+ if model.config.is_encoder_decoder:
661
+ input_ids, model_kwargs = model._prepare_decoder_input_ids_for_generation(
662
+ batch_size=batch_size,
663
+ model_input_name=model_input_name,
664
+ model_kwargs=model_kwargs,
665
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
666
+ device=inputs_tensor.device,
667
+ )
668
+ else:
669
+ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
670
+
671
+ # Expand inputs depending on the generation mode
672
+ input_ids, model_kwargs = model._expand_inputs_for_generation(
673
+ input_ids=input_ids,
674
+ expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
675
+ is_encoder_decoder=model.config.is_encoder_decoder,
676
+ **model_kwargs,
677
+ )
678
+
679
+ if generation_config.token_healing:
680
+ input_ids = model.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
681
+
682
+ if streamer is not None:
683
+ streamer.put(input_ids.cpu())
684
+
685
+ # 6. Prepare `max_length` depending on other stopping criteria.
686
+ input_ids_length = input_ids.shape[1]
687
+ generation_config = model._prepare_generated_length(
688
+ generation_config=generation_config,
689
+ has_default_max_length=has_default_max_length,
690
+ has_default_min_length=has_default_min_length,
691
+ model_input_name=model_input_name,
692
+ inputs_tensor=inputs_tensor,
693
+ input_ids_length=input_ids_length,
694
+ )
695
+
696
+ # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
697
+ # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
698
+ # dynamically overrides this value as it can need more than the last token logits
699
+ if model._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
700
+ model_kwargs["logits_to_keep"] = 1
701
+
702
+ model._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
703
+
704
+ # 7. Prepare the cache.
705
+ max_cache_length = generation_config.max_length - 1
706
+ if (
707
+ inputs_tensor.shape[1] != input_ids_length
708
+ and model_input_name == "inputs_embeds"
709
+ and not model.config.is_encoder_decoder
710
+ ):
711
+ max_cache_length += inputs_tensor.shape[1]
712
+ model._prepare_cache_for_generation(
713
+ generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
714
+ )
715
+
716
+ if model.device.type != input_ids.device.type:
717
+ warnings.warn(
718
+ "You are calling .generate() with the `input_ids` being on a device type different"
719
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
720
+ f" is on {model.device.type}. You may experience unexpected behaviors or slower generation."
721
+ " Please make sure that you have put `input_ids` to the"
722
+ f" correct device by calling for example input_ids = input_ids.to('{model.device.type}') before"
723
+ " running `.generate()`.",
724
+ UserWarning,
725
+ )
726
+
727
+ # 8. Prepare logits processors and stopping criteria
728
+ prepared_logits_processor = model._get_logits_processor(
729
+ generation_config=generation_config,
730
+ input_ids_seq_length=input_ids_length,
731
+ encoder_input_ids=inputs_tensor,
732
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
733
+ logits_processor=logits_processor,
734
+ device=inputs_tensor.device,
735
+ model_kwargs=model_kwargs,
736
+ negative_prompt_ids=negative_prompt_ids,
737
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
738
+ )
739
+ prepared_stopping_criteria = model._get_stopping_criteria(
740
+ generation_config=generation_config,
741
+ stopping_criteria=stopping_criteria,
742
+ tokenizer=generation_mode_kwargs.get("tokenizer"),
743
+ )
744
+
745
+ # Set model_kwargs `use_cache` so we can use it later in forward runs
746
+ model_kwargs["use_cache"] = generation_config.use_cache
747
+
748
+ # 9. Call generation mode
749
+ # For assisted generation, use our custom function
750
+ if generation_mode == GenerationMode.ASSISTED_GENERATION:
751
+ result = _assisted_decoding(
752
+ model,
753
+ input_ids,
754
+ logits_processor=prepared_logits_processor,
755
+ stopping_criteria=prepared_stopping_criteria,
756
+ generation_config=generation_config,
757
+ **generation_mode_kwargs,
758
+ **model_kwargs,
759
+ )
760
+ else:
761
+ # For other modes, use the model's standard methods
762
+ result = decoding_method(
763
+ model,
764
+ input_ids,
765
+ logits_processor=prepared_logits_processor,
766
+ stopping_criteria=prepared_stopping_criteria,
767
+ generation_config=generation_config,
768
+ **generation_mode_kwargs,
769
+ **model_kwargs,
770
+ )
771
+
772
+ return result
773
+
774
+
775
+ # def _speculative_backoff_sampling(
776
+ # candidate_input_ids,
777
+ # candidate_logits,
778
+ # candidate_logits_unprocessed,
779
+ # eos_position_logits,
780
+ # candidate_length,
781
+ # new_logits,
782
+ # new_logits_unprocessed,# NOTE: these are unprocessed, unwarped logits
783
+ # is_done_candidate,
784
+ # div_threshold,
785
+ # div_type,
786
+ # do_sample, # this is also passed in new
787
+ # logits_processor: LogitsProcessorList, # these two must be passed in because we want to work with the logits before they are processed and warped
788
+ # logits_warper: Optional[LogitsProcessorList], # these two must be passed in because we want to work with the logits before they are processed and warped
789
+ # div_logits_processor: Optional[LogitsProcessorList],
790
+ # cur_len,
791
+ # eos_token_id,
792
+ # candidate_generator_type='classifier',
793
+ # ):
794
+ # # valid_tokens, n_matches, new_logits = _speculative_backoff_sampling(
795
+ # # candidate_input_ids,
796
+ # # candidate_logits,
797
+ # # candidate_logits_unprocessed,
798
+ # # candidate_length,
799
+ # # new_logits,
800
+ # # is_done_candidate,
801
+ # # kl_div_threshold,
802
+ # # do_sample,
803
+ # # logits_processor,
804
+ # # logits_warper,
805
+ # # cur_len,
806
+ # # )
807
+ # """
808
+ # Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
809
+ # the selected tokens, as well as the number of candidate matches.
810
+
811
+ # NOTE: Unless otherwise stated, the variable names match those in the paper.
812
+ # """
813
+
814
+ # '''
815
+ # NOTE: Implementation plan -
816
+ # 1. implement custom assistent model class with classifier that terminates generation as soon as last generated logit is predicted to exceed distribution
817
+ # Is there an issue with using EOS token to terminate sequence? since large model will simply reject this token once it checks.
818
+ # I think this would work, since we can then use distribution generated by large model to generate next token (the position deemed as large model-necessary by classifier)
819
+ # 2. implement custom candidate_generator that uses this model to generate a series of candidates - DONE (other than question about do_sample - will set to sample for now)
820
+
821
+ # 3. implement this speculative_backoff_sampling class to backtrack, checking all candidates to see if they exceed the threshold. If they do, sample from large_model logits at this position (have to adjust logits as would is regular sampling)
822
+ # Need to make sure logit processing and warping is correct - both in terms of warping before calling this function (so that M_L sampling is correct) and in terms of having the warping not throw of the Kl divergence calculation
823
+ # Probably will pass original + processed logits into speculative_backoff_decoding function
824
+ # 4. Update cache of both assistant and target model to discard all KV values past first rejected token using cache.crop()
825
+ # 5. Make sure this is properly implemented within a loop, such that following all this candidate_generator is called again to generate the next batch of tokens
826
+
827
+ # '''
828
+
829
+ # initial_start_time = time.time()
830
+ # new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
831
+ # correction_term = 0
832
+
833
+ # if div_type != 'sd':
834
+
835
+ # if div_type == 'kl_div_processed' or div_type == 'js_div_processed' or div_type == 'tv_div_processed':
836
+ # epsilon = 1e-10
837
+ # q = candidate_logits.softmax(dim=-1)
838
+ # p = new_logits[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
839
+
840
+ # q_nonzero = (p > 0).int()
841
+ # p_nonzero = (q > 0).int()
842
+ # both_nonzero = (q_nonzero & p_nonzero).int()
843
+
844
+ # # print(f"nonzero q: {q_nonzero.sum(dim=-1)}")
845
+ # # print(f"nonzero p: {p_nonzero.sum(dim=-1)}")
846
+ # # print(f"both nonzero: {both_nonzero.sum(dim=-1)}")
847
+
848
+ # q = q + epsilon
849
+ # p = p + epsilon
850
+
851
+ # p = p / p.sum(dim=-1, keepdim=True)
852
+ # q = q / q.sum(dim=-1, keepdim=True)
853
+
854
+
855
+ # else:
856
+ # q = candidate_logits_unprocessed.softmax(dim=-1)
857
+ # p = new_logits_unprocessed[:, :candidate_length, :].softmax(dim=-1) # need to be cropped because M_L logits include logits for ungenerated position
858
+
859
+ # if len(div_logits_processor) > 0:
860
+ # epsilon = 1e-10
861
+ # q = q + epsilon
862
+ # p = p + epsilon
863
+
864
+ # p = p / p.sum(dim=-1, keepdim=True)
865
+ # q = q / q.sum(dim=-1, keepdim=True)
866
+
867
+ # if div_type == 'kl_div' or div_type == 'kl_div_processed':
868
+ # divs = torch.nn.functional.kl_div(torch.log(p), q, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
869
+ # elif div_type == 'kl_div_reversed' or div_type == 'kl_div_reversed_processed':
870
+ # divs = torch.nn.functional.kl_div(torch.log(q), p, reduction='none').sum(dim=-1) # shape = [bs, seq_len]
871
+ # elif div_type == 'js_div' or div_type == 'js_div_processed':
872
+ # m = 0.5 * (p + q) # Midpoint distribution
873
+ # divs = (0.5 * torch.nn.functional.kl_div(torch.log(p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q), m, reduction='none')).sum(dim=-1)
874
+ # elif div_type == 'tv_div' or div_type == 'tv_div_processed':
875
+ # divs = 0.5 * torch.abs(p - q).sum(dim=-1)
876
+
877
+ # elif div_type == 'top_p_kl_div' or div_type == 'top_p_js_div' or div_type == 'top_p_tv_div':
878
+ # p_sorted, p_sorted_indexes = torch.sort(p, descending=True)
879
+ # q_sorted = q[p_sorted_indexes]
880
+
881
+ # cum_p = torch.cumsum(p_sorted, dim=-1)
882
+
883
+ # # Identify the top-p (nucleus) indices
884
+ # top_p_mask = cum_p <= top_val
885
+ # top_p_mask[torch.argmax(cum_p > top_val)] = True # Include the first value exceeding p
886
+ # top_p = p_sorted[top_p_mask]
887
+ # top_q = q_sorted[top_p_mask]
888
+
889
+ # # Normalize the nucleus probabilities
890
+ # top_p = top_p / top_p.sum()
891
+ # top_q = top_q / top_q.sum()
892
+
893
+ # if div_type == 'top_p_kl_div':
894
+ # divs = torch.nn.functional.kl_div(torch.log(top_p), top_q, reduction='none').sum(dim=-1)
895
+
896
+ # if div_type == 'top_p_js_div':
897
+ # m = 0.5 * (top_p + top_q) # Midpoint distribution
898
+ # divs = (0.5 * torch.nn.functional.kl_div(torch.log(top_p), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(top_q), m, reduction='none')).sum(dim=-1)
899
+
900
+ # if div_type == 'top_p_tv_div':
901
+ # divs = 0.5 * torch.abs(top_p - top_q).sum(dim=-1)
902
+
903
+ # elif div_type == 'top_k_kl_div' or div_type == 'top_k_js_div' or div_type == 'top_k_tv_div':
904
+ # top_val = 50
905
+
906
+ # # print(f"p distr: {p}")
907
+ # # print(f"q distr: {q}")
908
+
909
+ # p_top_k, p_top_k_indices = torch.topk(p, top_val, dim=-1)
910
+ # q_top_k = torch.gather(q, -1, p_top_k_indices)
911
+
912
+ # top_k_mask = torch.zeros_like(p, dtype=torch.bool).scatter_(-1, p_top_k_indices, True)
913
+
914
+ # non_top_k_mask = ~top_k_mask # Invert the mask
915
+ # p_non_top_k_values = p * non_top_k_mask # Zero out the top_k values
916
+ # q_non_top_k_values = q * non_top_k_mask # Zero out the top_k values
917
+
918
+ # # Sum over the non-top_k positions
919
+ # p_non_top_k_sum = p_non_top_k_values.sum(dim=-1, keepdim=True)
920
+ # q_non_top_k_sum = q_non_top_k_values.sum(dim=-1, keepdim=True)
921
+ # # print(f"p_non_top_k_sum: {p_non_top_k_sum}")
922
+
923
+ # # p_non_top_k_sum = 1 - p_top_k.sum(dim=-1, keepdim=True)
924
+ # # q_non_top_k_sum = 1 - q_top_k.sum(dim=-1, keepdim=True)
925
+
926
+ # p_top_k = torch.cat((p_top_k, p_non_top_k_sum), dim=-1)
927
+ # q_top_k = torch.cat((q_top_k, q_non_top_k_sum), dim=-1)
928
+
929
+ # # print(f"p_top_k.shape: {p_top_k.shape}")
930
+ # # print(f"q_top_k.shape: {q_top_k.shape}")
931
+
932
+ # # p_top_k, p_top_k_indices = torch.topk(p, top_val, dim=-1)
933
+ # # q_top_k = q[:, :, p_top_k_indices]
934
+
935
+ # if div_type == 'top_k_kl_div':
936
+ # divs = torch.nn.functional.kl_div(torch.log(p_top_k), q_top_k, reduction='none').sum(dim=-1)
937
+
938
+ # if div_type == 'top_k_js_div':
939
+ # m = 0.5 * (p_top_k + q_top_k) # Midpoint distribution
940
+ # divs = (0.5 * torch.nn.functional.kl_div(torch.log(p_top_k), m, reduction='none') + 0.5 * torch.nn.functional.kl_div(torch.log(q_top_k), m, reduction='none')).sum(dim=-1)
941
+
942
+ # if div_type == 'top_k_tv_div':
943
+ # divs = 0.5 * torch.abs(p_top_k - q_top_k).sum(dim=-1)
944
+
945
+ # print(f"divs: {divs}")
946
+
947
+ # is_accepted = divs <= div_threshold
948
+
949
+
950
+ # print(f"divs: {divs.tolist()} threshold: {div_threshold} div_type: {div_type}")
951
+
952
+ # else:
953
+ # q = candidate_logits_unprocessed.softmax(dim=-1) # depends on whether processing candidate_logits or not
954
+ # q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
955
+ # p = new_logits.softmax(dim=-1)
956
+ # p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
957
+ # # print(f"SD in SBD - q: {q}, \np: {p}")
958
+ # probability_ratio = p_i / q_i
959
+
960
+ # # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
961
+ # # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
962
+ # # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
963
+ # r_i = torch.rand_like(probability_ratio)
964
+ # divs = r_i
965
+ # is_accepted = r_i <= probability_ratio
966
+
967
+ # # print(f"kl_div: {kl_div_threshold}")
968
+ # acceptance_time = time.time() - initial_start_time
969
+ # start_time = time.time()
970
+ # # print(f"acceptance time: {acceptance_time}")
971
+ # # print(f"divs: {divs}")
972
+
973
+ # # true_kl_divs = kl_divs.clone()
974
+ # if eos_position_logits != None:
975
+ # true_divs = divs.clone()
976
+ # eos_position_probs = eos_position_logits.softmax(dim=-1)
977
+ # eos_position_div = torch.nn.functional.kl_div(torch.log(p[:, -1, :].unsqueeze(1)), eos_position_probs, reduction='none').sum(dim=-1)
978
+ # true_divs[:, -1] = eos_position_div
979
+ # else:
980
+ # true_divs = divs
981
+
982
+ # # print(f"divs: {true_divs.tolist()}")
983
+ # # print(f"div_threshold: {div_threshold}")
984
+
985
+ # # labels = (kl_divs <= kl_div_threshold).int()
986
+
987
+ # n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 -
988
+ # # Process and warp the logits before sampling
989
+ # # if len(logits_processor) > 0:
990
+ # # for i in range(n_matches + 1):
991
+ # # new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
992
+ # # if do_sample and len(logits_warper) > 0:
993
+ # # for i in range(n_matches + 1):
994
+ # # new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
995
+ # logit_processing_time = time.time() - start_time
996
+ # start_time = time.time()
997
+ # # print(f"new_logits shape inside: {new_logits.shape}")
998
+ # # print(f"logit_processing_time: {logit_processing_time}")
999
+ # # print(f"candidate_generator_type: {candidate_generator_type}")
1000
+
1001
+ # if candidate_length == n_matches and new_candidate_input_ids[0, -1] == eos_token_id and candidate_generator_type != 'regular' and div_type != 'sd':
1002
+ # # print(f"Accepted an eos_token")
1003
+ # is_done_candidate = True
1004
+
1005
+ # is_done_time = time.time() - start_time
1006
+ # start_time = time.time()
1007
+ # # print(f"is_done_time: {is_done_time}")
1008
+ # if is_done_candidate and n_matches == candidate_length:
1009
+ # backoff_count = n_matches
1010
+ # total = candidate_length
1011
+ # n_matches -= 1
1012
+ # correction_term = 1
1013
+ # valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
1014
+
1015
+ # else:
1016
+ # if div_type != 'sd':
1017
+ # p_n_plus_1 = new_logits.softmax(dim=-1)[:, n_matches, :] # need to reuse new_logits because want to do post processing
1018
+ # p_prime = p_n_plus_1 # this is the distribution at the position we must sample from to replace the first rejection
1019
+
1020
+ # # token selection
1021
+ # if do_sample:
1022
+ # next_tokens = torch.multinomial(p_prime, num_samples=1)# .squeeze(1) # check that distributions are adjusted accordingly before being passed into this.
1023
+ # else:
1024
+ # next_tokens = torch.argmax(p_prime, dim=-1)
1025
+ # # The selected tokens include the matches (if any) plus the next sampled tokens
1026
+ # if n_matches > 0:
1027
+ # valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], next_tokens), dim=-1)
1028
+ # else:
1029
+ # valid_tokens = next_tokens
1030
+ # else:
1031
+ # gamma = candidate_logits.shape[1]
1032
+ # p_n_plus_1 = p[:, n_matches, :]
1033
+ # if n_matches < gamma:
1034
+ # q_n_plus_1 = q[:, n_matches, :]
1035
+ # p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
1036
+ # p_prime.div_(p_prime.sum())
1037
+ # else:
1038
+ # p_prime = p_n_plus_1
1039
+ # # print(f"p_prime: {p_prime}")
1040
+ # t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
1041
+
1042
+ # # The selected tokens include the matches (if any) plus the next sampled tokens
1043
+ # if n_matches > 0:
1044
+ # valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
1045
+ # else:
1046
+ # valid_tokens = t
1047
+
1048
+ # print(f"SBD: candidate_length: {candidate_length}, n_matches: {n_matches}")
1049
+ # # if candidate_length != 5:
1050
+ # # print(f"prediction: {true_divs[:, -1].item() > div_threshold}")
1051
+ # # spec_sampling_time = (time.time() - start_time) + acceptance_time
1052
+ # spec_sampling_time = time.time() - start_time
1053
+ # # print(f"spec_sampling_time: {spec_sampling_time}")
1054
+ # total_time = time.time() - initial_start_time
1055
+ # # print(f"total_time: {total_time} == {acceptance_time + logit_processing_time + is_done_time + spec_sampling_time}")
1056
+ # # print(f"total_time without processing: {total_time - logit_processing_time}")
1057
+ # return valid_tokens, n_matches, new_logits, correction_term, true_divs, acceptance_time, spec_sampling_time
custom_generate/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.40.0
3
+