File size: 4,907 Bytes
86f5f5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from __future__ import annotations
from typing import Optional, Any
import torch
from torch import nn
from transformers.cache_utils import Cache # kept for potential future use
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Attention,
Qwen3DecoderLayer,
Qwen3MLP,
Qwen3RMSNorm,
Qwen3Model,
Qwen3ForCausalLM,
Qwen3PreTrainedModel,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
try:
from peft import PeftModel
except ImportError:
PeftModel = Any # soft dependency
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# 1) Bidirectional attention: disable causal masking & sliding window
# ---------------------------------------------------------------------------
class ModifiedQwen3Attention(Qwen3Attention):
"""Full-context self-attention (no causal mask)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.sliding_window = None
# ---------------------------------------------------------------------------
# 2) Decoder layer using the bidirectional attention module
# ---------------------------------------------------------------------------
class ModifiedQwen3DecoderLayer(Qwen3DecoderLayer):
"""Decoder layer with full-context attention."""
def __init__(self, config: PretrainedConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = ModifiedQwen3Attention(config=config, layer_idx=layer_idx)
self.attention_type = "full_attention"
self.sliding_window = None
# ---------------------------------------------------------------------------
# 3) Backbone: Qwen-3 with bidirectional self-attention
# ---------------------------------------------------------------------------
class Qwen3BiModel(Qwen3Model):
"""Qwen-3 backbone whose self-attention is bidirectional."""
_no_split_modules = ["ModifiedQwen3DecoderLayer"]
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[ModifiedQwen3DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
)
self.has_sliding_layers = False
@staticmethod
def _build_pad_bias(pad_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""[B,L] -> additive bias [B,1,1,L] with -inf on padding."""
neg_inf = torch.finfo(dtype).min
bias = (~pad_mask.bool()).to(dtype) * neg_inf
return bias[:, None, None, :]
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
# Default to keep-all if no mask is provided
if attention_mask is None:
if input_ids is None:
raise ValueError("Either attention_mask or input_ids must be provided.")
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
pad_bias = self._build_pad_bias(attention_mask, self.embed_tokens.weight.dtype)
# Dict mask tells parent to skip causal-mask generation
attn_mask_dict = {"full_attention": pad_bias}
return super().forward(
input_ids=input_ids,
attention_mask=attn_mask_dict,
**kwargs,
)
# ---------------------------------------------------------------------------
# 4) Task head: MNTP (masked next-token) — no generation API
# ---------------------------------------------------------------------------
class Qwen3BiForMNTP(Qwen3ForCausalLM):
"""Bidirectional Qwen-3 with LM head for masked-token objectives."""
def __init__(self, config: PretrainedConfig):
# Bypass parent __init__ to wire a custom backbone
Qwen3PreTrainedModel.__init__(self, config)
self.model = Qwen3BiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def generate(self, *args, **kwargs): # type: ignore[override]
"""Disabled: bidirectional backbone is not autoregressive."""
raise NotImplementedError(
"generate() is disabled: this backbone is bidirectional and not autoregressive."
)
# -------- PEFT helpers --------
def get_model_for_peft(self):
return self.model
def set_model_for_peft(self, model: PeftModel): # type: ignore[override]
self.model = model
def save_peft_model(self, path: str):
if isinstance(self.model, PeftModel): # type: ignore[arg-type]
self.model.save_pretrained(path)
else:
raise ValueError("Backbone is not a PEFT model; nothing to save.")
|