Spaces:
Runtime error
Runtime error
| from collections import Counter | |
| from itertools import count, groupby, islice | |
| from operator import itemgetter | |
| from typing import Any, Iterable, TypeVar | |
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| from datasets import Features | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from requests.adapters import HTTPAdapter, Retry | |
| from analyze import PresidioEntity, analyzer, get_column_description, get_columns_with_strings, mask, presidio_scan_entities | |
| MAX_ROWS = 100 | |
| T = TypeVar("T") | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) | |
| session.mount('http://', HTTPAdapter(max_retries=retries)) | |
| DEFAULT_PRESIDIO_ENTITIES = sorted([ | |
| 'PERSON', | |
| 'CREDIT_CARD', | |
| 'US_SSN', | |
| 'US_DRIVER_LICENSE', | |
| 'PHONE_NUMBER', | |
| 'US_PASSPORT', | |
| 'EMAIL_ADDRESS', | |
| 'IP_ADDRESS', | |
| 'US_BANK_NUMBER', | |
| 'IBAN_CODE', | |
| 'EMAIL', | |
| ]) | |
| def stream_rows(dataset: str, config: str, split: str) -> Iterable[dict[str, Any]]: | |
| batch_size = 100 | |
| for i in count(): | |
| rows_resp = session.get(f"/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Frows%3Fdataset%3D%3Cspan class="hljs-subst">{dataset}&config={config}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() | |
| if "error" in rows_resp: | |
| raise RuntimeError(rows_resp["error"]) | |
| if not rows_resp["rows"]: | |
| break | |
| for row_item in rows_resp["rows"]: | |
| yield row_item["row"] | |
| class track_iter: | |
| def __init__(self, it: Iterable[T]): | |
| self.it = it | |
| self.next_idx = 0 | |
| def __iter__(self) -> T: | |
| for item in self.it: | |
| self.next_idx += 1 | |
| yield item | |
| def presidio_report(presidio_entities: list[PresidioEntity], next_row_idx: int, num_rows: int) -> dict[str, float]: | |
| title = f"Scan finished: {len(presidio_entities)} entities found" if num_rows == next_row_idx else "Scan in progress..." | |
| counter = Counter([title] * next_row_idx) | |
| for row_idx, presidio_entities_per_row in groupby(presidio_entities, itemgetter("row_idx")): | |
| counter.update(set("% of rows with " + presidio_entity["type"] for presidio_entity in presidio_entities_per_row)) | |
| return dict((presidio_entity_type, presidio_entity_type_row_count / num_rows) for presidio_entity_type, presidio_entity_type_row_count in counter.most_common()) | |
| def analyze_dataset(dataset: str, enabled_presidio_entities: list[str] = DEFAULT_PRESIDIO_ENTITIES, show_texts_without_masks: bool = False) -> pd.DataFrame: | |
| info_resp = session.get(f"/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Finfo%3Fdataset%3D%3Cspan class="hljs-subst">{dataset}", timeout=3).json() | |
| if "error" in info_resp: | |
| yield "β " + info_resp["error"], pd.DataFrame() | |
| return | |
| config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"])) | |
| features = Features.from_dict(info_resp["dataset_info"][config]["features"]) | |
| split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(iter(info_resp["dataset_info"][config]["splits"])) | |
| num_rows = min(info_resp["dataset_info"][config]["splits"][split]["num_examples"], MAX_ROWS) | |
| scanned_columns = get_columns_with_strings(features) | |
| columns_descriptions = [ | |
| get_column_description(column_name, features[column_name]) for column_name in scanned_columns | |
| ] | |
| rows = track_iter(islice(stream_rows(dataset, config, split), MAX_ROWS)) | |
| presidio_entities = [] | |
| for presidio_entity in presidio_scan_entities( | |
| rows, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions | |
| ): | |
| if not show_texts_without_masks: | |
| presidio_entity["text"] = mask(presidio_entity["text"]) | |
| if presidio_entity["type"] in enabled_presidio_entities: | |
| presidio_entities.append(presidio_entity) | |
| yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) | |
| yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) | |
| with gr.Blocks(css=".table {border-collapse: separate}") as demo: # custom CSS to fix a bug with gr.DataFrame, see https://github.com/radames/gradio-custom-components/issues/1 | |
| gr.Markdown("# Scan datasets using Presidio") | |
| gr.Markdown("The space takes an HF dataset name as an input, and returns the list of entities detected by Presidio in the first samples.") | |
| inputs = [ | |
| HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Search for dataset id on Huggingface", | |
| search_type="dataset", | |
| ), | |
| gr.CheckboxGroup( | |
| label="Presidio entities", | |
| choices=sorted(analyzer.get_supported_entities()), | |
| value=DEFAULT_PRESIDIO_ENTITIES, | |
| interactive=True, | |
| ), | |
| gr.Checkbox(label="Show texts without masks", value=False), | |
| ] | |
| button = gr.Button("Run Presidio Scan") | |
| outputs = [ | |
| gr.Label(show_label=False), | |
| gr.DataFrame(), | |
| ] | |
| button.click(analyze_dataset, inputs, outputs) | |
| gr.Examples( | |
| [ | |
| ["microsoft/orca-math-word-problems-200k"], | |
| ["tatsu-lab/alpaca"], | |
| ["Anthropic/hh-rlhf"], | |
| ["OpenAssistant/oasst1"], | |
| ["sidhq/email-thread-summary"], | |
| ["lhoestq/fake_name_and_ssn"] | |
| ], | |
| inputs, | |
| outputs, | |
| fn=analyze_dataset, | |
| run_on_click=True, | |
| cache_examples=False, | |
| ) | |
| demo.launch() | |