Update modeling_llada.py
Browse filesupdate generate_function with streaming output
- modeling_llada.py +81 -31
modeling_llada.py
CHANGED
|
@@ -1181,7 +1181,8 @@ class LLaDAModel(nn.Module):
|
|
| 1181 |
attention_bias: Optional[torch.Tensor] = None,
|
| 1182 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1183 |
use_cache: bool = False,
|
| 1184 |
-
|
|
|
|
| 1185 |
output_hidden_states: Optional[bool] = None,
|
| 1186 |
) -> LLaDAOutput:
|
| 1187 |
"""
|
|
@@ -1351,10 +1352,9 @@ class LLaDAModel(nn.Module):
|
|
| 1351 |
assert cache is not None
|
| 1352 |
attn_key_values.extend(cache)
|
| 1353 |
|
| 1354 |
-
if
|
| 1355 |
-
# shape: (batch_size,
|
| 1356 |
-
x = x[:, -
|
| 1357 |
-
|
| 1358 |
# Apply final layer norm.
|
| 1359 |
# shape: (batch_size, seq_len or 1, d_model)
|
| 1360 |
x = self.transformer.ln_f(x) # type: ignore
|
|
@@ -1406,6 +1406,7 @@ class LLaDAModelLM(PreTrainedModel):
|
|
| 1406 |
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1407 |
else:
|
| 1408 |
self.model = model
|
|
|
|
| 1409 |
|
| 1410 |
def forward(
|
| 1411 |
self,
|
|
@@ -1419,7 +1420,8 @@ class LLaDAModelLM(PreTrainedModel):
|
|
| 1419 |
output_attentions: Optional[bool] = None,
|
| 1420 |
output_hidden_states: Optional[bool] = None,
|
| 1421 |
return_dict: Optional[bool] = None,
|
| 1422 |
-
|
|
|
|
| 1423 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1424 |
if use_cache is None:
|
| 1425 |
use_cache = self.config.use_cache
|
|
@@ -1438,6 +1440,8 @@ class LLaDAModelLM(PreTrainedModel):
|
|
| 1438 |
past_key_values=past_key_values,
|
| 1439 |
use_cache=use_cache,
|
| 1440 |
output_hidden_states=output_hidden_states,
|
|
|
|
|
|
|
| 1441 |
)
|
| 1442 |
|
| 1443 |
logits = outputs.logits
|
|
@@ -1457,31 +1461,6 @@ class LLaDAModelLM(PreTrainedModel):
|
|
| 1457 |
hidden_states=hidden_states,
|
| 1458 |
)
|
| 1459 |
|
| 1460 |
-
def can_generate(self) -> bool:
|
| 1461 |
-
return True
|
| 1462 |
-
|
| 1463 |
-
def prepare_inputs_for_generation(
|
| 1464 |
-
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 1465 |
-
):
|
| 1466 |
-
if past_key_values:
|
| 1467 |
-
# This is because we want the model to only process the last generated token.
|
| 1468 |
-
input_ids = input_ids[:, -1:]
|
| 1469 |
-
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 1470 |
-
|
| 1471 |
-
model_inputs.update(kwargs)
|
| 1472 |
-
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 1473 |
-
return model_inputs
|
| 1474 |
-
|
| 1475 |
-
# TODO: these are required to make the implementation complete.
|
| 1476 |
-
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
| 1477 |
-
# pass
|
| 1478 |
-
#
|
| 1479 |
-
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
| 1480 |
-
# pass
|
| 1481 |
-
#
|
| 1482 |
-
# def _reorder_cache(self, past_key_values, beam_idx):
|
| 1483 |
-
# pass
|
| 1484 |
-
|
| 1485 |
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1486 |
return self.model.transformer.wte
|
| 1487 |
|
|
@@ -1504,5 +1483,76 @@ class LLaDAModelLM(PreTrainedModel):
|
|
| 1504 |
if self.config.weight_tying:
|
| 1505 |
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1507 |
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1508 |
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
|
|
|
| 1181 |
attention_bias: Optional[torch.Tensor] = None,
|
| 1182 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1183 |
use_cache: bool = False,
|
| 1184 |
+
last_block_logits_only: bool = False,
|
| 1185 |
+
block_length: int = 64,
|
| 1186 |
output_hidden_states: Optional[bool] = None,
|
| 1187 |
) -> LLaDAOutput:
|
| 1188 |
"""
|
|
|
|
| 1352 |
assert cache is not None
|
| 1353 |
attn_key_values.extend(cache)
|
| 1354 |
|
| 1355 |
+
if last_block_logits_only:
|
| 1356 |
+
# shape: (batch_size, block_length, d_model)
|
| 1357 |
+
x = x[:, -block_length:, :]
|
|
|
|
| 1358 |
# Apply final layer norm.
|
| 1359 |
# shape: (batch_size, seq_len or 1, d_model)
|
| 1360 |
x = self.transformer.ln_f(x) # type: ignore
|
|
|
|
| 1406 |
self.model = LLaDAModel(model_config, init_params=init_params)
|
| 1407 |
else:
|
| 1408 |
self.model = model
|
| 1409 |
+
self.mask_id = model_config.mask_token_id
|
| 1410 |
|
| 1411 |
def forward(
|
| 1412 |
self,
|
|
|
|
| 1420 |
output_attentions: Optional[bool] = None,
|
| 1421 |
output_hidden_states: Optional[bool] = None,
|
| 1422 |
return_dict: Optional[bool] = None,
|
| 1423 |
+
last_block_logits_only: bool = False,
|
| 1424 |
+
block_length: int = 64,
|
| 1425 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1426 |
if use_cache is None:
|
| 1427 |
use_cache = self.config.use_cache
|
|
|
|
| 1440 |
past_key_values=past_key_values,
|
| 1441 |
use_cache=use_cache,
|
| 1442 |
output_hidden_states=output_hidden_states,
|
| 1443 |
+
last_block_logits_only=last_block_logits_only,
|
| 1444 |
+
block_length=block_length,
|
| 1445 |
)
|
| 1446 |
|
| 1447 |
logits = outputs.logits
|
|
|
|
| 1461 |
hidden_states=hidden_states,
|
| 1462 |
)
|
| 1463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1464 |
def get_input_embeddings(self) -> torch.nn.Module:
|
| 1465 |
return self.model.transformer.wte
|
| 1466 |
|
|
|
|
| 1483 |
if self.config.weight_tying:
|
| 1484 |
self.model.transformer.ff_out = self.model.transformer.wte
|
| 1485 |
|
| 1486 |
+
|
| 1487 |
+
def prefill_phase(self, input_ids, block_length):
|
| 1488 |
+
"""Prefill phase: Process initial prompt and generate KV cache."""
|
| 1489 |
+
with torch.no_grad():
|
| 1490 |
+
outputs = self(
|
| 1491 |
+
input_ids=input_ids,
|
| 1492 |
+
use_cache=True,
|
| 1493 |
+
return_dict=True,
|
| 1494 |
+
last_block_logits_only=True,
|
| 1495 |
+
block_length=block_length
|
| 1496 |
+
)
|
| 1497 |
+
output_past_key_values = []
|
| 1498 |
+
for i in range(len(outputs.past_key_values)):
|
| 1499 |
+
k,v = outputs.past_key_values[i]
|
| 1500 |
+
new_k,new_v = k[:,:,:-block_length,:],v[:,:,:-block_length,:]
|
| 1501 |
+
output_past_key_values.append((new_k,new_v))
|
| 1502 |
+
output_past_key_values = tuple(output_past_key_values)
|
| 1503 |
+
return {
|
| 1504 |
+
'input_ids': input_ids,
|
| 1505 |
+
'logits': outputs.logits,
|
| 1506 |
+
'past_key_values': output_past_key_values,
|
| 1507 |
+
}
|
| 1508 |
+
|
| 1509 |
+
def unmask_function_greedy(self, logits, x, threshold=0.9):
|
| 1510 |
+
"""Greedy unmasking function with confidence threshold."""
|
| 1511 |
+
mask_index = x == self.mask_id
|
| 1512 |
+
x_top_0 = torch.argmax(logits, dim=-1)
|
| 1513 |
+
p = F.softmax(logits, dim=-1)
|
| 1514 |
+
confidence = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x_top_0, -1)), -1)
|
| 1515 |
+
transfer_index = torch.zeros_like(x_top_0, dtype=torch.bool, device=x_top_0.device)
|
| 1516 |
+
confidence = torch.where(mask_index, confidence, -torch.inf)
|
| 1517 |
+
for j in range(confidence.shape[0]):
|
| 1518 |
+
mask = confidence[j] > threshold
|
| 1519 |
+
if mask.sum() == 0:
|
| 1520 |
+
max_conf_idx = torch.argmax(confidence[j])
|
| 1521 |
+
mask[max_conf_idx] = True
|
| 1522 |
+
transfer_index[j] = mask
|
| 1523 |
+
x[transfer_index] = x_top_0[transfer_index]
|
| 1524 |
+
return x
|
| 1525 |
+
|
| 1526 |
+
@torch.no_grad()
|
| 1527 |
+
def generate(self, input_ids, attention_mask, max_gen_length=1024, block_length=64, threshold=0.9,streaming=False,eos_token_id=126081):
|
| 1528 |
+
batchsize, prompt_length = input_ids.shape
|
| 1529 |
+
max_num_blocks = max_gen_length // block_length
|
| 1530 |
+
output_ids = input_ids
|
| 1531 |
+
block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
|
| 1532 |
+
output_ids = torch.cat([output_ids, block_x], dim=-1)
|
| 1533 |
+
# prefilling block loop
|
| 1534 |
+
prefill_outputs = self.prefill_phase(output_ids, block_length)
|
| 1535 |
+
past_key_values = prefill_outputs['past_key_values']
|
| 1536 |
+
logits = prefill_outputs['logits']
|
| 1537 |
+
output_ids[:,-block_length:] = self.unmask_function_greedy(logits=logits, x=output_ids[:,-block_length:], threshold=threshold)
|
| 1538 |
+
# decoding block loop
|
| 1539 |
+
for j in range(max_num_blocks):
|
| 1540 |
+
while (output_ids[:,-block_length:] == self.mask_id).sum():
|
| 1541 |
+
outputs = self(
|
| 1542 |
+
input_ids=output_ids[:,-block_length:],
|
| 1543 |
+
past_key_values=past_key_values,
|
| 1544 |
+
use_cache=True,
|
| 1545 |
+
return_dict=True
|
| 1546 |
+
)
|
| 1547 |
+
output_ids[:,-block_length:] = self.unmask_function_greedy(logits=outputs.logits, x=output_ids[:,-block_length:], threshold=threshold)
|
| 1548 |
+
past_key_values = outputs.past_key_values
|
| 1549 |
+
if streaming:
|
| 1550 |
+
yield output_ids[:,-block_length:]
|
| 1551 |
+
if (output_ids == eos_token_id).any():
|
| 1552 |
+
return output_ids[:, prompt_length:]
|
| 1553 |
+
block_x = torch.full((batchsize, block_length), self.mask_id, dtype=torch.long).to(self.device)
|
| 1554 |
+
output_ids = torch.cat([output_ids, block_x], dim=-1)
|
| 1555 |
+
return output_ids[:, prompt_length:]
|
| 1556 |
+
|
| 1557 |
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
| 1558 |
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|