Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import concurrent.futures | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load the model and tokenizer (using GPT-2 as an example) | |
| model_name = "gpt2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model.eval() | |
| torch.set_num_threads(2) | |
| def min_p_sampling(logits, pbase=0.1): | |
| """ | |
| Perform min-p sampling on the logits. As described in | |
| https://arxiv.org/abs/2407.01082 | |
| Args: | |
| logits (torch.Tensor): 1D tensor of logits for the next token. | |
| pbase (float): Base probability to scale pmax. | |
| Returns: | |
| int: The sampled token index. | |
| """ | |
| # Convert logits to probabilities. | |
| probs = torch.softmax(logits, dim=-1) | |
| # 1. Find maximum probability. | |
| pmax = probs.max() | |
| # 2. Compute the dynamic threshold. | |
| pscaled = pbase * pmax | |
| # 3. Create a mask of tokens with probability >= pscaled. | |
| mask = probs >= pscaled | |
| # In the unlikely event that no token meets the threshold, use the full distribution. | |
| if mask.sum() == 0: | |
| mask = torch.ones_like(probs, dtype=torch.bool) | |
| probs_filtered = probs * mask.float() | |
| # 4. Normalize and sample. | |
| probs_normalized = probs_filtered / probs_filtered.sum() | |
| sampled_index = torch.multinomial(probs_normalized, num_samples=1) | |
| return sampled_index.item() | |
| def generate_laconic_completion(prompt: str, n: int = 5, max_length: int = 100): | |
| # generate n completions greedily and return the shortest one | |
| with torch.no_grad(): | |
| # Encode the prompt and get the attention mask. | |
| encoded = tokenizer(prompt, return_tensors="pt") | |
| input_ids = encoded["input_ids"] | |
| attention_mask = encoded["attention_mask"] | |
| # Generate the output. | |
| outputs = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_length=max_length, | |
| num_return_sequences=n, | |
| do_sample=True, | |
| ) | |
| completions = [ | |
| tokenizer.decode(output, skip_special_tokens=True) for output in outputs | |
| ] | |
| return min(completions, key=len) | |
| def generate_with_confidence(input_ids, max_length): | |
| """ | |
| Generate a sequence using greedy decoding while returning the scores. | |
| """ | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=max_length, | |
| do_sample=False, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| ) | |
| return outputs | |
| def compute_answer_confidence(outputs): | |
| """ | |
| Compute the answer confidence over the generated tokens. | |
| For each generated token, compute the difference between the top-1 and top-2 logits. | |
| Returns the average difference. | |
| """ | |
| diffs = [] | |
| for score in outputs.scores: | |
| # Get top-2 logit values | |
| top2 = torch.topk(score[0], 2) | |
| diff = top2.values[0] - top2.values[1] | |
| diffs.append(diff.item()) | |
| return sum(diffs) / len(diffs) if diffs else 0.0 | |
| def cot_decoding(prompt, k=5, max_length=100): | |
| """ | |
| Perform Chain-of-Thought (CoT) decoding by exploring top-k alternative paths. | |
| """ | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
| # Get logits for the next token | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits[0, -1, :] | |
| # Get top-k candidate tokens | |
| topk = torch.topk(logits, k) | |
| candidate_tokens = topk.indices | |
| paths = [] | |
| for token in candidate_tokens: | |
| # Append the candidate token to the prompt | |
| new_input_ids = torch.cat([input_ids, token.view(1, 1)], dim=1) | |
| # Generate a full sequence with output scores | |
| gen_outputs = generate_with_confidence( | |
| new_input_ids, max_length=new_input_ids.shape[1] + max_length | |
| ) | |
| # Decode the generated sequence | |
| generated_text = tokenizer.decode( | |
| gen_outputs.sequences[0], skip_special_tokens=True | |
| ) | |
| # Compute answer confidence | |
| confidence = compute_answer_confidence(gen_outputs) | |
| paths.append({"text": generated_text, "confidence": confidence}) | |
| return max(paths, key=lambda x: x["confidence"])["text"] | |
| def generate_completion(prompt, strategy, params): | |
| """ | |
| Generate a complete answer using model.generate with specified parameters. | |
| """ | |
| with torch.no_grad(): | |
| # Encode the prompt and get the attention mask. | |
| encoded = tokenizer(prompt, return_tensors="pt") | |
| input_ids = encoded["input_ids"] | |
| attention_mask = encoded["attention_mask"] | |
| # Generate the output. | |
| output_ids = model.generate( | |
| input_ids, attention_mask=attention_mask, max_length=100, **params | |
| ) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| def generate_min_p_completion(prompt, pbase=0.1, max_length=100): | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
| past = None | |
| with torch.no_grad(): | |
| for _ in range(max_length - input_ids.size(1)): | |
| # Only pass the last token if past is available | |
| outputs = ( | |
| model(input_ids[:, -1:], past_key_values=past) | |
| if past is not None | |
| else model(input_ids) | |
| ) | |
| past = outputs.past_key_values | |
| logits = outputs.logits[:, -1, :] | |
| next_token = min_p_sampling(logits, pbase=pbase) | |
| input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1) | |
| if next_token == tokenizer.eos_token_id: | |
| break | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| def generate_all(prompt): | |
| """ | |
| Run multiple decoding strategies concurrently and yield updates as each completes. | |
| """ | |
| # Define each decoding strategy and its parameters. | |
| methods = { | |
| "Greedy": {"type": "default", "params": {"do_sample": False}}, | |
| "Top-k Sampling": { | |
| "type": "default", | |
| "params": {"do_sample": True, "top_k": 100}, | |
| }, | |
| "Top-p Sampling": { | |
| "type": "default", | |
| "params": {"do_sample": True, "top_p": 0.95}, | |
| }, | |
| "Beam Search": { | |
| "type": "default", | |
| "params": {"num_beams": 5, "early_stopping": True}, | |
| }, | |
| "Eta Sampling": { | |
| "type": "default", | |
| "params": {"do_sample": True, "eta_cutoff": 0.3}, | |
| }, | |
| "Epsilon Sampling": { | |
| "type": "default", | |
| "params": {"do_sample": True, "epsilon_cutoff": 0.2}, | |
| }, | |
| "Min-p Sampling": {"type": "min_p", "pbase": 0.1}, | |
| "laconic": { | |
| "type": "default", | |
| "params": {"do_sample": True, "num_return_sequences": 5}, | |
| }, | |
| "COT Decoding": { | |
| "type": "cot_decoding", | |
| "params": {"k": 5, "max_length": 100}, | |
| }, | |
| } | |
| # Define the order for display. | |
| method_order = [ | |
| "Greedy", | |
| "Top-k Sampling", | |
| "Top-p Sampling", | |
| "Beam Search", | |
| "Min-p Sampling", | |
| "Eta Sampling", | |
| "Epsilon Sampling", | |
| "laconic", | |
| "COT Decoding", | |
| ] | |
| results = {method: None for method in methods} | |
| # Yield an initial placeholder state. | |
| yield tuple("Processing..." for _ in method_order) | |
| # Use a thread pool to run each generation concurrently. | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future_to_method = {} | |
| for method, info in methods.items(): | |
| if info["type"] == "default": | |
| future = executor.submit( | |
| generate_completion, prompt, method, info["params"] | |
| ) | |
| elif info["type"] == "min_p": | |
| future = executor.submit( | |
| generate_min_p_completion, prompt, info["pbase"] | |
| ) | |
| elif method == "laconic": | |
| future = executor.submit(generate_laconic_completion, prompt) | |
| elif method == "COT Decoding": | |
| future = executor.submit(cot_decoding, prompt, **info["params"]) | |
| future_to_method[future] = method | |
| # As each future completes, update its result and yield the current state. | |
| for future in concurrent.futures.as_completed(future_to_method): | |
| method = future_to_method[future] | |
| try: | |
| result = future.result() | |
| except Exception as exc: | |
| result = f"Error: {exc}" | |
| results[method] = result | |
| # Yield the results in the pre-defined order; pending methods show "Processing..." | |
| yield tuple( | |
| results[m] if results[m] is not None else "Processing..." | |
| for m in method_order | |
| ) | |
| # Create the Gradio interface. | |
| interface = gr.Interface( | |
| fn=generate_all, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"), | |
| outputs=[ | |
| gr.Textbox(label="Greedy"), | |
| gr.Textbox(label="Top-k Sampling"), | |
| gr.Textbox(label="Top-p Sampling"), | |
| gr.Textbox(label="Beam Search"), | |
| gr.Textbox(label="Min-p Sampling (as in https://arxiv.org/abs/2407.01082)"), | |
| gr.Textbox(label="Eta Sampling"), | |
| gr.Textbox(label="Epsilon Sampling"), | |
| gr.Textbox( | |
| label="laconic decoding (by Alex Dimakis, 2025, search for twitter thread)" | |
| ), | |
| gr.Textbox( | |
| label="COT Decoding (Chain-of-Thought Reasoning without Prompting, Wang, Zhou, 2024)" | |
| ), | |
| ], | |
| title="Decoding Methods Comparison", | |
| description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.", | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |