Spaces:
Running
Running
Fix-feature-mapping-for-multi-labels
#133
by
ZeroCommand
- opened
- text_classification.py +7 -6
- text_classification_ui_helpers.py +13 -8
text_classification.py
CHANGED
|
@@ -22,23 +22,24 @@ class HuggingFaceInferenceAPIResponse:
|
|
| 22 |
def get_labels_and_features_from_dataset(ds):
|
| 23 |
try:
|
| 24 |
dataset_features = ds.features
|
| 25 |
-
label_keys = [i for i in dataset_features.keys() if i.startswith(
|
|
|
|
|
|
|
| 26 |
if len(label_keys) == 0: # no labels found
|
| 27 |
# return everything for post processing
|
| 28 |
-
return list(dataset_features.keys()), list(dataset_features.keys())
|
| 29 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
| 30 |
-
if hasattr(dataset_features[label_keys[0]],
|
| 31 |
label_feat = dataset_features[label_keys[0]].feature
|
| 32 |
labels = label_feat.names
|
| 33 |
else:
|
| 34 |
labels = dataset_features[label_keys[0]].names
|
| 35 |
-
|
| 36 |
-
return labels, features
|
| 37 |
except Exception as e:
|
| 38 |
logging.warning(
|
| 39 |
f"Get Labels/Features Failed for dataset: {e}"
|
| 40 |
)
|
| 41 |
-
return None, None
|
| 42 |
|
| 43 |
def check_model_task(model_id):
|
| 44 |
# check if model is valid on huggingface
|
|
|
|
| 22 |
def get_labels_and_features_from_dataset(ds):
|
| 23 |
try:
|
| 24 |
dataset_features = ds.features
|
| 25 |
+
label_keys = [i for i in dataset_features.keys() if i.startswith("label")]
|
| 26 |
+
features = [f for f in dataset_features.keys() if not f.startswith("label")]
|
| 27 |
+
|
| 28 |
if len(label_keys) == 0: # no labels found
|
| 29 |
# return everything for post processing
|
| 30 |
+
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
| 31 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
| 32 |
+
if hasattr(dataset_features[label_keys[0]], "feature"):
|
| 33 |
label_feat = dataset_features[label_keys[0]].feature
|
| 34 |
labels = label_feat.names
|
| 35 |
else:
|
| 36 |
labels = dataset_features[label_keys[0]].names
|
| 37 |
+
return labels, features, label_keys
|
|
|
|
| 38 |
except Exception as e:
|
| 39 |
logging.warning(
|
| 40 |
f"Get Labels/Features Failed for dataset: {e}"
|
| 41 |
)
|
| 42 |
+
return None, None, None
|
| 43 |
|
| 44 |
def check_model_task(model_id):
|
| 45 |
# check if model is valid on huggingface
|
text_classification_ui_helpers.py
CHANGED
|
@@ -138,7 +138,7 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
|
|
| 138 |
ds_labels = list(shared_labels)
|
| 139 |
if len(ds_labels) > MAX_LABELS:
|
| 140 |
ds_labels = ds_labels[:MAX_LABELS]
|
| 141 |
-
gr.Warning(f"
|
| 142 |
|
| 143 |
# sort labels to make sure the order is consistent
|
| 144 |
# prediction gives the order based on probability
|
|
@@ -198,7 +198,7 @@ def precheck_model_ds_enable_example_btn(
|
|
| 198 |
try:
|
| 199 |
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
|
| 200 |
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
|
| 201 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds[dataset_split])
|
| 202 |
|
| 203 |
if model_task is None or model_task != "text-classification":
|
| 204 |
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
|
|
@@ -300,7 +300,7 @@ def align_columns_and_show_prediction(
|
|
| 300 |
model_labels = list(prediction_response.keys())
|
| 301 |
|
| 302 |
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
|
| 303 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
| 304 |
|
| 305 |
# when dataset does not have labels or features
|
| 306 |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
|
@@ -390,13 +390,15 @@ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, da
|
|
| 390 |
return gr.update(interactive=False)
|
| 391 |
return gr.update(interactive=True)
|
| 392 |
|
| 393 |
-
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
| 394 |
label_mapping = {}
|
| 395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
| 396 |
-
logger.warn("Label mapping corrupted:
|
|
|
|
| 397 |
|
| 398 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
| 399 |
-
logger.warn("Feature mapping corrupted:
|
|
|
|
| 400 |
|
| 401 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
| 402 |
# align the saved labels with dataset labels order
|
|
@@ -405,7 +407,10 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
|
| 405 |
if "features" not in all_mappings.keys():
|
| 406 |
logger.warning("features not in all_mappings")
|
| 407 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
|
|
|
| 408 |
feature_mapping = all_mappings["features"]
|
|
|
|
|
|
|
| 409 |
return label_mapping, feature_mapping
|
| 410 |
|
| 411 |
def show_hf_token_info(token):
|
|
@@ -421,8 +426,8 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
|
| 421 |
|
| 422 |
# get ds labels and features again for alignment
|
| 423 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
| 424 |
-
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
| 425 |
-
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
|
| 426 |
|
| 427 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 428 |
save_job_to_pipe(
|
|
|
|
| 138 |
ds_labels = list(shared_labels)
|
| 139 |
if len(ds_labels) > MAX_LABELS:
|
| 140 |
ds_labels = ds_labels[:MAX_LABELS]
|
| 141 |
+
gr.Warning(f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd.")
|
| 142 |
|
| 143 |
# sort labels to make sure the order is consistent
|
| 144 |
# prediction gives the order based on probability
|
|
|
|
| 198 |
try:
|
| 199 |
ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
|
| 200 |
df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
|
| 201 |
+
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds[dataset_split])
|
| 202 |
|
| 203 |
if model_task is None or model_task != "text-classification":
|
| 204 |
gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
|
|
|
|
| 300 |
model_labels = list(prediction_response.keys())
|
| 301 |
|
| 302 |
ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
|
| 303 |
+
ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
|
| 304 |
|
| 305 |
# when dataset does not have labels or features
|
| 306 |
if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
|
|
|
|
| 390 |
return gr.update(interactive=False)
|
| 391 |
return gr.update(interactive=True)
|
| 392 |
|
| 393 |
+
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
|
| 394 |
label_mapping = {}
|
| 395 |
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
| 396 |
+
logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
| 397 |
+
\nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
|
| 398 |
|
| 399 |
if len(all_mappings["features"].keys()) != len(ds_features):
|
| 400 |
+
logger.warn(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
|
| 401 |
+
\nall_mappings: {all_mappings}\nds_features: {ds_features}""")
|
| 402 |
|
| 403 |
for i, label in zip(range(len(ds_labels)), ds_labels):
|
| 404 |
# align the saved labels with dataset labels order
|
|
|
|
| 407 |
if "features" not in all_mappings.keys():
|
| 408 |
logger.warning("features not in all_mappings")
|
| 409 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
| 410 |
+
|
| 411 |
feature_mapping = all_mappings["features"]
|
| 412 |
+
if len(label_keys) > 0:
|
| 413 |
+
feature_mapping.update({"label": label_keys[0]})
|
| 414 |
return label_mapping, feature_mapping
|
| 415 |
|
| 416 |
def show_hf_token_info(token):
|
|
|
|
| 426 |
|
| 427 |
# get ds labels and features again for alignment
|
| 428 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
| 429 |
+
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
|
| 430 |
+
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
|
| 431 |
|
| 432 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 433 |
save_job_to_pipe(
|