|  | from .configuration_hypernet import ZettHypernetConfig | 
					
						
						|  | from transformers import PreTrainedModel, RobertaConfig, RobertaModel | 
					
						
						|  | from functools import partial | 
					
						
						|  |  | 
					
						
						|  | from torch import nn as nn | 
					
						
						|  | import torch | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  |  | 
					
						
						|  | class Rescaler(nn.Module): | 
					
						
						|  | def __init__(self, dim: int): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.dim = dim | 
					
						
						|  |  | 
					
						
						|  | self.w = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False) | 
					
						
						|  | self.b = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, x): | 
					
						
						|  | return self.w * x + self.b | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ProjectorBlock(nn.Module): | 
					
						
						|  | def __init__(self, input_dim: int, dim: int, intermediate_dim: int): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.input_dim = input_dim | 
					
						
						|  | self.dim = dim | 
					
						
						|  | self.intermediate_dim = intermediate_dim | 
					
						
						|  |  | 
					
						
						|  | self.dense1 = nn.Linear(self.input_dim, self.intermediate_dim) | 
					
						
						|  | self.dense2 = nn.Linear(self.intermediate_dim, self.dim) | 
					
						
						|  |  | 
					
						
						|  | self.ln = nn.LayerNorm(self.dim, eps=1e-6) | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, x): | 
					
						
						|  | h = F.gelu( | 
					
						
						|  | self.dense2(F.gelu(self.dense1(x), approximate="tanh")), | 
					
						
						|  | approximate="tanh", | 
					
						
						|  | ) | 
					
						
						|  | return self.ln(h + x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ZettHypernet(PreTrainedModel): | 
					
						
						|  | config_class = ZettHypernetConfig | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, config: ZettHypernetConfig): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  |  | 
					
						
						|  | self.config = config | 
					
						
						|  | self.has_separate_out_embeddings = getattr( | 
					
						
						|  | self.config, "separate_out_embeddings", False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_embed_lang_id: | 
					
						
						|  | self.lang_embeddings = nn.Embedding( | 
					
						
						|  | self.config.n_langs, self.config.hn_hidden_size | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.has_separate_out_embeddings: | 
					
						
						|  | n_in_embd = self.config.n_embd * 2 | 
					
						
						|  | n_out_embd = self.config.n_embd | 
					
						
						|  | else: | 
					
						
						|  | n_in_embd = self.config.n_embd | 
					
						
						|  | n_out_embd = self.config.n_embd | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_model_type == "roberta": | 
					
						
						|  | config = RobertaConfig.from_pretrained( | 
					
						
						|  | self.config.hn_model_name_or_path | 
					
						
						|  | ) | 
					
						
						|  | config.num_hidden_layers = self.config.hn_n_layers | 
					
						
						|  | config.hidden_size = self.config.hn_hidden_size | 
					
						
						|  | config.intermediate_size = self.config.hn_intermediate_size | 
					
						
						|  | if getattr(self.config, "hn_num_attention_heads", None) is None: | 
					
						
						|  | self.config.hn_num_attention_heads = self.config.hn_hidden_size // 64 | 
					
						
						|  | config.num_attention_heads = self.config.hn_num_attention_heads | 
					
						
						|  | self.embed_init_range = config.initializer_range | 
					
						
						|  | module_class = partial(RobertaModel, add_pooling_layer=False) | 
					
						
						|  | elif self.config.hn_model_type == "t5": | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_embed_using_source_embeddings: | 
					
						
						|  |  | 
					
						
						|  | config.vocab_size = self.config.pad_token_id + 1 | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | self.config.hn_add_inter_token_attention | 
					
						
						|  | or self.config.hn_embed_target_priors | 
					
						
						|  | ): | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  | self.pad_token_id = self.config.pad_token_id | 
					
						
						|  | assert self.pad_token_id is not None | 
					
						
						|  | self.model = module_class(config) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.fallback_embeddings = nn.Embedding( | 
					
						
						|  | max(self.config.hn_n_extra_tokens, 1), n_in_embd | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_embed_using_source_embeddings: | 
					
						
						|  | self.input_projection = nn.Sequential( | 
					
						
						|  | *[ | 
					
						
						|  | nn.Linear(n_in_embd, self.config.hn_hidden_size), | 
					
						
						|  | ProjectorBlock( | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_intermediate_size, | 
					
						
						|  | ), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_single_head: | 
					
						
						|  | self.output_projection = nn.Sequential( | 
					
						
						|  | *[ | 
					
						
						|  | ProjectorBlock( | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_intermediate_size, | 
					
						
						|  | ), | 
					
						
						|  | nn.Linear(self.config.hn_hidden_size, n_in_embd), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.output_projection = nn.Sequential( | 
					
						
						|  | *[ | 
					
						
						|  | ProjectorBlock( | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_intermediate_size, | 
					
						
						|  | ), | 
					
						
						|  | nn.Linear(self.config.hn_hidden_size, n_out_embd), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | if self.has_separate_out_embeddings: | 
					
						
						|  | self.output_projection_out = nn.Sequential( | 
					
						
						|  | *[ | 
					
						
						|  | ProjectorBlock( | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_hidden_size, | 
					
						
						|  | self.config.hn_intermediate_size, | 
					
						
						|  | ), | 
					
						
						|  | nn.Linear(self.config.hn_hidden_size, self.config.n_embd), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_rescale_embeddings: | 
					
						
						|  | self.in_scaler = Rescaler(n_in_embd) | 
					
						
						|  | self.scaler = Rescaler(n_out_embd) | 
					
						
						|  |  | 
					
						
						|  | if self.has_separate_out_embeddings: | 
					
						
						|  | self.out_scaler = Rescaler(self.config.n_embd) | 
					
						
						|  |  | 
					
						
						|  | if getattr(self.config, "hn_predict_bias", False): | 
					
						
						|  | self.bias_projection = nn.Linear(self.config.hn_hidden_size, 1) | 
					
						
						|  |  | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | target_surface_forms, | 
					
						
						|  | target_priors=None, | 
					
						
						|  | source_embeddings=None, | 
					
						
						|  | lang_index=None, | 
					
						
						|  | deterministic: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | if target_priors is not None: | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  | if not self.config.hn_embed_using_source_embeddings: | 
					
						
						|  | raise NotImplementedError() | 
					
						
						|  |  | 
					
						
						|  | use_fallback = target_surface_forms >= self.config.original_vocab_size | 
					
						
						|  |  | 
					
						
						|  | main_ids = torch.minimum( | 
					
						
						|  | target_surface_forms, torch.tensor(self.config.original_vocab_size - 1, device=self.device) | 
					
						
						|  | ) | 
					
						
						|  | fallback_ids = torch.maximum( | 
					
						
						|  | target_surface_forms - self.config.original_vocab_size, torch.tensor(0, device=self.device) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | source_embeds = F.embedding(main_ids, weight=source_embeddings) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_rescale_embeddings: | 
					
						
						|  | source_embeds = self.in_scaler(source_embeds) | 
					
						
						|  |  | 
					
						
						|  | inputs_embeds = torch.where( | 
					
						
						|  | use_fallback[..., None], | 
					
						
						|  | self.fallback_embeddings(fallback_ids), | 
					
						
						|  | source_embeds, | 
					
						
						|  | ) | 
					
						
						|  | inputs_embeds = self.input_projection(inputs_embeds) | 
					
						
						|  | attention_mask = target_surface_forms != self.pad_token_id | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_embed_lang_id: | 
					
						
						|  | lang_embedding = self.lang_embeddings(lang_index).squeeze() | 
					
						
						|  |  | 
					
						
						|  | lang_embedding -= self.model.embeddings.token_type_embeddings( | 
					
						
						|  | torch.tensor(0, device=self.device) | 
					
						
						|  | ) + self.model.embeddings.position_embeddings( | 
					
						
						|  | torch.tensor(attention_mask.shape[1], device=self.device) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | lang_embedding = lang_embedding[None, None, :].expand( | 
					
						
						|  | inputs_embeds.shape[0], -1, -1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | inputs_embeds = torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | inputs_embeds, | 
					
						
						|  | lang_embedding, | 
					
						
						|  | ], | 
					
						
						|  | axis=1, | 
					
						
						|  | ) | 
					
						
						|  | attention_mask = torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | attention_mask, | 
					
						
						|  | torch.ones(lang_embedding.shape[:-1], dtype=torch.bool, device=self.device), | 
					
						
						|  | ], | 
					
						
						|  | axis=1, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | position_ids = torch.broadcast_to( | 
					
						
						|  | torch.arange(torch.atleast_2d(attention_mask).shape[-1], device=self.device), | 
					
						
						|  | attention_mask.shape, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | hidden_states = self.model( | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | position_ids=position_ids, | 
					
						
						|  | ).last_hidden_state | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_concat_last_hidden_state: | 
					
						
						|  | hidden_states = hidden_states.reshape(target_surface_forms.shape[0], -1) | 
					
						
						|  | else: | 
					
						
						|  | hidden_states = hidden_states[:, 0] | 
					
						
						|  |  | 
					
						
						|  | predicted_embeddings = self.output_projection(hidden_states) | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_single_head: | 
					
						
						|  | predicted_embeddings_in = predicted_embeddings[..., : self.config.n_embd] | 
					
						
						|  |  | 
					
						
						|  | if self.has_separate_out_embeddings: | 
					
						
						|  | predicted_embeddings_out = predicted_embeddings[ | 
					
						
						|  | ..., self.config.n_embd : | 
					
						
						|  | ] | 
					
						
						|  | else: | 
					
						
						|  | predicted_embeddings_out = None | 
					
						
						|  | else: | 
					
						
						|  | predicted_embeddings_in = predicted_embeddings | 
					
						
						|  | if self.has_separate_out_embeddings: | 
					
						
						|  | predicted_embeddings_out = self.output_projection_out(hidden_states) | 
					
						
						|  | else: | 
					
						
						|  | predicted_embeddings_out = None | 
					
						
						|  |  | 
					
						
						|  | if self.config.hn_rescale_embeddings: | 
					
						
						|  | predicted_embeddings_in = self.scaler(predicted_embeddings_in) | 
					
						
						|  |  | 
					
						
						|  | if predicted_embeddings_out is not None: | 
					
						
						|  | predicted_embeddings_out = self.out_scaler(predicted_embeddings_out) | 
					
						
						|  |  | 
					
						
						|  | if getattr(self.config, "hn_predict_bias", False): | 
					
						
						|  | predicted_bias = self.bias_projection(hidden_states)[..., 0] | 
					
						
						|  | else: | 
					
						
						|  | predicted_bias = torch.zeros_like( | 
					
						
						|  | target_surface_forms[..., 0], dtype=self.dtype | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return predicted_embeddings_in, predicted_embeddings_out, predicted_bias | 
					
						
						|  |  |