File size: 4,907 Bytes
86f5f5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations

from typing import Optional, Any

import torch
from torch import nn
from transformers.cache_utils import Cache  # kept for potential future use
from transformers.models.qwen3.modeling_qwen3 import (
    Qwen3Attention,
    Qwen3DecoderLayer,
    Qwen3MLP,
    Qwen3RMSNorm,
    Qwen3Model,
    Qwen3ForCausalLM,
    Qwen3PreTrainedModel,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel

try:
    from peft import PeftModel
except ImportError:
    PeftModel = Any  # soft dependency

logger = logging.get_logger(__name__)

# ---------------------------------------------------------------------------
# 1) Bidirectional attention: disable causal masking & sliding window
# ---------------------------------------------------------------------------
class ModifiedQwen3Attention(Qwen3Attention):
    """Full-context self-attention (no causal mask)."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False
        self.sliding_window = None


# ---------------------------------------------------------------------------
# 2) Decoder layer using the bidirectional attention module
# ---------------------------------------------------------------------------
class ModifiedQwen3DecoderLayer(Qwen3DecoderLayer):
    """Decoder layer with full-context attention."""

    def __init__(self, config: PretrainedConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.self_attn = ModifiedQwen3Attention(config=config, layer_idx=layer_idx)
        self.attention_type = "full_attention"
        self.sliding_window = None


# ---------------------------------------------------------------------------
# 3) Backbone: Qwen-3 with bidirectional self-attention
# ---------------------------------------------------------------------------
class Qwen3BiModel(Qwen3Model):
    """Qwen-3 backbone whose self-attention is bidirectional."""

    _no_split_modules = ["ModifiedQwen3DecoderLayer"]

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [ModifiedQwen3DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
        )
        self.has_sliding_layers = False

    @staticmethod
    def _build_pad_bias(pad_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
        """[B,L] -> additive bias [B,1,1,L] with -inf on padding."""
        neg_inf = torch.finfo(dtype).min
        bias = (~pad_mask.bool()).to(dtype) * neg_inf
        return bias[:, None, None, :]

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        # Default to keep-all if no mask is provided
        if attention_mask is None:
            if input_ids is None:
                raise ValueError("Either attention_mask or input_ids must be provided.")
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)

        pad_bias = self._build_pad_bias(attention_mask, self.embed_tokens.weight.dtype)
        # Dict mask tells parent to skip causal-mask generation
        attn_mask_dict = {"full_attention": pad_bias}

        return super().forward(
            input_ids=input_ids,
            attention_mask=attn_mask_dict,
            **kwargs,
        )


# ---------------------------------------------------------------------------
# 4) Task head: MNTP (masked next-token) — no generation API
# ---------------------------------------------------------------------------
class Qwen3BiForMNTP(Qwen3ForCausalLM):
    """Bidirectional Qwen-3 with LM head for masked-token objectives."""

    def __init__(self, config: PretrainedConfig):
        # Bypass parent __init__ to wire a custom backbone
        Qwen3PreTrainedModel.__init__(self, config)

        self.model = Qwen3BiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def generate(self, *args, **kwargs):  # type: ignore[override]
        """Disabled: bidirectional backbone is not autoregressive."""
        raise NotImplementedError(
            "generate() is disabled: this backbone is bidirectional and not autoregressive."
        )

    # -------- PEFT helpers --------
    def get_model_for_peft(self):
        return self.model

    def set_model_for_peft(self, model: PeftModel):  # type: ignore[override]
        self.model = model

    def save_peft_model(self, path: str):
        if isinstance(self.model, PeftModel):  # type: ignore[arg-type]
            self.model.save_pretrained(path)
        else:
            raise ValueError("Backbone is not a PEFT model; nothing to save.")