LRL commited on
Commit
c808d93
·
1 Parent(s): f45ad20

add quant.py

Browse files
Files changed (1) hide show
  1. quant.py +174 -0
quant.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
6
+
7
+
8
+ pretrained_model_dir = "/root/Yi-6B/"
9
+ quantized_model_dir = "/root/Yi-6B/quant/"
10
+
11
+
12
+ os.makedirs(quantized_model_dir, exist_ok=True)
13
+ def get_wikitext2(nsamples, seed, seqlen, model):
14
+ from datasets import load_dataset
15
+
16
+ traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
17
+ testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
18
+
19
+ from transformers import AutoTokenizer
20
+
21
+ try:
22
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
23
+ except Exception:
24
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
25
+ trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
26
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
27
+
28
+ import random
29
+
30
+ random.seed(seed)
31
+ np.random.seed(0)
32
+ torch.random.manual_seed(0)
33
+
34
+ traindataset = []
35
+ for _ in range(nsamples):
36
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
37
+ j = i + seqlen
38
+ inp = trainenc.input_ids[:, i:j]
39
+ attention_mask = torch.ones_like(inp)
40
+ traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
41
+ return traindataset, testenc
42
+
43
+
44
+ @torch.no_grad()
45
+ def opt_eval(model, testenc, dev, seqlen=2048):
46
+ print("Evaluating ...")
47
+
48
+ testenc = testenc.input_ids
49
+ nsamples = testenc.numel() // seqlen
50
+
51
+ use_cache = model.config.use_cache
52
+ model.config.use_cache = False
53
+ layers = model.model.decoder.layers
54
+
55
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
56
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
57
+ if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
58
+ model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
59
+ if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
60
+ model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
61
+ layers[0] = layers[0].to(dev)
62
+
63
+ dtype = next(iter(model.parameters())).dtype
64
+ inps = torch.zeros((nsamples, seqlen, model.config.hidden_size), dtype=dtype, device=dev)
65
+ cache = {"i": 0, "attention_mask": None}
66
+
67
+ class Catcher(nn.Module):
68
+ def __init__(self, module):
69
+ super().__init__()
70
+ self.module = module
71
+
72
+ def forward(self, inp, **kwargs):
73
+ inps[cache["i"]] = inp
74
+ cache["i"] += 1
75
+ cache["attention_mask"] = kwargs["attention_mask"]
76
+ raise ValueError
77
+
78
+ layers[0] = Catcher(layers[0])
79
+ for i in range(nsamples):
80
+ batch = testenc[:, (i * seqlen) : ((i + 1) * seqlen)].to(dev)
81
+ try:
82
+ model(batch)
83
+ except ValueError:
84
+ pass
85
+ layers[0] = layers[0].module
86
+
87
+ layers[0] = layers[0].cpu()
88
+ model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
89
+ model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
90
+ if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
91
+ model.model.decoder.project_out = model.model.decoder.project_out.cpu()
92
+ if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
93
+ model.model.decoder.project_in = model.model.decoder.project_in.cpu()
94
+ torch.cuda.empty_cache()
95
+
96
+ outs = torch.zeros_like(inps)
97
+ attention_mask = cache["attention_mask"]
98
+
99
+ for i in range(len(layers)):
100
+ print(i)
101
+ layer = layers[i].to(dev)
102
+
103
+ for j in range(nsamples):
104
+ outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
105
+ layers[i] = layer.cpu()
106
+ del layer
107
+ torch.cuda.empty_cache()
108
+ inps, outs = outs, inps
109
+
110
+ if model.model.decoder.final_layer_norm is not None:
111
+ model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
112
+ if model.model.decoder.project_out is not None:
113
+ model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
114
+ model.lm_head = model.lm_head.to(dev)
115
+
116
+ testenc = testenc.to(dev)
117
+ nlls = []
118
+ for i in range(nsamples):
119
+ hidden_states = inps[i].unsqueeze(0)
120
+ if model.model.decoder.final_layer_norm is not None:
121
+ hidden_states = model.model.decoder.final_layer_norm(hidden_states)
122
+ if model.model.decoder.project_out is not None:
123
+ hidden_states = model.model.decoder.project_out(hidden_states)
124
+ lm_logits = model.lm_head(hidden_states)
125
+ shift_logits = lm_logits[:, :-1, :].contiguous()
126
+ shift_labels = testenc[:, (i * seqlen) : ((i + 1) * seqlen)][:, 1:]
127
+ loss_fct = nn.CrossEntropyLoss()
128
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
129
+ neg_log_likelihood = loss.float() * seqlen
130
+ nlls.append(neg_log_likelihood)
131
+ ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
132
+ print(ppl.item())
133
+
134
+ model.config.use_cache = use_cache
135
+
136
+
137
+ def main():
138
+ traindataset, testenc = get_wikitext2(128, 0, 2048, pretrained_model_dir)
139
+
140
+ quantize_config = BaseQuantizeConfig(
141
+ bits=4, # quantize model to 4-bit
142
+ group_size=128, # it is recommended to set the value to 128
143
+ desc_act=False, # desc_act and group size only works on triton
144
+ )
145
+
146
+ # load un-quantized model, the model will always be force loaded into cpu
147
+ model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
148
+
149
+ # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
150
+ # with value under torch.LongTensor type.
151
+ model.quantize(traindataset, use_triton=False)
152
+
153
+ # save quantized model
154
+ model.save_quantized(quantized_model_dir)
155
+
156
+ # save quantized model using safetensors
157
+ model.save_quantized(quantized_model_dir, use_safetensors=True)
158
+
159
+ # load quantized model, currently only support cpu or single gpu
160
+ model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)
161
+
162
+ opt_eval(model.model, testenc, "cuda:0")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ import logging
167
+
168
+ logging.basicConfig(
169
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
170
+ level=logging.INFO,
171
+ datefmt="%Y-%m-%d %H:%M:%S",
172
+ )
173
+
174
+ main()