update streamer generation
Browse files- generation_utils.py +24 -94
generation_utils.py
CHANGED
|
@@ -35,9 +35,23 @@ from transformers.cache_utils import (
|
|
| 35 |
DynamicCache,
|
| 36 |
)
|
| 37 |
from transformers.generation.utils import GenerationMixin
|
|
|
|
| 38 |
|
| 39 |
logger = logging.get_logger("Dimple."+__name__)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def top_p_logits(logits, top_p=None):
|
| 43 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
@@ -417,98 +431,9 @@ class DimpleGenerationMixin:
|
|
| 417 |
self,
|
| 418 |
inputs: Optional[torch.Tensor] = None,
|
| 419 |
generation_config: Optional[DimpleGenerationConfig] = None,
|
| 420 |
-
|
| 421 |
**kwargs,
|
| 422 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
| 423 |
-
|
| 424 |
-
"""
|
| 425 |
-
Generates sequences using a diffusion-based masked token denoising algorithm.
|
| 426 |
-
|
| 427 |
-
This method replaces masked tokens in `inputs` through iterative refinement, based on a denoising process
|
| 428 |
-
inspired by diffusion models. It uses intermediate confidence-based sampling to progressively fill in masked tokens.
|
| 429 |
-
|
| 430 |
-
Args:
|
| 431 |
-
inputs (torch.Tensor):
|
| 432 |
-
Input token IDs.
|
| 433 |
-
generation_config (DimpleGenerationConfig, optional):
|
| 434 |
-
An instance of `DimpleGenerationConfig` containing generation hyperparameters. If not provided,
|
| 435 |
-
the default generation config from the model is used.
|
| 436 |
-
**kwargs:
|
| 437 |
-
Additional generation parameters that override those in `generation_config`.
|
| 438 |
-
|
| 439 |
-
Returns:
|
| 440 |
-
DimpleModelOutput if `return_dict_in_generate=True`, else `torch.LongTensor` of generated token IDs.
|
| 441 |
-
|
| 442 |
-
Key Parameters (either in `generation_config` or passed via kwargs):
|
| 443 |
-
|
| 444 |
-
- `max_new_tokens` (int, default=None):
|
| 445 |
-
The number of new tokens to generate or fill in. This sets the target length of the generated sequence beyond
|
| 446 |
-
the prompt. It is added to the input length to determine the total sequence length.
|
| 447 |
-
|
| 448 |
-
- `output_history` (bool, default=False):
|
| 449 |
-
If `True`, returns the full sequence history at each denoising step. This is useful for visualization or debugging
|
| 450 |
-
purposes. Only returned if `return_dict_in_generate=True`.
|
| 451 |
-
|
| 452 |
-
- `return_dict_in_generate` (bool, default=False):
|
| 453 |
-
If `True`, returns a `DimpleModelOutput` dictionary containing the final sequences and, optionally, the stepwise history.
|
| 454 |
-
If `False`, returns a plain tensor of token IDs.
|
| 455 |
-
|
| 456 |
-
- `steps` (int, default=512):
|
| 457 |
-
The number of denoising steps to perform during generation. Each step progressively refines the sequence by replacing
|
| 458 |
-
some masked tokens based on a sampling algorithm.
|
| 459 |
-
|
| 460 |
-
- `temperature` (float, default=0.0):
|
| 461 |
-
Sampling temperature applied to logits before softmax. Lower values make outputs more deterministic,
|
| 462 |
-
while higher values allow for more randomness in token selection.
|
| 463 |
-
|
| 464 |
-
- `top_p` (float, default=None):
|
| 465 |
-
Nucleus sampling parameter. If set, only the most probable tokens whose cumulative probability exceeds `top_p`
|
| 466 |
-
are considered during sampling.
|
| 467 |
-
|
| 468 |
-
- `alg` (str, default="origin"):
|
| 469 |
-
The denoising algorithm to use for determining which tokens to replace at each step. Options include:
|
| 470 |
-
- `"origin"`: random token selection based on a probability ratio.
|
| 471 |
-
- `"origin-ratio"`: like `"origin"` but uses continuous transfer ratio.
|
| 472 |
-
- `"autoregressive"`: always fills the left-most masked token.
|
| 473 |
-
- `"maskgit_plus"`: confidence-based selection similar to Google's MaskGIT.
|
| 474 |
-
- `"topk_margin"`: token selection based on margin (top1 - top2 probability).
|
| 475 |
-
- `"entropy"`: prioritizes tokens with high negative entropy (uncertainty).
|
| 476 |
-
|
| 477 |
-
- `use_cache` (bool, default=False):
|
| 478 |
-
Enables prefilling of past key values (past KV) for efficient decoding.
|
| 479 |
-
|
| 480 |
-
- `alg_p_threshold` (float, optional, default=None):
|
| 481 |
-
A confidence threshold used to determine whether a token is confident enough to be selected. If the token's
|
| 482 |
-
confidence is above this value, it is unmasked and committed to the sequence. Helps stabilize generation.
|
| 483 |
-
|
| 484 |
-
- `use_original_confidence` (bool, default=True):
|
| 485 |
-
If `True`, confidence scores are computed using the original (pre-sampled) probability distribution.
|
| 486 |
-
If `False`, uses the current step's softmaxed logits. Enables more stable token selection in some cases.
|
| 487 |
-
|
| 488 |
-
- `decoding_pipeline` (str, default="dim"):
|
| 489 |
-
The generation decoding pipeline to use:
|
| 490 |
-
- `"dim"`: Dimple decoding pipeline.
|
| 491 |
-
- `"dream"`: Original DREAM token selection pipeline.
|
| 492 |
-
|
| 493 |
-
Example:
|
| 494 |
-
```python
|
| 495 |
-
output = model.diffusion_generate(
|
| 496 |
-
inputs=input_ids,
|
| 497 |
-
max_new_tokens=64,
|
| 498 |
-
output_history=True,
|
| 499 |
-
return_dict_in_generate=True,
|
| 500 |
-
steps=64,
|
| 501 |
-
temperature=0.2,
|
| 502 |
-
top_p=0.95,
|
| 503 |
-
alg="origin",
|
| 504 |
-
use_cache=True,
|
| 505 |
-
alg_p_threshold=0.95,
|
| 506 |
-
use_original_confidence=True,
|
| 507 |
-
decoding_pipeline="dim"
|
| 508 |
-
)
|
| 509 |
-
```
|
| 510 |
-
"""
|
| 511 |
-
|
| 512 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 513 |
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
| 514 |
generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
|
@@ -588,7 +513,7 @@ class DimpleGenerationMixin:
|
|
| 588 |
generation_config=generation_config,
|
| 589 |
generation_tokens_hook_func=generation_tokens_hook_func,
|
| 590 |
generation_logits_hook_func=generation_logits_hook_func,
|
| 591 |
-
|
| 592 |
**model_kwargs,
|
| 593 |
)
|
| 594 |
return result
|
|
@@ -599,7 +524,7 @@ class DimpleGenerationMixin:
|
|
| 599 |
generation_config: DimpleGenerationConfig,
|
| 600 |
generation_tokens_hook_func,
|
| 601 |
generation_logits_hook_func,
|
| 602 |
-
|
| 603 |
**model_kwargs,
|
| 604 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
| 605 |
# init values
|
|
@@ -618,7 +543,6 @@ class DimpleGenerationMixin:
|
|
| 618 |
top_p = generation_config.top_p
|
| 619 |
top_k = generation_config.top_k
|
| 620 |
attention_mask = model_kwargs.get("attention_mask", None)
|
| 621 |
-
attention_mask_4d = model_kwargs.get("attention_mask_4d", None)
|
| 622 |
|
| 623 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 624 |
|
|
@@ -784,9 +708,15 @@ class DimpleGenerationMixin:
|
|
| 784 |
if histories is not None:
|
| 785 |
histories.append(input_ids.clone())
|
| 786 |
|
|
|
|
|
|
|
|
|
|
| 787 |
if decoding_pipeline == 'dim' and torch.all(input_ids != mask_token_id):
|
| 788 |
break
|
| 789 |
-
|
|
|
|
|
|
|
|
|
|
| 790 |
if return_dict_in_generate:
|
| 791 |
return DimpleModelOutput(
|
| 792 |
sequences=input_ids,
|
|
|
|
| 35 |
DynamicCache,
|
| 36 |
)
|
| 37 |
from transformers.generation.utils import GenerationMixin
|
| 38 |
+
from transformers import TextIteratorStreamer
|
| 39 |
|
| 40 |
logger = logging.get_logger("Dimple."+__name__)
|
| 41 |
|
| 42 |
+
class FullSequenceStreamer(TextIteratorStreamer):
|
| 43 |
+
def __init__(self, tokenizer, **kwargs):
|
| 44 |
+
super().__init__(tokenizer, **kwargs)
|
| 45 |
+
|
| 46 |
+
def put(self, value, stream_end=False):
|
| 47 |
+
# Assume full token_ids are passed in every time
|
| 48 |
+
decoded = self.tokenizer.batch_decode(value, **self.decode_kwargs)
|
| 49 |
+
self.text_queue.put(decoded)
|
| 50 |
+
if stream_end:
|
| 51 |
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
| 52 |
+
|
| 53 |
+
def end(self):
|
| 54 |
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
| 55 |
|
| 56 |
def top_p_logits(logits, top_p=None):
|
| 57 |
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
|
| 431 |
self,
|
| 432 |
inputs: Optional[torch.Tensor] = None,
|
| 433 |
generation_config: Optional[DimpleGenerationConfig] = None,
|
| 434 |
+
streamer: Optional[FullSequenceStreamer]=None,
|
| 435 |
**kwargs,
|
| 436 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
| 438 |
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
| 439 |
generation_tokens_hook_func = model_kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
|
|
|
| 513 |
generation_config=generation_config,
|
| 514 |
generation_tokens_hook_func=generation_tokens_hook_func,
|
| 515 |
generation_logits_hook_func=generation_logits_hook_func,
|
| 516 |
+
streamer = streamer,
|
| 517 |
**model_kwargs,
|
| 518 |
)
|
| 519 |
return result
|
|
|
|
| 524 |
generation_config: DimpleGenerationConfig,
|
| 525 |
generation_tokens_hook_func,
|
| 526 |
generation_logits_hook_func,
|
| 527 |
+
streamer: Optional[FullSequenceStreamer] = None,
|
| 528 |
**model_kwargs,
|
| 529 |
) -> Union[DimpleModelOutput, torch.LongTensor]:
|
| 530 |
# init values
|
|
|
|
| 543 |
top_p = generation_config.top_p
|
| 544 |
top_k = generation_config.top_k
|
| 545 |
attention_mask = model_kwargs.get("attention_mask", None)
|
|
|
|
| 546 |
|
| 547 |
histories = [] if (return_dict_in_generate and output_history) else None
|
| 548 |
|
|
|
|
| 708 |
if histories is not None:
|
| 709 |
histories.append(input_ids.clone())
|
| 710 |
|
| 711 |
+
if streamer is not None:
|
| 712 |
+
streamer.put(input_ids[:, -answer_token_length+1:])
|
| 713 |
+
|
| 714 |
if decoding_pipeline == 'dim' and torch.all(input_ids != mask_token_id):
|
| 715 |
break
|
| 716 |
+
|
| 717 |
+
if streamer is not None:
|
| 718 |
+
streamer.end()
|
| 719 |
+
|
| 720 |
if return_dict_in_generate:
|
| 721 |
return DimpleModelOutput(
|
| 722 |
sequences=input_ids,
|