momergul commited on
Commit
3805b84
·
verified ·
1 Parent(s): 17b25b8

Upload modeling_git.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_git.py +100 -0
modeling_git.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ from transformers import ViTFeatureExtractor, ViTModel, ViTConfig
4
+ from typing import List, Optional, Tuple, Union
5
+ import warnings
6
+ import ipdb
7
+ import os
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from itertools import product
12
+ import numpy as np
13
+ import transformers.models.git.modeling_git as modeling_git
14
+ import transformers.models.vit.modeling_vit as modeling_vit
15
+ from transformers.models.opt.modeling_opt import OPTConfig
16
+ import transformers.models.opt.modeling_opt as hg_opt
17
+ import transformers.models.clip.modeling_clip as modeling_clip
18
+
19
+
20
+ class GitForCausalLM(modeling_git.GitForCausalLM):
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+
24
+ del self.output
25
+ self.output = nn.Linear(
26
+ self.config.hidden_size,
27
+ self.config.vocab_size,
28
+ bias=False)
29
+ self.post_init()
30
+
31
+ del self.git.image_encoder
32
+ self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16')
33
+ dino_cfg = self.git.image_encoder.config
34
+ config = self.git.config
35
+ config.vision_config.hidden_size = dino_cfg.hidden_size
36
+
37
+ del self.git.visual_projection
38
+ self.git.visual_projection = modeling_git.GitProjection(config)
39
+ num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1
40
+ self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: Optional[torch.Tensor] = None,
45
+ attention_mask: Optional[torch.Tensor] = None,
46
+ position_ids: Optional[torch.Tensor] = None,
47
+ pixel_values: Optional[torch.Tensor] = None,
48
+ head_mask: Optional[torch.Tensor] = None,
49
+ inputs_embeds: Optional[torch.Tensor] = None,
50
+ labels: Optional[torch.Tensor] = None,
51
+ past_key_values: Optional[List[torch.Tensor]] = None,
52
+ use_cache: Optional[bool] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ ) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]:
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+ if labels is not None:
59
+ use_cache = False
60
+
61
+ outputs = self.git(
62
+ input_ids,
63
+ attention_mask=attention_mask,
64
+ position_ids=position_ids,
65
+ pixel_values=pixel_values,
66
+ head_mask=head_mask,
67
+ inputs_embeds=inputs_embeds,
68
+ past_key_values=past_key_values,
69
+ use_cache=use_cache,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict,
73
+ )
74
+
75
+ sequence_output = outputs[0]
76
+ logits = self.output(sequence_output)
77
+
78
+ loss = None
79
+ if labels is not None:
80
+ # we are doing next-token prediction; shift prediction scores and input ids by one
81
+ if pixel_values is not None:
82
+ num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
83
+ else:
84
+ num_image_tokens = 0
85
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
86
+ labels = labels[:, 1:].contiguous()
87
+ loss_fct = CrossEntropyLoss()
88
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
89
+
90
+ if not return_dict:
91
+ output = (logits,) + outputs[1:]
92
+ return ((loss,) + output) if loss is not None else output
93
+
94
+ return modeling_git.CausalLMOutputWithPast(
95
+ loss=loss,
96
+ logits=logits,
97
+ past_key_values=outputs.past_key_values,
98
+ hidden_states=outputs.hidden_states,
99
+ attentions=outputs.attentions,
100
+ )