sky-2002 commited on
Commit
13a7e6f
·
verified ·
1 Parent(s): 162c0d8

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +45 -0
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def load_model(checkpoint_path, model):
7
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
8
+ model.load_state_dict(checkpoint["model"])
9
+ model.eval()
10
+ return model
11
+
12
+
13
+ def generate_text(
14
+ model,
15
+ data_processor,
16
+ prompt: str,
17
+ max_new_tokens: int,
18
+ temperature: float = 1.0,
19
+ top_k: Optional[int] = None,
20
+ device: str = "cpu",
21
+ ):
22
+ model.eval()
23
+ tokens = data_processor.tokenize(prompt)
24
+ input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
25
+
26
+ with torch.no_grad():
27
+ for _ in range(max_new_tokens):
28
+ # crop input_ids if it exceeds the context size
29
+ if input_ids.size(1) > model.config.max_token_len:
30
+ input_ids = input_ids[:, -model.config.max_token_len :]
31
+
32
+ logits = model(input_ids)
33
+ logits = logits[:, -1, :] / temperature # get the logits for the last token
34
+
35
+ if top_k is not None:
36
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
37
+ logits[logits < v[:, [-1]]] = -float("inf")
38
+
39
+ probs = F.softmax(logits, dim=-1)
40
+ next_token = torch.multinomial(probs, num_samples=1)
41
+ input_ids = torch.cat((input_ids, next_token), dim=1)
42
+
43
+ output_tokens = input_ids[0].tolist()
44
+ generated_text = data_processor.detokenize(output_tokens)
45
+ return generated_text