File size: 11,101 Bytes
73bcbf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
"""
Author : Fabien FURFARO
"""
import logging
import os
import re
from typing import Any, Dict, List, Optional, Union
from jinja2 import Environment, FileSystemLoader
import psutil
import torch
from transformers import AutoConfig, PretrainedConfig
logger = logging.getLogger(__name__) # monitoring
# Constants
BYTES_IN_GB = 1024**3
def convert_sets_to_lists(obj):
"""Convert sets to list for LoRA serialized config"""
if isinstance(obj, set):
return list(obj)
if isinstance(obj, dict):
return {k: convert_sets_to_lists(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [convert_sets_to_lists(x) for x in obj]
return obj
class TpttConfig(PretrainedConfig):
"""
Configuration class for the TPTT model.
This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
"""
model_type = "tptt"
auto_map = {
"AutoModelForCausalLM": "modeling_tptt.TpttModel",
"AutoConfig": "configuration_tptt.TpttConfig",
}
architectures = ["TpttModel"]
RECURRENT_MODES = {
"delta_rule": {
"order": 1,
"gate_type": "k",
"linear": True,
"trick": "derivative",
},
"delta_rule_v": {
"order": 1,
"gate_type": "v",
"linear": True,
"trick": "derivative",
},
"delta_rule_kv": {
"order": 1,
"gate_type": "kv",
"linear": True,
"trick": "derivative",
},
"delta_rule_gelu": {
"order": 1,
"gate_type": "k",
"linear": False,
"trick": "derivative",
},
"delta_product": {
"order": 2,
"gate_type": "k",
"linear": True,
"trick": "derivative",
},
"delta_product_r": {
"order": 2,
"gate_type": "k",
"linear": True,
"trick": "rotative",
},
"delta_product_c": {
"order": 2,
"gate_type": "k",
"linear": True,
"trick": "combined",
},
} # Tested modes, see parse_mode_name if you want to add more
def __init__(
self,
base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
base_model_name: str = "meta-llama/Llama-3.2-1B",
base_model_subfolder: Optional[str] = None,
name_or_path: Optional[str] = None,
model_task: str = "causal_lm",
target_modules_names: Optional[List[str]] = None,
operator_mode: str = "delta_rule",
use_linear_checkpoint: Optional[bool] = None,
max_self_attn_length: Optional[
int
] = None, # unnecessary if SWA, else, standards 8192
base_scale_attn: bool = False,
mag_weight: float = 0.5, # if 1.0, use only linear operator
cross_gate: bool = False, # unlinear mixing strategy
max_chunk_size: int = 64, # 128 if adaptive chunking (longest)
linear_precision: Union[str, torch.dtype] = "float32",
lora_config: Optional[dict] = None, # only serialized accepted
padding_side: Optional[str] = None, # for tokenizer, default "right"
bidirectional: bool = False, # if True, use bidirectional attention
pooling_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# If base_model_config is provided, load it and merge with this config
if base_model_config is not None:
if isinstance(base_model_config, PretrainedConfig):
base_model_config = base_model_config.to_dict()
else:
# Load config from Hugging Face Hub or a local path
base_model_config = AutoConfig.from_pretrained(
base_model_name, **kwargs
).to_dict()
# Merge all backbone fields into this config
for k, v in base_model_config.items():
setattr(self, k, v)
self.base_model_name = base_model_name
self.base_model_subfolder = base_model_subfolder
self.model_task = model_task
if name_or_path is not None:
self._name_or_path = name_or_path
else:
if "/" in base_model_name:
self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
else:
self._name_or_path = "Titans-" + base_model_name
self.target_modules_names = target_modules_names or [
"attn",
"self_attn",
"attention",
]
self.operator_mode = operator_mode
# Detect available memory on accelerator device
if torch.cuda.is_available():
_, total_mem = torch.cuda.mem_get_info()
else:
total_mem = psutil.virtual_memory().total
total_mem_gb = total_mem / BYTES_IN_GB
self.use_linear_checkpoint = (
total_mem_gb < 16
if use_linear_checkpoint is None
else use_linear_checkpoint
)
self.base_scale_attn = base_scale_attn
self.mag_weight = mag_weight
self.cross_gate = cross_gate
self.max_chunk_size = max_chunk_size
self.max_self_attn_length = max_self_attn_length
if isinstance(linear_precision, torch.dtype):
linear_precision = str(linear_precision).replace("torch.", "")
self.linear_precision = linear_precision
self.lora_config = lora_config
if lora_config is not None:
if hasattr(self.lora_config.get("peft_type"), "value"):
self.lora_config["peft_type"] = self.lora_config["peft_type"].value
self.lora_config = convert_sets_to_lists(self.lora_config)
self.padding_side = padding_side
self.bidirectional = bidirectional
if self.bidirectional:
print("Bidirectional is enabled, need to be uncausal and unpadded.")
self.pooling_config = pooling_config
super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
# Copy class attributes to instance for serialization (save dict)
self.model_type = self.__class__.model_type
self.auto_map = self.__class__.auto_map
self.architectures = self.__class__.architectures
# Padding side configuration if not set
if self.padding_side is None:
self.padding_side = "right"
logger.info("Warning: padding_side is None, defaulting to 'right'.")
# set recurrent configuration from operator mode
if operator_mode not in self.__class__.RECURRENT_MODES:
self.recurrent_config = parse_mode_name(operator_mode)
else:
self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
TpttConfig.register_for_auto_class()
def parse_mode_name(name: str) -> dict:
"""Parse mode to recurrent config"""
if name.startswith("delta_product"):
parts = name.split("_")
# Prefix is always two words: 'delta' and 'product'
base_len = 2
order = 2
gate_type = "k"
linear = True
trick = "derivative"
idx = base_len
# Check for order (immediately after the prefix)
if len(parts) > idx and parts[idx].isdigit():
order = int(parts[idx])
idx += 1
remaining = parts[idx:]
# Trick (r/c) is always at the far right if present
if remaining and remaining[-1] in ("r", "c"):
trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
remaining = remaining[:-1]
# 'gelu' comes just before the trick if present
if remaining and remaining[-1] == "gelu":
linear = False
remaining = remaining[:-1]
# If anything remains, it's the gate_type
if remaining:
gate_type = "_".join(remaining)
return {
"order": order,
"gate_type": gate_type,
"linear": linear,
"trick": trick,
}
# delta_rule[_gate][_gelu]
m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
if m:
return {
"order": 1,
"gate_type": m.group(1) if m.group(1) else "k",
"linear": not bool(m.group(2)),
"trick": "derivative",
}
raise ValueError(f"Unknown mode: {name}")
def get_mode_name(
order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
) -> str:
"""Get recurrent mode name from parameter"""
base = (
"delta_rule"
if order == 1
else ("delta_product" if order == 2 else f"delta_product_{order}")
)
parts = []
if gate_type != "k":
parts.append(gate_type)
if not linear:
parts.append("gelu")
if order >= 2 and trick != "derivative":
parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
return base + (("_" + "_".join(parts)) if parts else "")
def render_template(template_path: str, variables: dict) -> str:
"""Load and render a Jinja2 template from any file path."""
env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
template = env.get_template(os.path.basename(template_path))
return template.render(**variables)
def write_model_card(output_path: str, content: str):
"""Write the generated content into README.md."""
os.makedirs(output_path, exist_ok=True)
readme_path = os.path.join(output_path, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(content)
def generate_model_card(
output_path: str,
config: Union[dict, object],
template: Optional[
str
], # can be "model_card" OR an absolute/relative path to a .md file
extra_variables: Optional[Dict] = None,
):
"""
Generate a README.md file from a Jinja2 template and a configuration.
- template can be either:
* a full path to a template file
* a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
"""
if template is None:
template = "model_card_template" # default template name
# Locate the template
if os.path.exists(template): # direct file path provided
template_path = template
else:
default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
template_path = os.path.join(default_templates_dir, f"{template}.md")
if not os.path.exists(template_path):
raise FileNotFoundError(f"Template not found: {template_path}")
variables = {
"model_id": os.path.basename(output_path),
"config": config,
}
if extra_variables:
variables.update(extra_variables)
content = render_template(template_path, variables)
write_model_card(output_path, content)
|