|
|
import torch |
|
|
from typing import Optional |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, model): |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
model.load_state_dict(checkpoint["model"]) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
def generate_text( |
|
|
model, |
|
|
data_processor, |
|
|
prompt: str, |
|
|
max_new_tokens: int, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
device: str = "cpu", |
|
|
): |
|
|
model.eval() |
|
|
tokens = data_processor.tokenize(prompt) |
|
|
input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
if input_ids.size(1) > model.config.max_token_len: |
|
|
input_ids = input_ids[:, -model.config.max_token_len :] |
|
|
|
|
|
logits = model(input_ids) |
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float("inf") |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
input_ids = torch.cat((input_ids, next_token), dim=1) |
|
|
|
|
|
output_tokens = input_ids[0].tolist() |
|
|
generated_text = data_processor.detokenize(output_tokens) |
|
|
return generated_text |
|
|
|