Spaces:
Running
Running
| import re | |
| from functools import partial | |
| import nltk | |
| def get_len(tokenizer, text): | |
| return len(tokenizer.encode(text, add_special_tokens=False)) | |
| class Truncater: | |
| def __init__(self, tokenizer, *, max_length): | |
| self.max_length = max_length | |
| self.tokenizer = tokenizer | |
| def __call__(self, text): | |
| return self.truncate(text) | |
| def truncate(self, text): | |
| input_ids = self.tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=self.max_length) | |
| return self.tokenizer.decode(input_ids) | |
| class Refiner: | |
| def __init__(self, tokenizer, *, chunk_size, max_chunk_size): | |
| assert chunk_size <= max_chunk_size | |
| self.chunk_size = chunk_size | |
| self.max_chunk_size = max_chunk_size | |
| self.tokenizer = tokenizer | |
| self.get_len = partial(get_len, tokenizer) | |
| self.current_summary = None | |
| self.chunks = [] | |
| self.initial_prompt = "" | |
| self.chunk_prefix = "" | |
| self.summary_prefix = "" | |
| self.refinement_prompt = "" | |
| def set_prompts(self, *, initial_prompt="", chunk_prefix="", summary_prefix="", refinement_prompt=""): | |
| self.initial_prompt = initial_prompt | |
| self.chunk_prefix = chunk_prefix | |
| self.summary_prefix = summary_prefix | |
| self.refinement_prompt = refinement_prompt | |
| def current_prompt(self): | |
| if self.current_summary is None: | |
| return self.initial_prompt | |
| else: | |
| return self.refinement_prompt | |
| def __call__(self, text): | |
| self.chunks = Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) | |
| return self.refine(text) | |
| def __len__(self): | |
| return len(self.chunks) | |
| def refine(self, text): | |
| for chunk in self.chunks: | |
| if self.current_summary is None: | |
| yield chunk | |
| else: | |
| summary = self.summary_prefix + self.current_summary | |
| chunk = self.chunk_prefix + chunk | |
| yield summary + "\n\n" + chunk | |
| def set_current_summary(self, summary): | |
| self.current_summary = summary | |
| class Chunker: | |
| def __init__(self, tokenizer, *, chunk_size, max_chunk_size): | |
| assert chunk_size <= max_chunk_size | |
| self.chunk_size = chunk_size # target chunk size | |
| self.max_chunk_size = max_chunk_size # hard limit | |
| self.tokenizer = tokenizer | |
| self.get_len = partial(get_len, tokenizer) | |
| def __call__(self, text): | |
| return Chunker.chunk_text(text, self.chunk_size, self.max_chunk_size, self.get_len) | |
| def chunk_text(text, chunk_size, max_chunk_size, len_fn): | |
| paragraphs = re.split("\n\n|\n(?=[^\n])", text) | |
| text = " ".join(paragraphs) | |
| sentences = nltk.sent_tokenize(text) | |
| sentences = [s.strip() for s in sentences] | |
| chunks = [] | |
| Chunker._chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn) | |
| return chunks | |
| def _chunk_text(sentences, chunks, chunk_size, max_chunk_size, len_fn): | |
| if not sentences: | |
| return | |
| remaining_text = " ".join(sentences) | |
| if len_fn(remaining_text) <= max_chunk_size: | |
| chunks.append(remaining_text) | |
| return | |
| index = 0 | |
| length_so_far = 0 | |
| while index < len(sentences) and length_so_far + len_fn(sentences[index]) <= chunk_size: | |
| length_so_far += len_fn(sentences[index]) | |
| index += 1 | |
| if index == 0: | |
| raise ValueError("No chunking possible") | |
| else: | |
| chunk = " ".join(sentences[:index]) | |
| chunks.append(chunk) | |
| Chunker._chunk_text(sentences[index:], chunks, chunk_size, max_chunk_size, len_fn) | |