diff --git "a/tools/dedup_summaries.py" "b/tools/dedup_summaries.py" new file mode 100644--- /dev/null +++ "b/tools/dedup_summaries.py" @@ -0,0 +1,2905 @@ +import os +import re +import time +from typing import List, Tuple + +import boto3 +import gradio as gr +import markdown +import pandas as pd +import spaces +from rapidfuzz import fuzz, process +from tqdm import tqdm + +from tools.aws_functions import connect_to_bedrock_runtime +from tools.config import ( + BATCH_SIZE_DEFAULT, + CHOSEN_LOCAL_MODEL_TYPE, + DEDUPLICATION_THRESHOLD, + DEFAULT_SAMPLED_SUMMARIES, + LLM_CONTEXT_LENGTH, + LLM_MAX_NEW_TOKENS, + LLM_SEED, + MAX_COMMENT_CHARS, + MAX_GROUPS, + MAX_SPACES_GPU_RUN_TIME, + MAX_TIME_FOR_LOOP, + NUMBER_OF_RETRY_ATTEMPTS, + OUTPUT_DEBUG_FILES, + OUTPUT_FOLDER, + REASONING_SUFFIX, + RUN_LOCAL_MODEL, + TIMEOUT_WAIT, + model_name_map, +) +from tools.helper_functions import ( + clean_column_name, + convert_reference_table_to_pivot_table, + create_batch_file_path_details, + create_topic_summary_df_from_reference_table, + ensure_model_in_map, + generate_zero_shot_topics_df, + get_basic_response_data, + get_file_name_no_ext, + load_in_data_file, + read_file, + wrap_text, +) +from tools.llm_funcs import ( + calculate_tokens_from_metadata, + call_llm_with_markdown_table_checks, + construct_azure_client, + construct_gemini_generative_model, + get_assistant_model, + get_model, + get_tokenizer, + process_requests, +) +from tools.prompts import ( + comprehensive_summary_format_prompt, + comprehensive_summary_format_prompt_by_group, + llm_deduplication_prompt, + llm_deduplication_prompt_with_candidates, + llm_deduplication_system_prompt, + summarise_everything_prompt, + summarise_everything_system_prompt, + summarise_topic_descriptions_prompt, + summarise_topic_descriptions_system_prompt, + summary_assistant_prefill, + system_prompt, +) + +max_tokens = LLM_MAX_NEW_TOKENS +timeout_wait = TIMEOUT_WAIT +number_of_api_retry_attempts = NUMBER_OF_RETRY_ATTEMPTS +max_time_for_loop = MAX_TIME_FOR_LOOP +batch_size_default = BATCH_SIZE_DEFAULT +deduplication_threshold = DEDUPLICATION_THRESHOLD +max_comment_character_length = MAX_COMMENT_CHARS +reasoning_suffix = REASONING_SUFFIX +output_debug_files = OUTPUT_DEBUG_FILES +default_number_of_sampled_summaries = DEFAULT_SAMPLED_SUMMARIES +max_text_length = 500 + + +# DEDUPLICATION/SUMMARISATION FUNCTIONS +def deduplicate_categories( + category_series: pd.Series, + join_series: pd.Series, + reference_df: pd.DataFrame, + general_topic_series: pd.Series = None, + merge_general_topics="No", + merge_sentiment: str = "No", + threshold: float = 90, +) -> pd.DataFrame: + """ + Deduplicates similar category names in a pandas Series based on a fuzzy matching threshold, + merging smaller topics into larger topics. + + Parameters: + category_series (pd.Series): Series containing category names to deduplicate. + join_series (pd.Series): Additional series used for joining back to original results. + reference_df (pd.DataFrame): DataFrame containing the reference data to count occurrences. + threshold (float): Similarity threshold for considering two strings as duplicates. + + Returns: + pd.DataFrame: DataFrame with columns ['old_category', 'deduplicated_category']. + """ + # Count occurrences of each category in the reference_df + category_counts = reference_df["Subtopic"].value_counts().to_dict() + + # Initialize dictionaries for both category mapping and scores + deduplication_map = {} + match_scores = {} # New dictionary to store match scores + + # First pass: Handle exact matches + for category in category_series.unique(): + if category in deduplication_map: + continue + + # Find all exact matches + exact_matches = category_series[ + category_series.str.lower() == category.lower() + ].index.tolist() + if len(exact_matches) > 1: + # Find the variant with the highest count + match_counts = { + match: category_counts.get(category_series[match], 0) + for match in exact_matches + } + most_common = max(match_counts.items(), key=lambda x: x[1])[0] + most_common_category = category_series[most_common] + + # Map all exact matches to the most common variant and store score + for match in exact_matches: + deduplication_map[category_series[match]] = most_common_category + match_scores[category_series[match]] = ( + 100 # Exact matches get score of 100 + ) + + # Second pass: Handle fuzzy matches for remaining categories + # Create a DataFrame to maintain the relationship between categories and general topics + categories_df = pd.DataFrame( + {"category": category_series, "general_topic": general_topic_series} + ).drop_duplicates() + + for _, row in categories_df.iterrows(): + category = row["category"] + if category in deduplication_map: + continue + + current_general_topic = row["general_topic"] + + # Filter potential matches to only those within the same General topic if relevant + if merge_general_topics == "No": + potential_matches = categories_df[ + (categories_df["category"] != category) + & (categories_df["general_topic"] == current_general_topic) + ]["category"].tolist() + else: + potential_matches = categories_df[(categories_df["category"] != category)][ + "category" + ].tolist() + + matches = process.extract( + category, potential_matches, scorer=fuzz.WRatio, score_cutoff=threshold + ) + + if matches: + best_match = max(matches, key=lambda x: x[1]) + match, score, _ = best_match + + if category_counts.get(category, 0) < category_counts.get(match, 0): + deduplication_map[category] = match + match_scores[category] = score + else: + deduplication_map[match] = category + match_scores[match] = score + else: + deduplication_map[category] = category + match_scores[category] = 100 + + # Create the result DataFrame with scores + result_df = pd.DataFrame( + { + "old_category": category_series + " | " + join_series, + "deduplicated_category": category_series.map( + lambda x: deduplication_map.get(x, x) + ), + "match_score": category_series.map( + lambda x: match_scores.get(x, 100) + ), # Add scores column + } + ) + + # print(result_df) + + return result_df + + +def deduplicate_topics( + reference_df: pd.DataFrame, + topic_summary_df: pd.DataFrame, + reference_table_file_name: str, + unique_topics_table_file_name: str, + in_excel_sheets: str = "", + merge_sentiment: str = "No", + merge_general_topics: str = "No", + score_threshold: int = 90, + in_data_files: List[str] = list(), + chosen_cols: List[str] = "", + output_folder: str = OUTPUT_FOLDER, + deduplicate_topics: str = "Yes", +): + """ + Deduplicate topics based on a reference and unique topics table, merging similar topics. + + Args: + reference_df (pd.DataFrame): DataFrame containing reference data with topics. + topic_summary_df (pd.DataFrame): DataFrame summarizing unique topics. + reference_table_file_name (str): Base file name for the output reference table. + unique_topics_table_file_name (str): Base file name for the output unique topics table. + in_excel_sheets (str, optional): Comma-separated list of Excel sheet names to load. Defaults to "". + merge_sentiment (str, optional): Whether to merge topics regardless of sentiment ("Yes" or "No"). Defaults to "No". + merge_general_topics (str, optional): Whether to merge topics across different general topics ("Yes" or "No"). Defaults to "No". + score_threshold (int, optional): Fuzzy matching score threshold for deduplication. Defaults to 90. + in_data_files (List[str], optional): List of input data file paths. Defaults to []. + chosen_cols (List[str], optional): List of chosen columns from the input data files. Defaults to "". + output_folder (str, optional): Folder path to save output files. Defaults to OUTPUT_FOLDER. + deduplicate_topics (str, optional): Whether to perform topic deduplication ("Yes" or "No"). Defaults to "Yes". + """ + output_files = list() + log_output_files = list() + file_data = pd.DataFrame() + deduplicated_unique_table_markdown = "" + + if (len(reference_df["Response References"].unique()) == 1) | ( + len(topic_summary_df["Topic number"].unique()) == 1 + ): + print( + "Data file outputs are too short for deduplicating. Returning original data." + ) + + # Get file name without extension and create proper output paths + reference_table_file_name_no_ext = get_file_name_no_ext( + reference_table_file_name + ) + unique_topics_table_file_name_no_ext = get_file_name_no_ext( + unique_topics_table_file_name + ) + + # Create output paths with _dedup suffix to match normal path + reference_file_out_path = ( + output_folder + reference_table_file_name_no_ext + "_dedup.csv" + ) + unique_topics_file_out_path = ( + output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv" + ) + + # Save the DataFrames to CSV files + reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + reference_file_out_path, index=None, encoding="utf-8-sig" + ) + topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + unique_topics_file_out_path, index=None, encoding="utf-8-sig" + ) + + output_files.append(reference_file_out_path) + output_files.append(unique_topics_file_out_path) + + # Create markdown output for display + topic_summary_df_revised_display = topic_summary_df.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + deduplicated_unique_table_markdown = ( + topic_summary_df_revised_display.to_markdown(index=False) + ) + + return ( + reference_df, + topic_summary_df, + output_files, + log_output_files, + deduplicated_unique_table_markdown, + ) + + # For checking that data is not lost during the process + initial_unique_references = len(reference_df["Response References"].unique()) + + if topic_summary_df.empty: + topic_summary_df = create_topic_summary_df_from_reference_table(reference_df) + + # Then merge the topic numbers back to the original dataframe + reference_df = reference_df.merge( + topic_summary_df[ + ["General topic", "Subtopic", "Sentiment", "Topic number"] + ], + on=["General topic", "Subtopic", "Sentiment"], + how="left", + ) + + if in_data_files and chosen_cols: + file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file( + in_data_files, chosen_cols, 1, in_excel_sheets + ) + else: + out_message = "No file data found, pivot table output will not be created." + print(out_message) + # raise Exception(out_message) + + # Run through this x times to try to get all duplicate topics + if deduplicate_topics == "Yes": + if "Group" not in reference_df.columns: + reference_df["Group"] = "All" + for i in range(0, 8): + if merge_sentiment == "No": + if merge_general_topics == "No": + reference_df["old_category"] = ( + reference_df["Subtopic"] + " | " + reference_df["Sentiment"] + ) + reference_df_unique = reference_df.drop_duplicates("old_category") + + # Create an empty list to store results from each group + results = list() + # Iterate over each group instead of using .apply() + for name, group in reference_df_unique.groupby( + ["General topic", "Sentiment", "Group"] + ): + # Run your function on the 'group' DataFrame + result = deduplicate_categories( + group["Subtopic"], + group["Sentiment"], + reference_df, + general_topic_series=group["General topic"], + merge_general_topics="No", + threshold=score_threshold, + ) + results.append(result) + + # Concatenate all the results into a single DataFrame + deduplicated_topic_map_df = pd.concat(results).reset_index( + drop=True + ) + # --- MODIFIED SECTION END --- + + else: + # This case should allow cross-topic matching but is still grouping by Sentiment + reference_df["old_category"] = ( + reference_df["Subtopic"] + " | " + reference_df["Sentiment"] + ) + reference_df_unique = reference_df.drop_duplicates("old_category") + + results = list() + for name, group in reference_df_unique.groupby("Sentiment"): + result = deduplicate_categories( + group["Subtopic"], + group["Sentiment"], + reference_df, + general_topic_series=None, + merge_general_topics="Yes", + threshold=score_threshold, + ) + results.append(result) + deduplicated_topic_map_df = pd.concat(results).reset_index( + drop=True + ) + + else: + if merge_general_topics == "No": + reference_df["old_category"] = ( + reference_df["Subtopic"] + " | " + reference_df["Sentiment"] + ) + reference_df_unique = reference_df.drop_duplicates("old_category") + + results = list() + for name, group in reference_df_unique.groupby("General topic"): + result = deduplicate_categories( + group["Subtopic"], + group["Sentiment"], + reference_df, + general_topic_series=group["General topic"], + merge_general_topics="No", + merge_sentiment=merge_sentiment, + threshold=score_threshold, + ) + results.append(result) + deduplicated_topic_map_df = pd.concat(results).reset_index( + drop=True + ) + + else: + reference_df["old_category"] = ( + reference_df["Subtopic"] + " | " + reference_df["Sentiment"] + ) + reference_df_unique = reference_df.drop_duplicates("old_category") + + deduplicated_topic_map_df = deduplicate_categories( + reference_df_unique["Subtopic"], + reference_df_unique["Sentiment"], + reference_df, + general_topic_series=None, + merge_general_topics="Yes", + merge_sentiment=merge_sentiment, + threshold=score_threshold, + ).reset_index(drop=True) + + if deduplicated_topic_map_df["deduplicated_category"].isnull().all(): + print("No deduplicated categories found, skipping the following code.") + + else: + # Remove rows where 'deduplicated_category' is blank or NaN + deduplicated_topic_map_df = deduplicated_topic_map_df.loc[ + ( + deduplicated_topic_map_df["deduplicated_category"].str.strip() + != "" + ) + & ~(deduplicated_topic_map_df["deduplicated_category"].isnull()), + ["old_category", "deduplicated_category", "match_score"], + ] + + reference_df = reference_df.merge( + deduplicated_topic_map_df, on="old_category", how="left" + ) + + reference_df.rename( + columns={"Subtopic": "Subtopic_old", "Sentiment": "Sentiment_old"}, + inplace=True, + ) + # Extract subtopic and sentiment from deduplicated_category + reference_df["Subtopic"] = reference_df[ + "deduplicated_category" + ].str.extract(r"^(.*?) \|")[ + 0 + ] # Extract subtopic + reference_df["Sentiment"] = reference_df[ + "deduplicated_category" + ].str.extract(r"\| (.*)$")[ + 0 + ] # Extract sentiment + + # Combine with old values to ensure no data is lost + reference_df["Subtopic"] = reference_df[ + "deduplicated_category" + ].combine_first(reference_df["Subtopic_old"]) + reference_df["Sentiment"] = reference_df["Sentiment"].combine_first( + reference_df["Sentiment_old"] + ) + + reference_df = reference_df.rename( + columns={"General Topic": "General topic"}, errors="ignore" + ) + reference_df = reference_df[ + [ + "Response References", + "General topic", + "Subtopic", + "Sentiment", + "Summary", + "Start row of group", + "Group", + ] + ] + + if merge_general_topics == "Yes": + # Replace General topic names for each Subtopic with that for the Subtopic with the most responses + # Step 1: Count the number of occurrences for each General topic and Subtopic combination + count_df = ( + reference_df.groupby(["Subtopic", "General topic"]) + .size() + .reset_index(name="Count") + ) + + # Step 2: Find the General topic with the maximum count for each Subtopic + max_general_topic = count_df.loc[ + count_df.groupby("Subtopic")["Count"].idxmax() + ] + + # Step 3: Map the General topic back to the original DataFrame + reference_df = reference_df.merge( + max_general_topic[["Subtopic", "General topic"]], + on="Subtopic", + suffixes=("", "_max"), + how="left", + ) + + reference_df["General topic"] = reference_df[ + "General topic_max" + ].combine_first(reference_df["General topic"]) + + if merge_sentiment == "Yes": + # Step 1: Count the number of occurrences for each General topic and Subtopic combination + count_df = ( + reference_df.groupby(["Subtopic", "Sentiment"]) + .size() + .reset_index(name="Count") + ) + + # Step 2: Determine the number of unique Sentiment values for each Subtopic + unique_sentiments = ( + count_df.groupby("Subtopic")["Sentiment"] + .nunique() + .reset_index(name="UniqueCount") + ) + + # Step 3: Update Sentiment to 'Mixed' where there is more than one unique sentiment + reference_df = reference_df.merge( + unique_sentiments, on="Subtopic", how="left" + ) + reference_df["Sentiment"] = reference_df.apply( + lambda row: "Mixed" if row["UniqueCount"] > 1 else row["Sentiment"], + axis=1, + ) + + # Clean up the DataFrame by dropping the UniqueCount column + reference_df.drop(columns=["UniqueCount"], inplace=True) + + # print("reference_df:", reference_df) + reference_df = reference_df[ + [ + "Response References", + "General topic", + "Subtopic", + "Sentiment", + "Summary", + "Start row of group", + "Group", + ] + ] + # reference_df.drop(['old_category', 'deduplicated_category', "Subtopic_old", "Sentiment_old"], axis=1, inplace=True, errors="ignore") + + # Update reference summary column with all summaries + reference_df["Summary"] = reference_df.groupby( + ["Response References", "General topic", "Subtopic", "Sentiment"] + )["Summary"].transform("
".join) + + # Check that we have not inadvertantly removed some data during the above process + end_unique_references = len(reference_df["Response References"].unique()) + + if initial_unique_references != end_unique_references: + raise Exception( + f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}" + ) + + # Drop duplicates in the reference table - each comment should only have the same topic referred to once + reference_df.drop_duplicates( + ["Response References", "General topic", "Subtopic", "Sentiment"], + inplace=True, + ) + + # Remake topic_summary_df based on new reference_df + topic_summary_df = create_topic_summary_df_from_reference_table(reference_df) + + # Then merge the topic numbers back to the original dataframe + reference_df = reference_df.merge( + topic_summary_df[ + ["General topic", "Subtopic", "Sentiment", "Group", "Topic number"] + ], + on=["General topic", "Subtopic", "Sentiment", "Group"], + how="left", + ) + + else: + print("Topics have not beeen deduplicated") + + reference_table_file_name_no_ext = get_file_name_no_ext(reference_table_file_name) + unique_topics_table_file_name_no_ext = get_file_name_no_ext( + unique_topics_table_file_name + ) + + if not file_data.empty: + basic_response_data = get_basic_response_data(file_data, chosen_cols) + reference_df_pivot = convert_reference_table_to_pivot_table( + reference_df, basic_response_data + ) + + reference_pivot_file_path = ( + output_folder + reference_table_file_name_no_ext + "_pivot_dedup.csv" + ) + reference_df_pivot.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + reference_pivot_file_path, index=None, encoding="utf-8-sig" + ) + log_output_files.append(reference_pivot_file_path) + + reference_file_out_path = ( + output_folder + reference_table_file_name_no_ext + "_dedup.csv" + ) + unique_topics_file_out_path = ( + output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv" + ) + reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + reference_file_out_path, index=None, encoding="utf-8-sig" + ) + topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + unique_topics_file_out_path, index=None, encoding="utf-8-sig" + ) + + output_files.append(reference_file_out_path) + output_files.append(unique_topics_file_out_path) + + # Outputs for markdown table output + topic_summary_df_revised_display = topic_summary_df.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + deduplicated_unique_table_markdown = topic_summary_df_revised_display.to_markdown( + index=False + ) + + return ( + reference_df, + topic_summary_df, + output_files, + log_output_files, + deduplicated_unique_table_markdown, + ) + + +def deduplicate_topics_llm( + reference_df: pd.DataFrame, + topic_summary_df: pd.DataFrame, + reference_table_file_name: str, + unique_topics_table_file_name: str, + model_choice: str, + in_api_key: str, + temperature: float, + model_source: str, + bedrock_runtime=None, + local_model=None, + tokenizer=None, + assistant_model=None, + in_excel_sheets: str = "", + merge_sentiment: str = "No", + merge_general_topics: str = "No", + in_data_files: List[str] = list(), + chosen_cols: List[str] = "", + output_folder: str = OUTPUT_FOLDER, + candidate_topics=None, + azure_endpoint: str = "", + output_debug_files: str = "False", + api_url: str = None, +): + """ + Deduplicate topics using LLM semantic understanding to identify and merge similar topics. + + Args: + reference_df (pd.DataFrame): DataFrame containing reference data with topics. + topic_summary_df (pd.DataFrame): DataFrame summarizing unique topics. + reference_table_file_name (str): Base file name for the output reference table. + unique_topics_table_file_name (str): Base file name for the output unique topics table. + model_choice (str): The LLM model to use for deduplication. + in_api_key (str): API key for the LLM service. + temperature (float): Temperature setting for the LLM. + model_source (str): Source of the model (AWS, Gemini, Local, etc.). + bedrock_runtime: AWS Bedrock runtime client (if using AWS). + local_model: Local model instance (if using local model). + tokenizer: Tokenizer for local model. + assistant_model: Assistant model for speculative decoding. + in_excel_sheets (str, optional): Comma-separated list of Excel sheet names to load. Defaults to "". + merge_sentiment (str, optional): Whether to merge topics regardless of sentiment ("Yes" or "No"). Defaults to "No". + merge_general_topics (str, optional): Whether to merge topics across different general topics ("Yes" or "No"). Defaults to "No". + in_data_files (List[str], optional): List of input data file paths. Defaults to []. + chosen_cols (List[str], optional): List of chosen columns from the input data files. Defaults to "". + output_folder (str, optional): Folder path to save output files. Defaults to OUTPUT_FOLDER. + candidate_topics (optional): Candidate topics file for zero-shot guidance. Defaults to None. + azure_endpoint (str, optional): Azure endpoint for the LLM. Defaults to "". + output_debug_files (str, optional): Whether to output debug files. Defaults to "False". + """ + + output_files = list() + log_output_files = list() + file_data = pd.DataFrame() + deduplicated_unique_table_markdown = "" + + # Check if data is too short for deduplication + if (len(reference_df["Response References"].unique()) == 1) | ( + len(topic_summary_df["Topic number"].unique()) == 1 + ): + print( + "Data file outputs are too short for deduplicating. Returning original data." + ) + + # Get file name without extension and create proper output paths + reference_table_file_name_no_ext = get_file_name_no_ext( + reference_table_file_name + ) + unique_topics_table_file_name_no_ext = get_file_name_no_ext( + unique_topics_table_file_name + ) + + # Create output paths with _dedup suffix to match normal path + reference_file_out_path = ( + output_folder + reference_table_file_name_no_ext + "_dedup.csv" + ) + unique_topics_file_out_path = ( + output_folder + unique_topics_table_file_name_no_ext + "_dedup.csv" + ) + + # Save the DataFrames to CSV files + reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + reference_file_out_path, index=None, encoding="utf-8-sig" + ) + topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + unique_topics_file_out_path, index=None, encoding="utf-8-sig" + ) + + output_files.append(reference_file_out_path) + output_files.append(unique_topics_file_out_path) + + # Create markdown output for display + topic_summary_df_revised_display = topic_summary_df.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + deduplicated_unique_table_markdown = ( + topic_summary_df_revised_display.to_markdown(index=False) + ) + + # Return with token counts set to 0 for early return + return ( + reference_df, + topic_summary_df, + output_files, + log_output_files, + deduplicated_unique_table_markdown, + 0, # input_tokens + 0, # output_tokens + 0, # number_of_calls + 0.0, # estimated_time_taken + ) + + # For checking that data is not lost during the process + initial_unique_references = len(reference_df["Response References"].unique()) + + # Create topic summary if it doesn't exist + if topic_summary_df.empty: + topic_summary_df = create_topic_summary_df_from_reference_table(reference_df) + + # Merge topic numbers back to the original dataframe + reference_df = reference_df.merge( + topic_summary_df[ + ["General topic", "Subtopic", "Sentiment", "Topic number"] + ], + on=["General topic", "Subtopic", "Sentiment"], + how="left", + ) + + # Load data files if provided + if in_data_files and chosen_cols: + file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file( + in_data_files, chosen_cols, 1, in_excel_sheets + ) + else: + out_message = "No file data found, pivot table output will not be created." + print(out_message) + + # Process candidate topics if provided + candidate_topics_table = "" + if candidate_topics is not None: + try: + + # Read and process candidate topics + # Handle both string paths (CLI) and gr.FileData objects (Gradio) + candidate_topics_path = ( + candidate_topics + if isinstance(candidate_topics, str) + else getattr(candidate_topics, "name", None) + ) + if candidate_topics_path is None: + raise ValueError( + "candidate_topics must be a file path string or a FileData object with a 'name' attribute" + ) + candidate_topics_df = read_file(candidate_topics_path) + candidate_topics_df = candidate_topics_df.fillna("") + candidate_topics_df = candidate_topics_df.astype(str) + + # Generate zero-shot topics DataFrame + zero_shot_topics_df = generate_zero_shot_topics_df( + candidate_topics_df, "No", False + ) + + if not zero_shot_topics_df.empty: + candidate_topics_table = zero_shot_topics_df[ + ["General topic", "Subtopic"] + ].to_markdown(index=False) + print( + f"Found {len(zero_shot_topics_df)} candidate topics to consider during deduplication" + ) + except Exception as e: + print(f"Error processing candidate topics: {e}") + candidate_topics_table = "" + + # Prepare topics table for LLM analysis + topics_table = topic_summary_df[ + ["General topic", "Subtopic", "Sentiment", "Number of responses"] + ].to_markdown(index=False) + + # Format the prompt with candidate topics if available + if candidate_topics_table: + formatted_prompt = llm_deduplication_prompt_with_candidates.format( + topics_table=topics_table, candidate_topics_table=candidate_topics_table + ) + else: + formatted_prompt = llm_deduplication_prompt.format(topics_table=topics_table) + + # Initialise conversation history + conversation_history = list() + whole_conversation = list() + whole_conversation_metadata = list() + + # Set up model clients based on model source + if "Gemini" in model_source: + client, config = construct_gemini_generative_model( + in_api_key, + temperature, + model_choice, + llm_deduplication_system_prompt, + max_tokens, + LLM_SEED, + ) + bedrock_runtime = None + elif "AWS" in model_source: + if not bedrock_runtime: + bedrock_runtime = boto3.client("bedrock-runtime") + client = None + config = None + elif "Azure/OpenAI" in model_source: + client, config = construct_azure_client(in_api_key, azure_endpoint) + bedrock_runtime = None + elif "Local" in model_source: + client = None + config = None + bedrock_runtime = None + elif "inference-server" in model_source: + client = None + config = None + bedrock_runtime = None + # api_url is already passed to call_llm_with_markdown_table_checks + if api_url is None: + raise ValueError( + "api_url is required when model_source is 'inference-server'" + ) + else: + raise ValueError(f"Unsupported model source: {model_source}") + + # Call LLM to get deduplication suggestions + print("Calling LLM for topic deduplication analysis...") + + # Use the existing call_llm_with_markdown_table_checks function + ( + responses, + conversation_history, + whole_conversation, + whole_conversation_metadata, + response_text, + ) = call_llm_with_markdown_table_checks( + batch_prompts=[formatted_prompt], + system_prompt=llm_deduplication_system_prompt, + conversation_history=conversation_history, + whole_conversation=whole_conversation, + whole_conversation_metadata=whole_conversation_metadata, + client=client, + client_config=config, + model_choice=model_choice, + temperature=temperature, + reported_batch_no=1, + local_model=local_model, + tokenizer=tokenizer, + bedrock_runtime=bedrock_runtime, + model_source=model_source, + MAX_OUTPUT_VALIDATION_ATTEMPTS=3, + assistant_prefill="", + master=False, + CHOSEN_LOCAL_MODEL_TYPE=CHOSEN_LOCAL_MODEL_TYPE, + random_seed=LLM_SEED, + api_url=api_url, + ) + + # Generate debug files if enabled + if output_debug_files == "True": + try: + # Create batch file path details for debug files + batch_file_path_details = ( + get_file_name_no_ext(reference_table_file_name) + "_llm_dedup" + ) + model_choice_clean_short = ( + model_choice.replace("/", "_").replace(":", "_").replace(".", "_") + ) + + # Create full prompt for debug output + full_prompt = llm_deduplication_system_prompt + "\n" + formatted_prompt + + # Write debug files + ( + current_prompt_content_logged, + current_summary_content_logged, + current_conversation_content_logged, + current_metadata_content_logged, + ) = process_debug_output_iteration( + OUTPUT_DEBUG_FILES, + output_folder, + batch_file_path_details, + model_choice_clean_short, + full_prompt, + response_text, + whole_conversation, + whole_conversation_metadata, + log_output_files, + task_type="llm_deduplication", + ) + + print("Debug files written for LLM deduplication analysis") + + except Exception as e: + print(f"Error writing debug files for LLM deduplication: {e}") + + # Parse the LLM response to extract merge suggestions + merge_suggestions_df = ( + pd.DataFrame() + ) # Initialize empty DataFrame for analysis results + num_merges_applied = 0 + + try: + # Extract the markdown table from the response + table_match = re.search( + r"\|.*\|.*\n\|.*\|.*\n(\|.*\|.*\n)*", response_text, re.MULTILINE + ) + if table_match: + table_text = table_match.group(0) + + # Convert markdown table to DataFrame + from io import StringIO + + merge_suggestions_df = pd.read_csv( + StringIO(table_text), sep="|", skipinitialspace=True + ) + + # Clean up the DataFrame + merge_suggestions_df = merge_suggestions_df.dropna( + axis=1, how="all" + ) # Remove empty columns + merge_suggestions_df.columns = merge_suggestions_df.columns.str.strip() + + # Remove rows where all values are NaN + merge_suggestions_df = merge_suggestions_df.dropna(how="all") + + if not merge_suggestions_df.empty: + print( + f"LLM identified {len(merge_suggestions_df)} potential topic merges" + ) + + # Apply the merges to the reference_df + for _, row in merge_suggestions_df.iterrows(): + original_general = row.get("Original General topic", "").strip() + original_subtopic = row.get("Original Subtopic", "").strip() + original_sentiment = row.get("Original Sentiment", "").strip() + merged_general = row.get("Merged General topic", "").strip() + merged_subtopic = row.get("Merged Subtopic", "").strip() + merged_sentiment = row.get("Merged Sentiment", "").strip() + + if all( + [ + original_general, + original_subtopic, + original_sentiment, + merged_general, + merged_subtopic, + merged_sentiment, + ] + ): + + # Find matching rows in reference_df + mask = ( + (reference_df["General topic"] == original_general) + & (reference_df["Subtopic"] == original_subtopic) + & (reference_df["Sentiment"] == original_sentiment) + ) + + if mask.any(): + # Update the matching rows + reference_df.loc[mask, "General topic"] = merged_general + reference_df.loc[mask, "Subtopic"] = merged_subtopic + reference_df.loc[mask, "Sentiment"] = merged_sentiment + num_merges_applied += 1 + print( + f"Merged: {original_general} | {original_subtopic} | {original_sentiment} -> {merged_general} | {merged_subtopic} | {merged_sentiment}" + ) + else: + print("No merge suggestions found in LLM response") + else: + print("No markdown table found in LLM response") + + except Exception as e: + print(f"Error parsing LLM response: {e}") + print("Continuing with original data...") + + # Update reference summary column with all summaries + reference_df["Summary"] = reference_df.groupby( + ["Response References", "General topic", "Subtopic", "Sentiment"] + )["Summary"].transform("
".join) + + # Check that we have not inadvertently removed some data during the process + end_unique_references = len(reference_df["Response References"].unique()) + + if initial_unique_references != end_unique_references: + raise Exception( + f"Number of unique references changed during processing: Initial={initial_unique_references}, Final={end_unique_references}" + ) + + # Drop duplicates in the reference table + reference_df.drop_duplicates( + ["Response References", "General topic", "Subtopic", "Sentiment"], inplace=True + ) + + # Remake topic_summary_df based on new reference_df + topic_summary_df = create_topic_summary_df_from_reference_table(reference_df) + + # Merge the topic numbers back to the original dataframe + reference_df = reference_df.merge( + topic_summary_df[ + ["General topic", "Subtopic", "Sentiment", "Group", "Topic number"] + ], + on=["General topic", "Subtopic", "Sentiment", "Group"], + how="left", + ) + + # Create pivot table if file data is available + if not file_data.empty: + basic_response_data = get_basic_response_data(file_data, chosen_cols) + reference_df_pivot = convert_reference_table_to_pivot_table( + reference_df, basic_response_data + ) + + reference_pivot_file_path = ( + output_folder + + get_file_name_no_ext(reference_table_file_name) + + "_pivot_dedup.csv" + ) + reference_df_pivot.to_csv( + reference_pivot_file_path, index=None, encoding="utf-8-sig" + ) + log_output_files.append(reference_pivot_file_path) + + # Save analysis results CSV if merge suggestions were found + if not merge_suggestions_df.empty: + analysis_results_file_path = ( + output_folder + + get_file_name_no_ext(reference_table_file_name) + + "_dedup_llm_analysis_results.csv" + ) + merge_suggestions_df.to_csv( + analysis_results_file_path, index=None, encoding="utf-8-sig" + ) + log_output_files.append(analysis_results_file_path) + print(f"Analysis results saved to: {analysis_results_file_path}") + + # Save output files + reference_file_out_path = ( + output_folder + get_file_name_no_ext(reference_table_file_name) + "_dedup.csv" + ) + unique_topics_file_out_path = ( + output_folder + + get_file_name_no_ext(unique_topics_table_file_name) + + "_dedup.csv" + ) + reference_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + reference_file_out_path, index=None, encoding="utf-8-sig" + ) + topic_summary_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + unique_topics_file_out_path, index=None, encoding="utf-8-sig" + ) + + output_files.append(reference_file_out_path) + output_files.append(unique_topics_file_out_path) + + # Outputs for markdown table output + topic_summary_df_revised_display = topic_summary_df.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + deduplicated_unique_table_markdown = topic_summary_df_revised_display.to_markdown( + index=False + ) + + # Calculate token usage and timing information for logging + total_input_tokens = 0 + total_output_tokens = 0 + number_of_calls = 1 # Single LLM call for deduplication + + # Extract token usage from conversation metadata + if whole_conversation_metadata: + for metadata in whole_conversation_metadata: + if "input_tokens:" in metadata and "output_tokens:" in metadata: + try: + input_tokens = int( + metadata.split("input_tokens: ")[1].split(" ")[0] + ) + output_tokens = int( + metadata.split("output_tokens: ")[1].split(" ")[0] + ) + total_input_tokens += input_tokens + total_output_tokens += output_tokens + except (ValueError, IndexError): + pass + + # Calculate estimated time taken (rough estimate based on token usage) + estimated_time_taken = ( + total_input_tokens + total_output_tokens + ) / 1000 # Rough estimate in seconds + + return ( + reference_df, + topic_summary_df, + output_files, + log_output_files, + deduplicated_unique_table_markdown, + total_input_tokens, + total_output_tokens, + number_of_calls, + estimated_time_taken, + ) # , num_merges_applied + + +def sample_reference_table_summaries( + reference_df: pd.DataFrame, + random_seed: int, + no_of_sampled_summaries: int = default_number_of_sampled_summaries, + sample_reference_table_checkbox: bool = False, +): + """ + Sample x number of summaries from which to produce summaries, so that the input token length is not too long. + """ + + if sample_reference_table_checkbox: + + all_summaries = pd.DataFrame( + columns=[ + "General topic", + "Subtopic", + "Sentiment", + "Group", + "Response References", + "Summary", + ] + ) + + if "Group" not in reference_df.columns: + reference_df["Group"] = "All" + + reference_df_grouped = reference_df.groupby( + ["General topic", "Subtopic", "Sentiment", "Group"] + ) + + if "Revised summary" in reference_df.columns: + out_message = "Summary has already been created for this file" + print(out_message) + raise Exception(out_message) + + for group_keys, reference_df_group in reference_df_grouped: + if len(reference_df_group["General topic"]) > 1: + + filtered_reference_df = reference_df_group.reset_index() + + filtered_reference_df_unique = filtered_reference_df.drop_duplicates( + ["General topic", "Subtopic", "Sentiment", "Summary"] + ) + + # Sample n of the unique topic summaries PER GROUP. To limit the length of the text going into the summarisation tool + # This ensures each group gets up to no_of_sampled_summaries summaries, not the total across all groups + filtered_reference_df_unique_sampled = ( + filtered_reference_df_unique.sample( + min(no_of_sampled_summaries, len(filtered_reference_df_unique)), + random_state=random_seed, + ) + ) + + all_summaries = pd.concat( + [all_summaries, filtered_reference_df_unique_sampled] + ) + + # If no responses/topics qualify, just go ahead with the original reference dataframe + if all_summaries.empty: + sampled_reference_table_df = reference_df + # Filter by sentiment only (Response References is a string in original df, not a count) + sampled_reference_table_df = sampled_reference_table_df.loc[ + sampled_reference_table_df["Sentiment"] != "Not Mentioned" + ] + else: + # FIXED: Preserve Group column in aggregation to maintain group-specific summaries + sampled_reference_table_df = ( + all_summaries.groupby( + ["General topic", "Subtopic", "Sentiment", "Group"] + ) + .agg( + { + "Response References": "size", # Count the number of references + "Summary": lambda x: "\n".join( + [s.split(": ", 1)[1] for s in x if ": " in s] + ), # Join substrings after ': ' + } + ) + .reset_index() + ) + # Filter by sentiment and count (Response References is now a numeric count after aggregation) + sampled_reference_table_df = sampled_reference_table_df.loc[ + (sampled_reference_table_df["Sentiment"] != "Not Mentioned") + & (sampled_reference_table_df["Response References"] > 1) + ] + else: + sampled_reference_table_df = reference_df + + summarised_references_markdown = sampled_reference_table_df.to_markdown(index=False) + + return sampled_reference_table_df, summarised_references_markdown + + +def count_tokens_in_text(text: str, tokenizer=None, model_source: str = "Local") -> int: + """ + Count the number of tokens in the given text. + + Args: + text (str): The text to count tokens for + tokenizer (object, optional): Tokenizer object for local models. Defaults to None. + model_source (str): Source of the model to determine tokenization method. Defaults to "Local". + + Returns: + int: Number of tokens in the text + """ + if not text: + return 0 + + try: + if model_source == "Local" and tokenizer and len(tokenizer) > 0: + # Use local tokenizer if available + tokens = tokenizer[0].encode(text, add_special_tokens=False) + return len(tokens) + else: + # Fallback: rough estimation using word count (approximately 1.3 tokens per word) + word_count = len(text.split()) + return int(word_count * 1.3) + except Exception as e: + print(f"Error counting tokens: {e}. Using word count estimation.") + # Fallback: rough estimation using word count + word_count = len(text.split()) + return int(word_count * 1.3) + + +def summarise_output_topics_query( + model_choice: str, + in_api_key: str, + temperature: float, + formatted_summary_prompt: str, + summarise_topic_descriptions_system_prompt: str, + model_source: str, + bedrock_runtime: boto3.Session.client, + local_model=list(), + tokenizer=list(), + assistant_model=list(), + azure_endpoint: str = "", + api_url: str = None, +): + """ + Query an LLM to generate a summary of topics based on the provided prompts. + + Args: + model_choice (str): The name/type of model to use for generation + in_api_key (str): API key for accessing the model service + temperature (float): Temperature parameter for controlling randomness in generation + formatted_summary_prompt (str): The formatted prompt containing topics to summarize + summarise_topic_descriptions_system_prompt (str): System prompt providing context and instructions + model_source (str): Source of the model (e.g. "AWS", "Gemini", "Local") + bedrock_runtime (boto3.Session.client): AWS Bedrock runtime client for AWS models + local_model (object, optional): Local model object if using local inference. Defaults to empty list. + tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list. + Returns: + tuple: Contains: + - response_text (str): The generated summary text + - conversation_history (list): History of the conversation with the model + - whole_conversation_metadata (list): Metadata about the conversation + """ + conversation_history = list() + whole_conversation_metadata = list() + client = list() + client_config = {} + + # Combine system prompt and user prompt for token counting + full_input_text = ( + summarise_topic_descriptions_system_prompt + "\n" + formatted_summary_prompt[0] + if isinstance(formatted_summary_prompt, list) + else summarise_topic_descriptions_system_prompt + + "\n" + + formatted_summary_prompt + ) + + # Count tokens in the input text + input_token_count = count_tokens_in_text(full_input_text, tokenizer, model_source) + + # Check if input exceeds context length + if input_token_count > LLM_CONTEXT_LENGTH: + error_message = f"Input text exceeds LLM context length. Input tokens: {input_token_count}, Max context length: {LLM_CONTEXT_LENGTH}. Please reduce the input text size." + print(error_message) + raise ValueError(error_message) + + print(f"Input token count: {input_token_count} (Max: {LLM_CONTEXT_LENGTH})") + + # Prepare Gemini models before query + if "Gemini" in model_source: + # print("Using Gemini model:", model_choice) + client, config = construct_gemini_generative_model( + in_api_key=in_api_key, + temperature=temperature, + model_choice=model_choice, + system_prompt=system_prompt, + max_tokens=max_tokens, + ) + elif "Azure/OpenAI" in model_source: + client, config = construct_azure_client( + in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), + endpoint=azure_endpoint, + ) + elif "Local" in model_source: + pass + # print("Using local model: ", model_choice) + elif "AWS" in model_source: + pass + # print("Using AWS Bedrock model:", model_choice) + + whole_conversation = [summarise_topic_descriptions_system_prompt] + + # Process requests to large language model + ( + responses, + conversation_history, + whole_conversation, + whole_conversation_metadata, + response_text, + ) = process_requests( + formatted_summary_prompt, + system_prompt, + conversation_history, + whole_conversation, + whole_conversation_metadata, + client, + client_config, + model_choice, + temperature, + bedrock_runtime=bedrock_runtime, + model_source=model_source, + local_model=local_model, + tokenizer=tokenizer, + assistant_model=assistant_model, + assistant_prefill=summary_assistant_prefill, + api_url=api_url, + ) + + summarised_output = re.sub( + r"\n{2,}", "\n", response_text + ) # Replace multiple line breaks with a single line break + summarised_output = re.sub( + r"^\n{1,}", "", summarised_output + ) # Remove one or more line breaks at the start + summarised_output = re.sub( + r"\n", "
", summarised_output + ) # Replace \n with more html friendly
tags + summarised_output = summarised_output.strip() + + print("Finished summary query") + + # Ensure the system prompt is included in the conversation history + try: + if isinstance(conversation_history, list): + has_system_prompt = False + + if conversation_history: + first_entry = conversation_history[0] + if isinstance(first_entry, dict): + role_is_system = first_entry.get("role") == "system" + parts = first_entry.get("parts") + content_matches = ( + parts == summarise_topic_descriptions_system_prompt + or ( + isinstance(parts, list) + and summarise_topic_descriptions_system_prompt in parts + ) + ) + has_system_prompt = role_is_system and content_matches + elif isinstance(first_entry, str): + has_system_prompt = ( + first_entry.strip().lower().startswith("system:") + ) + + if not has_system_prompt: + conversation_history.insert( + 0, + { + "role": "system", + "parts": [summarise_topic_descriptions_system_prompt], + }, + ) + except Exception as _e: + # Non-fatal: if anything goes wrong, return the original conversation history + pass + + return ( + summarised_output, + conversation_history, + whole_conversation_metadata, + response_text, + ) + + +def process_debug_output_iteration( + output_debug_files: str, + output_folder: str, + batch_file_path_details: str, + model_choice_clean_short: str, + final_system_prompt: str, + summarised_output: str, + conversation_history: list, + metadata: list, + log_output_files: list, + task_type: str, +) -> tuple[str, str, str, str]: + """ + Writes debug files for summary generation if output_debug_files is "True", + and returns the content of the prompt, summary, conversation, and metadata for the current iteration. + + Args: + output_debug_files (str): Flag to indicate if debug files should be written. + output_folder (str): The folder where output files are saved. + batch_file_path_details (str): Details for the batch file path. + model_choice_clean_short (str): Shortened cleaned model choice. + final_system_prompt (str): The system prompt content. + summarised_output (str): The summarised output content. + conversation_history (list): The full conversation history. + metadata (list): The metadata for the conversation. + log_output_files (list): A list to append paths of written log files. This list is modified in-place. + task_type (str): The type of task being performed. + Returns: + tuple[str, str, str, str]: A tuple containing the content of the prompt, + summarised output, conversation history (as string), + and metadata (as string) for the current iteration. + """ + current_prompt_content = final_system_prompt + current_summary_content = summarised_output + + if isinstance(conversation_history, list): + + # Handle both list of strings and list of dicts + if conversation_history and isinstance(conversation_history[0], dict): + # Convert list of dicts to list of strings + conversation_strings = list() + for entry in conversation_history: + if "role" in entry and "parts" in entry: + role = entry["role"].capitalize() + message = ( + " ".join(entry["parts"]) + if isinstance(entry["parts"], list) + else str(entry["parts"]) + ) + conversation_strings.append(f"{role}: {message}") + else: + # Fallback for unexpected dict format + conversation_strings.append(str(entry)) + current_conversation_content = "\n".join(conversation_strings) + else: + # Handle list of strings + current_conversation_content = "\n".join(conversation_history) + else: + current_conversation_content = str(conversation_history) + current_metadata_content = str(metadata) + current_task_type = task_type + + if output_debug_files == "True": + try: + formatted_prompt_output_path = ( + output_folder + + batch_file_path_details + + "_full_prompt_" + + model_choice_clean_short + + "_" + + current_task_type + + ".txt" + ) + final_table_output_path = ( + output_folder + + batch_file_path_details + + "_full_response_" + + model_choice_clean_short + + "_" + + current_task_type + + ".txt" + ) + whole_conversation_path = ( + output_folder + + batch_file_path_details + + "_full_conversation_" + + model_choice_clean_short + + "_" + + current_task_type + + ".txt" + ) + whole_conversation_path_meta = ( + output_folder + + batch_file_path_details + + "_metadata_" + + model_choice_clean_short + + "_" + + current_task_type + + ".txt" + ) + + with open( + formatted_prompt_output_path, + "w", + encoding="utf-8-sig", + errors="replace", + ) as f: + f.write(current_prompt_content) + with open( + final_table_output_path, "w", encoding="utf-8-sig", errors="replace" + ) as f: + f.write(current_summary_content) + with open( + whole_conversation_path, "w", encoding="utf-8-sig", errors="replace" + ) as f: + f.write(current_conversation_content) + with open( + whole_conversation_path_meta, + "w", + encoding="utf-8-sig", + errors="replace", + ) as f: + f.write(current_metadata_content) + + log_output_files.append(formatted_prompt_output_path) + log_output_files.append(final_table_output_path) + log_output_files.append(whole_conversation_path) + log_output_files.append(whole_conversation_path_meta) + except Exception as e: + print(f"Error in writing debug files for summary: {e}") + + # Return the content of the objects for the current iteration. + # The caller can then append these to separate lists if accumulation is desired. + return ( + current_prompt_content, + current_summary_content, + current_conversation_content, + current_metadata_content, + ) + + +@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME) +def summarise_output_topics( + sampled_reference_table_df: pd.DataFrame, + topic_summary_df: pd.DataFrame, + reference_table_df: pd.DataFrame, + model_choice: str, + in_api_key: str, + temperature: float, + reference_data_file_name: str, + summarised_outputs: list = list(), + latest_summary_completed: int = 0, + out_metadata_str: str = "", + in_data_files: List[str] = list(), + in_excel_sheets: str = "", + chosen_cols: List[str] = list(), + log_output_files: list[str] = list(), + summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text", + output_folder: str = OUTPUT_FOLDER, + context_textbox: str = "", + aws_access_key_textbox: str = "", + aws_secret_key_textbox: str = "", + aws_region_textbox: str = "", + model_name_map: dict = model_name_map, + hf_api_key_textbox: str = "", + azure_endpoint_textbox: str = "", + existing_logged_content: list = list(), + additional_summary_instructions_provided: str = "", + output_debug_files: str = "False", + group_value: str = "All", + reasoning_suffix: str = reasoning_suffix, + local_model: object = None, + tokenizer: object = None, + assistant_model: object = None, + summarise_topic_descriptions_prompt: str = summarise_topic_descriptions_prompt, + summarise_topic_descriptions_system_prompt: str = summarise_topic_descriptions_system_prompt, + do_summaries: str = "Yes", + api_url: str = None, + progress=gr.Progress(track_tqdm=True), +): + """ + Create improved summaries of topics by consolidating raw batch-level summaries from the initial model run. Works on a single group of summaries at a time (called from wrapper function summarise_output_topics_by_group). + + Args: + sampled_reference_table_df (pd.DataFrame): DataFrame containing sampled reference data with summaries + topic_summary_df (pd.DataFrame): DataFrame containing topic summary information + reference_table_df (pd.DataFrame): DataFrame mapping response references to topics + model_choice (str): Name of the LLM model to use + in_api_key (str): API key for model access + temperature (float): Temperature parameter for model generation + reference_data_file_name (str): Name of the reference data file + summarised_outputs (list, optional): List to store generated summaries. Defaults to empty list. + latest_summary_completed (int, optional): Index of last completed summary. Defaults to 0. + out_metadata_str (str, optional): String for metadata output. Defaults to empty string. + in_data_files (List[str], optional): List of input data file paths. Defaults to empty list. + in_excel_sheets (str, optional): Excel sheet names if using Excel files. Defaults to empty string. + chosen_cols (List[str], optional): List of columns selected for analysis. Defaults to empty list. + log_output_files (list[str], optional): List of log file paths. Defaults to empty list. + summarise_format_radio (str, optional): Format instructions for summary generation. Defaults to two paragraph format. + output_folder (str, optional): Folder path for outputs. Defaults to OUTPUT_FOLDER. + context_textbox (str, optional): Additional context for summarization. Defaults to empty string. + aws_access_key_textbox (str, optional): AWS access key. Defaults to empty string. + aws_secret_key_textbox (str, optional): AWS secret key. Defaults to empty string. + model_name_map (dict, optional): Dictionary mapping model choices to their properties. Defaults to model_name_map. + hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string. + azure_endpoint_textbox (str, optional): Azure endpoint. Defaults to empty string. + additional_summary_instructions_provided (str, optional): Additional summary instructions provided by the user. Defaults to empty string. + existing_logged_content (list, optional): List of existing logged content. Defaults to empty list. + output_debug_files (str, optional): Flag to indicate if debug files should be written. Defaults to "False". + group_value (str, optional): Value of the group to summarise. Defaults to "All". + reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix. + local_model (object, optional): Local model object if using local inference. Defaults to None. + tokenizer (object, optional): Tokenizer object if using local inference. Defaults to None. + assistant_model (object, optional): Assistant model object if using local inference. Defaults to None. + summarise_topic_descriptions_prompt (str, optional): Prompt template for topic summarization. + summarise_topic_descriptions_system_prompt (str, optional): System prompt for topic summarization. + do_summaries (str, optional): Flag to control summary generation. Defaults to "Yes". + progress (gr.Progress, optional): Gradio progress tracker. Defaults to track_tqdm=True. + + Returns: + Multiple outputs including summarized content, metadata, and file paths + """ + out_metadata = list() + summarised_output_markdown = "" + output_files = list() + acc_input_tokens = 0 + acc_output_tokens = 0 + acc_number_of_calls = 0 + time_taken = 0 + out_metadata_str = ( + "" # Output metadata is currently replaced on starting a summarisation task + ) + out_message = list() + task_type = "Topic summarisation" + topic_summary_df_revised = pd.DataFrame() + + all_prompts_content = list() + all_summaries_content = list() + all_metadata_content = list() + all_groups_content = list() + all_batches_content = list() + all_model_choice_content = list() + all_validated_content = list() + all_task_type_content = list() + all_logged_content = list() + all_file_names_content = list() + + tic = time.perf_counter() + + # Ensure custom model_choice is registered in model_name_map + ensure_model_in_map(model_choice, model_name_map) + + model_choice_clean = clean_column_name( + model_name_map[model_choice]["short_name"], + max_length=20, + front_characters=False, + ) + + if context_textbox and "The context of this analysis is" not in context_textbox: + context_textbox = "The context of this analysis is '" + context_textbox + "'." + + if log_output_files is None: + log_output_files = list() + + # Check for data for summarisations + if not topic_summary_df.empty and not reference_table_df.empty: + print("Unique table and reference table data found.") + else: + out_message = "Please upload a unique topic table and reference table file to continue with summarisation." + print(out_message) + raise Exception(out_message) + + if "Revised summary" in reference_table_df.columns: + out_message = "Summary has already been created for this file" + print(out_message) + raise Exception(out_message) + + # Load in data file and chosen columns if exists to create pivot table later + file_data = pd.DataFrame() + if in_data_files and chosen_cols: + file_data, data_file_names_textbox, total_number_of_batches = load_in_data_file( + in_data_files, chosen_cols, 1, in_excel_sheets=in_excel_sheets + ) + else: + out_message = "No file data found, pivot table output will not be created." + print(out_message) + # Use sys.stdout.write to avoid issues with progress bars + # sys.stdout.write(out_message + "\n") + # sys.stdout.flush() + # Note: file_data will remain empty, pivot tables will not be created + + reference_table_df = reference_table_df.rename( + columns={"General Topic": "General topic"}, errors="ignore" + ) + topic_summary_df = topic_summary_df.rename( + columns={"General Topic": "General topic"}, errors="ignore" + ) + if "Group" not in reference_table_df.columns: + reference_table_df["Group"] = "All" + if "Group" not in topic_summary_df.columns: + topic_summary_df["Group"] = "All" + if "Group" not in sampled_reference_table_df.columns: + sampled_reference_table_df["Group"] = "All" + + # Use the Summary column if it exists, otherwise use the Revised summary column + if "Summary" in sampled_reference_table_df.columns: + all_summaries = sampled_reference_table_df["Summary"].tolist() + else: + all_summaries = sampled_reference_table_df["Revised summary"].tolist() + + all_groups = sampled_reference_table_df["Group"].tolist() + + if not group_value: + group_value = str(all_groups[0]) + else: + group_value = str(group_value) + + length_all_summaries = len(all_summaries) + + model_source = model_name_map[model_choice]["source"] + + if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not local_model): + progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}") + local_model = get_model() + tokenizer = get_tokenizer() + assistant_model = get_assistant_model() + + ( + "Revising topic-level summaries. " + + str(latest_summary_completed) + + " summaries completed so far." + ) + summary_loop = progress.tqdm( + range(latest_summary_completed, length_all_summaries), + desc="Revising topic-level summaries", + unit="summaries", + ) + + if do_summaries == "Yes": + + bedrock_runtime = connect_to_bedrock_runtime( + model_name_map, + model_choice, + aws_access_key_textbox, + aws_secret_key_textbox, + aws_region_textbox, + ) + + create_batch_file_path_details(reference_data_file_name) + model_choice_clean_short = clean_column_name( + model_choice_clean, max_length=20, front_characters=False + ) + file_name_clean = f"{clean_column_name(reference_data_file_name, max_length=15)}_{clean_column_name(str(group_value), max_length=15).replace(' ','_')}" + # file_name_clean = clean_column_name(reference_data_file_name, max_length=20, front_characters=True) + in_column_cleaned = clean_column_name(chosen_cols, max_length=20) + + combined_summary_instructions = ( + summarise_format_radio + ". " + additional_summary_instructions_provided + ) + + for summary_no in summary_loop: + print("Current summary number is:", summary_no) + + batch_file_path_details = f"{file_name_clean}_batch_{latest_summary_completed + 1}_size_1_col_{in_column_cleaned}" + + summary_text = all_summaries[summary_no] + formatted_summary_prompt = [ + summarise_topic_descriptions_prompt.format( + summaries=summary_text, summary_format=combined_summary_instructions + ) + ] + + formatted_summarise_topic_descriptions_system_prompt = ( + summarise_topic_descriptions_system_prompt.format( + column_name=chosen_cols, consultation_context=context_textbox + ) + ) + + if "Local" in model_source and reasoning_suffix: + formatted_summarise_topic_descriptions_system_prompt = ( + formatted_summarise_topic_descriptions_system_prompt + + "\n" + + reasoning_suffix + ) + + try: + response, conversation_history, metadata, response_text = ( + summarise_output_topics_query( + model_choice, + in_api_key, + temperature, + formatted_summary_prompt, + formatted_summarise_topic_descriptions_system_prompt, + model_source, + bedrock_runtime, + local_model, + tokenizer=tokenizer, + assistant_model=assistant_model, + azure_endpoint=azure_endpoint_textbox, + api_url=api_url, + ) + ) + summarised_output = response_text + except Exception as e: + print("Creating summary failed:", e) + summarised_output = "" + + summarised_outputs.append(summarised_output) + out_metadata.extend(metadata) + out_metadata_str = ". ".join(out_metadata) + + # Call the new function to process and log debug outputs for the current iteration. + # The returned values are the contents of the prompt, summary, conversation, and metadata + + full_prompt = ( + formatted_summarise_topic_descriptions_system_prompt + + "\n" + + formatted_summary_prompt[0] + ) + + # Coerce toggle to string expected by debug writer (accepts True/False or "True"/"False") + output_debug_files_str = ( + "True" + if ( + (isinstance(output_debug_files, bool) and output_debug_files) + or (str(output_debug_files) == "True") + ) + else "False" + ) + + ( + current_prompt_content_logged, + current_summary_content_logged, + current_conversation_content_logged, + current_metadata_content_logged, + ) = process_debug_output_iteration( + output_debug_files_str, + output_folder, + batch_file_path_details, + model_choice_clean_short, + full_prompt, + summarised_output, + conversation_history, + metadata, + log_output_files, + task_type=task_type, + ) + + all_prompts_content.append(current_prompt_content_logged) + all_summaries_content.append(current_summary_content_logged) + # all_conversation_content.append(current_conversation_content_logged) + all_metadata_content.append(current_metadata_content_logged) + all_groups_content.append(all_groups[summary_no]) + all_batches_content.append(f"{summary_no}:") + all_model_choice_content.append(model_choice_clean_short) + all_validated_content.append("No") + all_task_type_content.append(task_type) + all_file_names_content.append(reference_data_file_name) + latest_summary_completed += 1 + + toc = time.perf_counter() + time_taken = toc - tic + + if time_taken > max_time_for_loop: + print( + "Time taken for loop is greater than maximum time allowed. Exiting and restarting loop" + ) + summary_loop.close() + tqdm._instances.clear() + break + + # If all summaries completed, make final outputs + if latest_summary_completed >= length_all_summaries: + print("All summaries completed. Creating outputs.") + + sampled_reference_table_df["Revised summary"] = summarised_outputs + + join_cols = ["General topic", "Subtopic", "Sentiment"] + join_plus_summary_cols = [ + "General topic", + "Subtopic", + "Sentiment", + "Revised summary", + ] + + summarised_references_j = sampled_reference_table_df[ + join_plus_summary_cols + ].drop_duplicates(join_plus_summary_cols) + + topic_summary_df_revised = topic_summary_df.merge( + summarised_references_j, on=join_cols, how="left" + ) + + # If no new summary is available, keep the original + # But prefer the version without "Rows X to Y" prefix to avoid duplication + def clean_summary_text(text): + if pd.isna(text): + return text + # Remove "Rows X to Y:" prefix if present (both at start and after
tags) + import re + + # First remove from the beginning + cleaned = re.sub(r"^Rows\s+\d+\s+to\s+\d+:\s*", "", str(text)) + # Then remove from after
tags + cleaned = re.sub(r"
\s*Rows\s+\d+\s+to\s+\d+:\s*", "
", cleaned) + return cleaned + + topic_summary_df_revised["Revised summary"] = topic_summary_df_revised[ + "Revised summary" + ].combine_first(topic_summary_df_revised["Summary"]) + # Clean the revised summary to remove "Rows X to Y" prefixes + topic_summary_df_revised["Revised summary"] = topic_summary_df_revised[ + "Revised summary" + ].apply(clean_summary_text) + topic_summary_df_revised = topic_summary_df_revised[ + [ + "General topic", + "Subtopic", + "Sentiment", + "Group", + "Number of responses", + "Revised summary", + ] + ] + + # Note: "Rows X to Y:" prefixes are now cleaned by the clean_summary_text function above + topic_summary_df_revised["Topic number"] = range( + 1, len(topic_summary_df_revised) + 1 + ) + + # If no new summary is available, keep the original. Also join on topic number to ensure consistent topic number assignment + reference_table_df_revised = reference_table_df.copy() + reference_table_df_revised = reference_table_df_revised.drop( + "Topic number", axis=1, errors="ignore" + ) + + # Ensure reference table has Topic number column + if ( + "Topic number" not in reference_table_df_revised.columns + or "Revised summary" not in reference_table_df_revised.columns + ): + if ( + "Topic number" in topic_summary_df_revised.columns + and "Revised summary" in topic_summary_df_revised.columns + ): + reference_table_df_revised = reference_table_df_revised.merge( + topic_summary_df_revised[ + [ + "General topic", + "Subtopic", + "Sentiment", + "Group", + "Topic number", + "Revised summary", + ] + ], + on=["General topic", "Subtopic", "Sentiment", "Group"], + how="left", + ) + + reference_table_df_revised["Revised summary"] = reference_table_df_revised[ + "Revised summary" + ].combine_first(reference_table_df_revised["Summary"]) + # Clean the revised summary to remove "Rows X to Y" prefixes + reference_table_df_revised["Revised summary"] = reference_table_df_revised[ + "Revised summary" + ].apply(clean_summary_text) + reference_table_df_revised = reference_table_df_revised.drop( + "Summary", axis=1, errors="ignore" + ) + + # Remove topics that are tagged as 'Not Mentioned' + topic_summary_df_revised = topic_summary_df_revised.loc[ + topic_summary_df_revised["Sentiment"] != "Not Mentioned", : + ] + reference_table_df_revised = reference_table_df_revised.loc[ + reference_table_df_revised["Sentiment"] != "Not Mentioned", : + ] + + # Combine the logged content into a list of dictionaries + all_logged_content = [ + { + "prompt": prompt, + "response": summary, + "metadata": metadata, + "batch": batch, + "model_choice": model_choice, + "validated": validated, + "group": group, + "task_type": task_type, + "file_name": file_name, + } + for prompt, summary, metadata, batch, model_choice, validated, group, task_type, file_name in zip( + all_prompts_content, + all_summaries_content, + all_metadata_content, + all_batches_content, + all_model_choice_content, + all_validated_content, + all_groups_content, + all_task_type_content, + all_file_names_content, + ) + ] + + if isinstance(existing_logged_content, pd.DataFrame): + existing_logged_content = existing_logged_content.to_dict(orient="records") + + out_logged_content = existing_logged_content + all_logged_content + + ### Save output files + + if output_debug_files == "True": + + if not file_data.empty: + basic_response_data = get_basic_response_data(file_data, chosen_cols) + reference_table_df_revised_pivot = ( + convert_reference_table_to_pivot_table( + reference_table_df_revised, basic_response_data + ) + ) + + ### Save pivot file to log area + reference_table_df_revised_pivot_path = ( + output_folder + + file_name_clean + + "_summ_reference_table_pivot_" + + model_choice_clean + + ".csv" + ) + reference_table_df_revised_pivot.drop( + ["1", "2", "3"], axis=1, errors="ignore" + ).to_csv( + reference_table_df_revised_pivot_path, + index=None, + encoding="utf-8-sig", + ) + log_output_files.append(reference_table_df_revised_pivot_path) + + # Save to file + topic_summary_df_revised_path = ( + output_folder + + file_name_clean + + "_summ_unique_topics_table_" + + model_choice_clean + + ".csv" + ) + topic_summary_df_revised.drop( + ["1", "2", "3"], axis=1, errors="ignore" + ).to_csv(topic_summary_df_revised_path, index=None, encoding="utf-8-sig") + + reference_table_df_revised_path = ( + output_folder + + file_name_clean + + "_summ_reference_table_" + + model_choice_clean + + ".csv" + ) + reference_table_df_revised.drop( + ["1", "2", "3"], axis=1, errors="ignore" + ).to_csv(reference_table_df_revised_path, index=None, encoding="utf-8-sig") + + log_output_files.extend( + [reference_table_df_revised_path, topic_summary_df_revised_path] + ) + + ### + topic_summary_df_revised_display = topic_summary_df_revised.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + summarised_output_markdown = topic_summary_df_revised_display.to_markdown( + index=False + ) + + # Ensure same file name not returned twice + output_files = list(set(output_files)) + log_output_files = list(set(log_output_files)) + + acc_input_tokens, acc_output_tokens, acc_number_of_calls = ( + calculate_tokens_from_metadata( + out_metadata_str, model_choice, model_name_map + ) + ) + + toc = time.perf_counter() + time_taken = toc - tic + + if isinstance(out_message, list): + out_message = "\n".join(out_message) + else: + out_message = out_message + + out_message = ( + out_message + + f"\nTopic summarisation finished processing. Total time: {round(float(time_taken), 1)}s" + ) + print(out_message) + + return ( + sampled_reference_table_df, + topic_summary_df_revised, + reference_table_df_revised, + output_files, + summarised_outputs, + latest_summary_completed, + out_metadata_str, + summarised_output_markdown, + log_output_files, + output_files, + acc_input_tokens, + acc_output_tokens, + acc_number_of_calls, + time_taken, + out_message, + out_logged_content, + ) + + +@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME) +def wrapper_summarise_output_topics_per_group( + grouping_col: str, + sampled_reference_table_df: pd.DataFrame, + topic_summary_df: pd.DataFrame, + reference_table_df: pd.DataFrame, + model_choice: str, + in_api_key: str, + temperature: float, + reference_data_file_name: str, + summarised_outputs: list = list(), + latest_summary_completed: int = 0, + out_metadata_str: str = "", + in_data_files: List[str] = list(), + in_excel_sheets: str = "", + chosen_cols: List[str] = list(), + log_output_files: list[str] = list(), + summarise_format_radio: str = "Return a summary up to two paragraphs long that includes as much detail as possible from the original text", + output_folder: str = OUTPUT_FOLDER, + context_textbox: str = "", + aws_access_key_textbox: str = "", + aws_secret_key_textbox: str = "", + aws_region_textbox: str = "", + model_name_map: dict = model_name_map, + hf_api_key_textbox: str = "", + azure_endpoint_textbox: str = "", + existing_logged_content: list = list(), + sample_reference_table: bool = False, + no_of_sampled_summaries: int = default_number_of_sampled_summaries, + random_seed: int = 42, + api_url: str = None, + additional_summary_instructions_provided: str = "", + output_debug_files: str = OUTPUT_DEBUG_FILES, + reasoning_suffix: str = reasoning_suffix, + local_model: object = None, + tokenizer: object = None, + assistant_model: object = None, + summarise_topic_descriptions_prompt: str = summarise_topic_descriptions_prompt, + summarise_topic_descriptions_system_prompt: str = summarise_topic_descriptions_system_prompt, + do_summaries: str = "Yes", + progress=gr.Progress(track_tqdm=True), +) -> Tuple[ + pd.DataFrame, + pd.DataFrame, + pd.DataFrame, + List[str], + List[str], + int, + str, + str, + List[str], + List[str], + int, + int, + int, + float, + str, + List[dict], +]: + """ + A wrapper function that iterates through unique values in a specified grouping column + and calls the `summarise_output_topics` function for each group of summaries. + It accumulates results from each call and returns a consolidated output. + + :param grouping_col: The name of the column to group the data by. + :param sampled_reference_table_df: DataFrame containing sampled reference data with summaries + :param topic_summary_df: DataFrame containing topic summary information + :param reference_table_df: DataFrame mapping response references to topics + :param model_choice: Name of the LLM model to use + :param in_api_key: API key for model access + :param temperature: Temperature parameter for model generation + :param reference_data_file_name: Name of the reference data file + :param summarised_outputs: List to store generated summaries + :param latest_summary_completed: Index of last completed summary + :param out_metadata_str: String for metadata output + :param in_data_files: List of input data file paths + :param in_excel_sheets: Excel sheet names if using Excel files + :param chosen_cols: List of columns selected for analysis + :param log_output_files: List of log file paths + :param summarise_format_radio: Format instructions for summary generation + :param output_folder: Folder path for outputs + :param context_textbox: Additional context for summarization + :param aws_access_key_textbox: AWS access key + :param aws_secret_key_textbox: AWS secret key + :param model_name_map: Dictionary mapping model choices to their properties + :param hf_api_key_textbox: Hugging Face API key + :param azure_endpoint_textbox: Azure endpoint + :param existing_logged_content: List of existing logged content + :param additional_summary_instructions_provided: Additional summary instructions + :param output_debug_files: Flag to indicate if debug files should be written + :param reasoning_suffix: Suffix for reasoning + :param local_model: Local model object if using local inference + :param tokenizer: Tokenizer object if using local inference + :param assistant_model: Assistant model object if using local inference + :param summarise_topic_descriptions_prompt: Prompt template for topic summarization + :param summarise_topic_descriptions_system_prompt: System prompt for topic summarization + :param do_summaries: Flag to control summary generation + :param sample_reference_table: If True, sample the reference table at the top of the function + :param no_of_sampled_summaries: Number of summaries to sample per group (default 100) + :param random_seed: Random seed for reproducible sampling (default 42) + :param progress: Gradio progress tracker + :return: A tuple containing consolidated results, mimicking the return structure of `summarise_output_topics` + """ + + acc_input_tokens = 0 + acc_output_tokens = 0 + acc_number_of_calls = 0 + out_message = list() + + # Logged content + all_groups_logged_content = existing_logged_content + + # Check if we have data to process + # Allow empty sampled_reference_table_df if sample_reference_table is True (it will be created from reference_table_df) + if ( + (sampled_reference_table_df.empty and not sample_reference_table) + or topic_summary_df.empty + or reference_table_df.empty + ): + out_message = "Please upload reference table, topic summary, and sampled reference table files to continue with summarisation." + print(out_message) + raise Exception(out_message) + + # Ensure Group column exists + if "Group" not in sampled_reference_table_df.columns: + sampled_reference_table_df["Group"] = "All" + if "Group" not in topic_summary_df.columns: + topic_summary_df["Group"] = "All" + if "Group" not in reference_table_df.columns: + reference_table_df["Group"] = "All" + + # Sample reference table if requested + if sample_reference_table: + print( + f"Sampling reference table with {no_of_sampled_summaries} summaries per group..." + ) + sampled_reference_table_df, _ = sample_reference_table_summaries( + reference_table_df, + random_seed=random_seed, + no_of_sampled_summaries=no_of_sampled_summaries, + sample_reference_table_checkbox=sample_reference_table, + ) + print( + f"Sampling complete. {len(sampled_reference_table_df)} summaries selected." + ) + + # Get unique group values + unique_values = sampled_reference_table_df["Group"].unique() + + if len(unique_values) > MAX_GROUPS: + print( + f"Warning: More than {MAX_GROUPS} unique values found in '{grouping_col}'. Processing only the first {MAX_GROUPS}." + ) + unique_values = unique_values[:MAX_GROUPS] + + # Initialize accumulators for results across all groups + acc_sampled_reference_table_df = pd.DataFrame() + acc_topic_summary_df_revised = pd.DataFrame() + acc_reference_table_df_revised = pd.DataFrame() + acc_output_files = list() + acc_log_output_files = list() + acc_summarised_outputs = list() + acc_latest_summary_completed = latest_summary_completed + acc_out_metadata_str = out_metadata_str + acc_summarised_output_markdown = "" + acc_total_time_taken = 0.0 + acc_logged_content = list() + + if len(unique_values) == 1: + # If only one unique value, no need for progress bar, iterate directly + loop_object = unique_values + else: + # If multiple unique values, use tqdm progress bar + loop_object = progress.tqdm( + unique_values, desc="Summarising group", unit="groups" + ) + + for i, group_value in enumerate(loop_object): + print( + f"\nProcessing summary group: {grouping_col} = {group_value} ({i+1}/{len(unique_values)})" + ) + + # Filter data for current group + filtered_sampled_reference_table_df = sampled_reference_table_df[ + sampled_reference_table_df["Group"] == group_value + ].copy() + filtered_topic_summary_df = topic_summary_df[ + topic_summary_df["Group"] == group_value + ].copy() + filtered_reference_table_df = reference_table_df[ + reference_table_df["Group"] == group_value + ].copy() + + if filtered_sampled_reference_table_df.empty: + print(f"No data for {grouping_col} = {group_value}. Skipping.") + continue + + # Create unique file name for this group's outputs + group_file_name = f"{reference_data_file_name}_{clean_column_name(str(group_value), max_length=15).replace(' ','_')}" + + # Call summarise_output_topics for the current group + try: + ( + seg_sampled_reference_table_df, + seg_topic_summary_df_revised, + seg_reference_table_df_revised, + seg_output_files, + seg_summarised_outputs, + seg_latest_summary_completed, + seg_out_metadata_str, + seg_summarised_output_markdown, + seg_log_output_files, + seg_output_files_2, + seg_acc_input_tokens, + seg_acc_output_tokens, + seg_acc_number_of_calls, + seg_time_taken, + seg_out_message, + seg_logged_content, + ) = summarise_output_topics( + sampled_reference_table_df=filtered_sampled_reference_table_df, + topic_summary_df=filtered_topic_summary_df, + reference_table_df=filtered_reference_table_df, + model_choice=model_choice, + in_api_key=in_api_key, + temperature=temperature, + reference_data_file_name=group_file_name, + summarised_outputs=list(), # Fresh for each call + latest_summary_completed=0, # Reset for each group + out_metadata_str="", # Fresh for each call + in_data_files=in_data_files, + in_excel_sheets=in_excel_sheets, + chosen_cols=chosen_cols, + log_output_files=list(), # Fresh for each call + summarise_format_radio=summarise_format_radio, + output_folder=output_folder, + context_textbox=context_textbox, + aws_access_key_textbox=aws_access_key_textbox, + aws_secret_key_textbox=aws_secret_key_textbox, + aws_region_textbox=aws_region_textbox, + model_name_map=model_name_map, + hf_api_key_textbox=hf_api_key_textbox, + azure_endpoint_textbox=azure_endpoint_textbox, + existing_logged_content=all_groups_logged_content, + additional_summary_instructions_provided=additional_summary_instructions_provided, + output_debug_files=output_debug_files, + group_value=group_value, + reasoning_suffix=reasoning_suffix, + local_model=local_model, + tokenizer=tokenizer, + assistant_model=assistant_model, + summarise_topic_descriptions_prompt=summarise_topic_descriptions_prompt, + summarise_topic_descriptions_system_prompt=summarise_topic_descriptions_system_prompt, + do_summaries=do_summaries, + api_url=api_url, + ) + + # Aggregate results + acc_sampled_reference_table_df = pd.concat( + [acc_sampled_reference_table_df, seg_sampled_reference_table_df] + ) + acc_topic_summary_df_revised = pd.concat( + [acc_topic_summary_df_revised, seg_topic_summary_df_revised] + ) + acc_reference_table_df_revised = pd.concat( + [acc_reference_table_df_revised, seg_reference_table_df_revised] + ) + + # For lists, extend + acc_output_files.extend( + f for f in seg_output_files if f not in acc_output_files + ) + acc_log_output_files.extend( + f for f in seg_log_output_files if f not in acc_log_output_files + ) + acc_summarised_outputs.extend(seg_summarised_outputs) + + acc_latest_summary_completed = seg_latest_summary_completed + acc_out_metadata_str += ( + ("\n---\n" if acc_out_metadata_str else "") + + f"Group {grouping_col}={group_value}:\n" + + seg_out_metadata_str + ) + acc_summarised_output_markdown = ( + seg_summarised_output_markdown # Keep the latest markdown + ) + acc_total_time_taken += float(seg_time_taken) + acc_logged_content.extend(seg_logged_content) + + # Accumulate token counts + acc_input_tokens += seg_acc_input_tokens + acc_output_tokens += seg_acc_output_tokens + acc_number_of_calls += seg_acc_number_of_calls + + print( + f"Group {grouping_col} = {group_value} summarised. Time: {seg_time_taken:.2f}s" + ) + + except Exception as e: + print(f"Error processing summary group {grouping_col} = {group_value}: {e}") + # Optionally, decide if you want to continue with other groups or stop + # For now, it will continue + continue + + # Ensure custom model_choice is registered in model_name_map + ensure_model_in_map(model_choice, model_name_map) + + # Create consolidated output files + overall_file_name = clean_column_name(reference_data_file_name, max_length=20) + model_choice_clean = model_name_map[model_choice]["short_name"] + model_choice_clean_short = clean_column_name( + model_choice_clean, max_length=20, front_characters=False + ) + + # Save consolidated outputs + if ( + not acc_topic_summary_df_revised.empty + and not acc_reference_table_df_revised.empty + ): + # Sort the dataframes + if "General topic" in acc_topic_summary_df_revised.columns: + acc_topic_summary_df_revised["Number of responses"] = ( + acc_topic_summary_df_revised["Number of responses"].astype(int) + ) + acc_topic_summary_df_revised.sort_values( + [ + "Group", + "Number of responses", + "General topic", + "Subtopic", + "Sentiment", + ], + ascending=[True, False, True, True, True], + inplace=True, + ) + elif "Main heading" in acc_topic_summary_df_revised.columns: + acc_topic_summary_df_revised["Number of responses"] = ( + acc_topic_summary_df_revised["Number of responses"].astype(int) + ) + acc_topic_summary_df_revised.sort_values( + [ + "Group", + "Number of responses", + "Main heading", + "Subheading", + "Topic number", + ], + ascending=[True, False, True, True, True], + inplace=True, + ) + + # Save consolidated files + consolidated_topic_summary_path = ( + output_folder + + overall_file_name + + "_all_final_summ_unique_topics_" + + model_choice_clean_short + + ".csv" + ) + consolidated_reference_table_path = ( + output_folder + + overall_file_name + + "_all_final_summ_reference_table_" + + model_choice_clean_short + + ".csv" + ) + + acc_topic_summary_df_revised.drop( + ["1", "2", "3"], axis=1, errors="ignore" + ).to_csv(consolidated_topic_summary_path, index=None, encoding="utf-8-sig") + acc_reference_table_df_revised.drop( + ["1", "2", "3"], axis=1, errors="ignore" + ).to_csv(consolidated_reference_table_path, index=None, encoding="utf-8-sig") + + acc_output_files.extend( + [consolidated_topic_summary_path, consolidated_reference_table_path] + ) + + # Create markdown output for display + topic_summary_df_revised_display = acc_topic_summary_df_revised.apply( + lambda col: col.map(lambda x: wrap_text(x, max_text_length=max_text_length)) + ) + acc_summarised_output_markdown = topic_summary_df_revised_display.to_markdown( + index=False + ) + + out_message = "\n".join(out_message) + out_message = ( + out_message + + " " + + f"Topic summarisation finished processing all groups. Total time: {acc_total_time_taken:.2f}s" + ) + print(out_message) + + # The return signature should match summarise_output_topics + return ( + acc_sampled_reference_table_df, + acc_topic_summary_df_revised, + acc_reference_table_df_revised, + acc_output_files, + acc_summarised_outputs, + acc_latest_summary_completed, + acc_out_metadata_str, + acc_summarised_output_markdown, + acc_log_output_files, + acc_output_files, # Duplicate for compatibility + acc_input_tokens, + acc_output_tokens, + acc_number_of_calls, + acc_total_time_taken, + out_message, + acc_logged_content, + ) + + +@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME) +def overall_summary( + topic_summary_df: pd.DataFrame, + model_choice: str, + in_api_key: str, + temperature: float, + reference_data_file_name: str, + output_folder: str = OUTPUT_FOLDER, + chosen_cols: List[str] = list(), + context_textbox: str = "", + aws_access_key_textbox: str = "", + aws_secret_key_textbox: str = "", + aws_region_textbox: str = "", + model_name_map: dict = model_name_map, + hf_api_key_textbox: str = "", + azure_endpoint_textbox: str = "", + existing_logged_content: list = list(), + api_url: str = None, + output_debug_files: str = output_debug_files, + log_output_files: list = list(), + reasoning_suffix: str = reasoning_suffix, + local_model: object = None, + tokenizer: object = None, + assistant_model: object = None, + summarise_everything_prompt: str = summarise_everything_prompt, + comprehensive_summary_format_prompt: str = comprehensive_summary_format_prompt, + comprehensive_summary_format_prompt_by_group: str = comprehensive_summary_format_prompt_by_group, + summarise_everything_system_prompt: str = summarise_everything_system_prompt, + do_summaries: str = "Yes", + progress=gr.Progress(track_tqdm=True), +) -> Tuple[ + List[str], + List[str], + int, + str, + List[str], + List[str], + int, + int, + int, + float, + List[dict], +]: + """ + Create an overall summary of all responses based on a topic summary table. + + Args: + topic_summary_df (pd.DataFrame): DataFrame containing topic summaries + model_choice (str): Name of the LLM model to use + in_api_key (str): API key for model access + temperature (float): Temperature parameter for model generation + reference_data_file_name (str): Name of reference data file + output_folder (str, optional): Folder to save outputs. Defaults to OUTPUT_FOLDER. + chosen_cols (List[str], optional): Columns to analyze. Defaults to empty list. + context_textbox (str, optional): Additional context. Defaults to empty string. + aws_access_key_textbox (str, optional): AWS access key. Defaults to empty string. + aws_secret_key_textbox (str, optional): AWS secret key. Defaults to empty string. + aws_region_textbox (str, optional): AWS region. Defaults to empty string. + model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map. + hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string. + existing_logged_content (list, optional): List of existing logged content. Defaults to empty list. + output_debug_files (str, optional): Flag to indicate if debug files should be written. Defaults to "False". + log_output_files (list, optional): List of existing logged content. Defaults to empty list. + api_url (str, optional): API URL for inference-server models. Defaults to None. + reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix. + local_model (object, optional): Local model object. Defaults to None. + tokenizer (object, optional): Tokenizer object. Defaults to None. + assistant_model (object, optional): Assistant model object. Defaults to None. + summarise_everything_prompt (str, optional): Prompt for overall summary + comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format + comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format + summarise_everything_system_prompt (str, optional): System prompt for overall summary + do_summaries (str, optional): Whether to generate summaries. Defaults to "Yes". + progress (gr.Progress, optional): Progress tracker. Defaults to gr.Progress(track_tqdm=True). + + Returns: + Tuple containing: + List[str]: Output files + List[str]: Text summarized outputs + int: Latest summary completed + str: Output metadata + List[str]: Summarized outputs + List[str]: Summarized outputs for DataFrame + int: Number of input tokens + int: Number of output tokens + int: Number of API calls + float: Time taken + List[dict]: List of logged content + """ + + out_metadata = list() + latest_summary_completed = 0 + output_files = list() + txt_summarised_outputs = list() + summarised_outputs = list() + summarised_outputs_for_df = list() + input_tokens_num = 0 + output_tokens_num = 0 + number_of_calls_num = 0 + time_taken = 0 + out_message = list() + all_logged_content = list() + all_prompts_content = list() + all_summaries_content = list() + all_metadata_content = list() + all_groups_content = list() + all_batches_content = list() + all_model_choice_content = list() + all_validated_content = list() + task_type = "Overall summary" + all_task_type_content = list() + log_output_files = list() + all_logged_content = list() + all_file_names_content = list() + tic = time.perf_counter() + + if "Group" not in topic_summary_df.columns: + topic_summary_df["Group"] = "All" + + topic_summary_df = topic_summary_df.sort_values( + by=["Group", "Number of responses"], ascending=[True, False] + ) + + unique_groups = sorted(topic_summary_df["Group"].unique()) + + length_groups = len(unique_groups) + + if context_textbox and "The context of this analysis is" not in context_textbox: + context_textbox = "The context of this analysis is '" + context_textbox + "'." + + if length_groups > 1: + comprehensive_summary_format_prompt = ( + comprehensive_summary_format_prompt_by_group + ) + else: + comprehensive_summary_format_prompt = comprehensive_summary_format_prompt + + # Ensure custom model_choice is registered in model_name_map + ensure_model_in_map(model_choice, model_name_map) + + batch_file_path_details = create_batch_file_path_details(reference_data_file_name) + model_choice_clean = model_name_map[model_choice]["short_name"] + model_choice_clean_short = clean_column_name( + model_choice_clean, max_length=20, front_characters=False + ) + + tic = time.perf_counter() + + if ( + (model_choice == CHOSEN_LOCAL_MODEL_TYPE) + & (RUN_LOCAL_MODEL == "1") + & (not local_model) + ): + progress(0.1, f"Using model: {CHOSEN_LOCAL_MODEL_TYPE}") + local_model = get_model() + tokenizer = get_tokenizer() + assistant_model = get_assistant_model() + + summary_loop = tqdm( + unique_groups, desc="Creating overall summary for groups", unit="groups" + ) + + if do_summaries == "Yes": + model_source = model_name_map[model_choice]["source"] + bedrock_runtime = connect_to_bedrock_runtime( + model_name_map, + model_choice, + aws_access_key_textbox, + aws_secret_key_textbox, + aws_region_textbox, + ) + + for summary_group in summary_loop: + + print("Creating overall summary for group:", summary_group) + + summary_text = topic_summary_df.loc[ + topic_summary_df["Group"] == summary_group + ].to_markdown(index=False) + + formatted_summary_prompt = [ + summarise_everything_prompt.format( + topic_summary_table=summary_text, + summary_format=comprehensive_summary_format_prompt, + ) + ] + + formatted_summarise_everything_system_prompt = ( + summarise_everything_system_prompt.format( + column_name=chosen_cols, consultation_context=context_textbox + ) + ) + + if "Local" in model_source and reasoning_suffix: + formatted_summarise_everything_system_prompt = ( + formatted_summarise_everything_system_prompt + + "\n" + + reasoning_suffix + ) + + try: + response, conversation_history, metadata, response_text = ( + summarise_output_topics_query( + model_choice, + in_api_key, + temperature, + formatted_summary_prompt, + formatted_summarise_everything_system_prompt, + model_source, + bedrock_runtime, + local_model, + tokenizer=tokenizer, + assistant_model=assistant_model, + azure_endpoint=azure_endpoint_textbox, + api_url=api_url, + ) + ) + summarised_output_for_df = response_text + summarised_output = response + except Exception as e: + print( + "Cannot create overall summary for group:", + summary_group, + "due to:", + e, + ) + summarised_output = "" + summarised_output_for_df = "" + + summarised_outputs_for_df.append(summarised_output_for_df) + summarised_outputs.append(summarised_output) + txt_summarised_outputs.append( + f"""Group name: {summary_group}\n""" + summarised_output + ) + + out_metadata.extend(metadata) + out_metadata_str = ". ".join(out_metadata) + + full_prompt = ( + formatted_summarise_everything_system_prompt + + "\n" + + formatted_summary_prompt[0] + ) + + ( + current_prompt_content_logged, + current_summary_content_logged, + current_conversation_content_logged, + current_metadata_content_logged, + ) = process_debug_output_iteration( + output_debug_files, + output_folder, + batch_file_path_details, + model_choice_clean_short, + full_prompt, + summarised_output, + conversation_history, + metadata, + log_output_files, + task_type=task_type, + ) + + all_prompts_content.append(current_prompt_content_logged) + all_summaries_content.append(current_summary_content_logged) + # all_conversation_content.append(current_conversation_content_logged) + all_metadata_content.append(current_metadata_content_logged) + all_groups_content.append(summary_group) + all_batches_content.append("1") + all_model_choice_content.append(model_choice_clean_short) + all_validated_content.append("No") + all_task_type_content.append(task_type) + all_file_names_content.append(reference_data_file_name) + latest_summary_completed += 1 + clean_column_name(summary_group) + + # Write overall outputs to csv + overall_summary_output_csv_path = ( + output_folder + + batch_file_path_details + + "_overall_summary_" + + model_choice_clean_short + + ".csv" + ) + summarised_outputs_df = pd.DataFrame( + data={"Group": unique_groups, "Summary": summarised_outputs_for_df} + ) + summarised_outputs_df.drop(["1", "2", "3"], axis=1, errors="ignore").to_csv( + overall_summary_output_csv_path, index=None, encoding="utf-8-sig" + ) + output_files.append(overall_summary_output_csv_path) + + summarised_outputs_df_for_display = pd.DataFrame( + data={"Group": unique_groups, "Summary": summarised_outputs} + ) + summarised_outputs_df_for_display["Summary"] = ( + summarised_outputs_df_for_display["Summary"] + .apply(lambda x: markdown.markdown(x) if isinstance(x, str) else x) + .str.replace(r"\n", "
", regex=False) + ) + html_output_table = summarised_outputs_df_for_display.to_html( + index=False, escape=False + ) + + output_files = list(set(output_files)) + + input_tokens_num, output_tokens_num, number_of_calls_num = ( + calculate_tokens_from_metadata( + out_metadata_str, model_choice, model_name_map + ) + ) + + # Check if beyond max time allowed for processing and break if necessary + toc = time.perf_counter() + time_taken = toc - tic + + out_message = "\n".join(out_message) + out_message = ( + out_message + + " " + + f"Overall summary finished processing. Total time: {time_taken:.2f}s" + ) + print(out_message) + + # Combine the logged content into a list of dictionaries + all_logged_content = [ + { + "prompt": prompt, + "response": summary, + "metadata": metadata, + "batch": batch, + "model_choice": model_choice, + "validated": validated, + "group": group, + "task_type": task_type, + "file_name": file_name, + } + for prompt, summary, metadata, batch, model_choice, validated, group, task_type, file_name in zip( + all_prompts_content, + all_summaries_content, + all_metadata_content, + all_batches_content, + all_model_choice_content, + all_validated_content, + all_groups_content, + all_task_type_content, + all_file_names_content, + ) + ] + + if isinstance(existing_logged_content, pd.DataFrame): + existing_logged_content = existing_logged_content.to_dict(orient="records") + + out_logged_content = existing_logged_content + all_logged_content + + return ( + output_files, + html_output_table, + summarised_outputs_df, + out_metadata_str, + input_tokens_num, + output_tokens_num, + number_of_calls_num, + time_taken, + out_message, + out_logged_content, + )