Spaces:
Sleeping
Sleeping
| """ | |
| aggregate.py - module for 'reducing' multiple 'summary chunks' into one | |
| an overly complicated class for legacy compatibility reasons, for usage of the | |
| 2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage | |
| """ | |
| import logging | |
| import pprint as pp | |
| import time | |
| import torch | |
| from transformers import GenerationConfig, pipeline | |
| # Setting up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| class BatchAggregator: | |
| """ | |
| BatchAggregator is a class for aggregating text from multiple sources. | |
| Usage: | |
| from aggregate import BatchAggregator | |
| aggregator = BatchAggregator() | |
| agg = aggregator.infer_aggregate(["This is a test", "This is another test"]) | |
| print(agg) | |
| """ | |
| GENERIC_CONFIG = GenerationConfig( | |
| max_new_tokens=512, | |
| num_beams=4, | |
| early_stopping=True, | |
| do_sample=False, | |
| truncation=True, | |
| ) | |
| def __init__( | |
| self, | |
| model_name: str = "pszemraj/bart-large-summary-map-reduce", | |
| force_cpu: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| __init__ initializes the BatchAggregator class. | |
| :param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce" | |
| :param bool force_cpu: force the model to run on CPU, default: False | |
| """ | |
| self.device = None | |
| self.is_compiled = False | |
| self.model_name = None | |
| self.aggregator = None | |
| self.force_cpu = force_cpu | |
| self.logger = logging.getLogger(__name__) | |
| self.init_model(model_name) | |
| def init_model(self, model_name: str) -> None: | |
| """ | |
| Initialize the model. | |
| :param model_name: The name of the model to use. | |
| """ | |
| # Free up memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.logger.info(f"Setting model to {model_name}") | |
| self.model_name = model_name | |
| self.aggregator = self._create_pipeline(model_name) | |
| self._configure_model() | |
| def _create_pipeline( | |
| self, model_name: str = "pszemraj/bart-large-summary-map-reduce" | |
| ) -> pipeline: | |
| """ | |
| _create_pipeline creates a pipeline for the model. | |
| :param str model_name: model name to use | |
| :return pipeline: the pipeline for the model | |
| :raises Exception: if the pipeline cannot be created | |
| """ | |
| device_map = ( | |
| "auto" if torch.cuda.is_available() and not self.force_cpu else "cpu" | |
| ) | |
| try: | |
| self.logger.info( | |
| f"Creating pipeline with model {model_name} on device {device_map}" | |
| ) | |
| return pipeline( | |
| "text2text-generation", | |
| model=model_name, | |
| device_map=device_map, | |
| torch_dtype=torch.float32, | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Failed to create pipeline: {e}") | |
| raise | |
| def _configure_model(self): | |
| """ | |
| Configure the model for generation. | |
| """ | |
| try: | |
| self.aggregator.model = torch.compile(self.aggregator.model) | |
| self.is_compiled = True | |
| except Exception as e: | |
| self.logger.warning(f"Could not compile model with Torch 2.0: {e}") | |
| self._set_default_generation_config() | |
| self.logger.info(self.aggregator.model.generation_config.to_json_string()) | |
| def _set_default_generation_config(self): | |
| """ | |
| Set the default generation configuration for the model. | |
| """ | |
| self.aggregator.model.generation_config.update( | |
| **self.GENERIC_CONFIG.to_diff_dict() | |
| ) | |
| def update_generation_config(self, **kwargs): | |
| """ | |
| Update the generation configuration with the specified parameters. | |
| Args: | |
| **kwargs: The parameters to update in the generation configuration. | |
| """ | |
| self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") | |
| self.aggregator.model.generation_config.update(**kwargs) | |
| def get_generation_config(self) -> dict: | |
| """ | |
| Get the current generation configuration. | |
| Returns: | |
| dict: The current generation configuration. | |
| """ | |
| return self.aggregator.model.generation_config.to_dict() | |
| def update_loglevel(self, level: str = "INFO"): | |
| """ | |
| Update the log level. | |
| Args: | |
| level (str): The log level to set. Defaults to "INFO". | |
| """ | |
| self.logger.setLevel(level) | |
| def infer_aggregate( | |
| self, | |
| text_list: list, | |
| instruction: str = None, # Kept for backward compatibility but not used | |
| **kwargs, | |
| ) -> str: | |
| """ | |
| infer_aggregate - infers a consolidated summary from a list of texts. | |
| Args: | |
| text_list (list): The texts to summarize. | |
| instruction (str): Not used by this model, kept for compatibility. | |
| **kwargs: Additional parameters to update in the generation configuration. | |
| Returns: | |
| The generated summary. | |
| """ | |
| joined_text = "\n\n".join(text_list) | |
| if kwargs: | |
| self.update_generation_config(**kwargs) | |
| st = time.perf_counter() | |
| self.logger.info(f"inference on {len(text_list)} texts ...") | |
| result = self.aggregator( | |
| joined_text, | |
| generation_config=self.aggregator.model.generation_config, | |
| )[0]["generated_text"] | |
| self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") | |
| self.logger.info( | |
| f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}" | |
| ) | |
| self.logger.debug(f"Generated text:\n{result}") | |
| return result | |
| def count_tokens(self, text: str) -> int: | |
| """count the number of tokens in a text""" | |
| return ( | |
| len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) | |
| if text | |
| else 0 | |
| ) | |