kouki321 commited on
Commit
7858597
·
verified ·
1 Parent(s): d73fe5d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers.cache_utils import DynamicCache
4
+ import os
5
+
6
+ # Minimal generate function for token-by-token generation
7
+ def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
8
+ device = model.model.embed_tokens.weight.device
9
+ origin_len = input_ids.shape[-1]
10
+ input_ids = input_ids.to(device)
11
+ output_ids = input_ids.clone()
12
+ next_token = input_ids
13
+
14
+ with torch.no_grad():
15
+ for _ in range(max_new_tokens):
16
+ out = model(
17
+ input_ids=next_token,
18
+ past_key_values=past_key_values,
19
+ use_cache=True
20
+ )
21
+ logits = out.logits[:, -1, :]
22
+ token = torch.argmax(logits, dim=-1, keepdim=True)
23
+ output_ids = torch.cat([output_ids, token], dim=-1)
24
+ past_key_values = out.past_key_values
25
+ next_token = token.to(device)
26
+
27
+ if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
28
+ break
29
+
30
+ # Return just the newly generated part
31
+ return output_ids[:, origin_len:]
32
+ torch.serialization.add_safe_globals([DynamicCache])
33
+ torch.serialization.add_safe_globals([set])
34
+
35
+ def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
36
+ # Encode prompt
37
+ device = model.model.embed_tokens.weight.device
38
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
39
+ cache = DynamicCache()# it grows as text is generated
40
+ #Run the model to populate the KV cache:
41
+ with torch.no_grad():
42
+ _ = model(
43
+ input_ids=input_ids,
44
+ past_key_values=cache,
45
+ use_cache=True
46
+ )
47
+ return cache
48
+
49
+ def clean_up(cache: DynamicCache, origin_len: int):
50
+ # Remove any tokens appended to the original knowledge
51
+ for i in range(len(cache.key_cache)):
52
+ cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
53
+ cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
54
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
56
+ #token=HF_TOKEN,
57
+ trust_remote_code=True)
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
+ device_map="auto",
62
+ trust_remote_code=True,
63
+ # token=HF_TOKEN
64
+ )
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ model.to(device)
67
+ print(f"Loaded {model_name}.")
68
+ if not os.path.exists("/kaggle/input/delice/delice.txt"):
69
+ raise FileNotFoundError("Please create a `document.txt` .")
70
+
71
+ with open("/kaggle/input/delice/delice.txt", "r", encoding="utf-8") as f:
72
+ doc_text = f.read()
73
+
74
+ system_prompt = f"""
75
+ <|system|>
76
+ Answer concisely and precisely, You are an assistant who provides concise factual answers.
77
+ <|user|>
78
+ Context:
79
+ {doc_text}
80
+ Question:
81
+ """.strip()
82
+
83
+ # Build the cache
84
+ ronan_cache = get_kv_cache(model, tokenizer, system_prompt)
85
+ torch.save(ronan_cache, "/kaggle/working/ronan_caches.pth")
86
+ origin_len = ronan_cache.key_cache[0].shape[-2]
87
+ print("KV cache built.")