sky-2002's picture
Upload deepseek_tinystories/utils.py
d26d01f verified
import torch
from typing import Optional
import torch.nn.functional as F
def load_model(checkpoint_path, model):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.eval()
return model
def generate_text(
model,
data_processor,
prompt: str,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
device: str = "cpu",
):
model.eval()
tokens = data_processor.tokenize(prompt)
input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
for _ in range(max_new_tokens):
# crop input_ids if it exceeds the context size
if input_ids.size(1) > model.config.max_token_len:
input_ids = input_ids[:, -model.config.max_token_len :]
logits = model(input_ids)
logits = logits[:, -1, :] / temperature # get the logits for the last token
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=1)
output_tokens = input_ids[0].tolist()
generated_text = data_processor.detokenize(output_tokens)
return generated_text