Spaces:
Running
Running
Commit
·
6457b7a
1
Parent(s):
1204664
fix labels not found; handle input too long
Browse files- app_text_classification.py +1 -0
- text_classification.py +13 -1
- text_classification_ui_helpers.py +4 -4
app_text_classification.py
CHANGED
|
@@ -201,6 +201,7 @@ def get_demo():
|
|
| 201 |
gr.on(
|
| 202 |
triggers=[
|
| 203 |
model_id_input.change,
|
|
|
|
| 204 |
dataset_id_input.change,
|
| 205 |
dataset_config_input.change,
|
| 206 |
dataset_split_input.change,
|
|
|
|
| 201 |
gr.on(
|
| 202 |
triggers=[
|
| 203 |
model_id_input.change,
|
| 204 |
+
model_id_input.input,
|
| 205 |
dataset_id_input.change,
|
| 206 |
dataset_config_input.change,
|
| 207 |
dataset_split_input.change,
|
text_classification.py
CHANGED
|
@@ -28,10 +28,14 @@ def get_labels_and_features_from_dataset(ds):
|
|
| 28 |
if len(label_keys) == 0: # no labels found
|
| 29 |
# return everything for post processing
|
| 30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
|
|
|
|
|
|
| 31 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
| 32 |
if hasattr(dataset_features[label_keys[0]], "feature"):
|
| 33 |
label_feat = dataset_features[label_keys[0]].feature
|
| 34 |
labels = label_feat.names
|
|
|
|
|
|
|
| 35 |
else:
|
| 36 |
labels = dataset_features[label_keys[0]].names
|
| 37 |
return labels, features, label_keys
|
|
@@ -83,11 +87,19 @@ def hf_inference_api(model_id, hf_token, payload):
|
|
| 83 |
url = f"{hf_inference_api_endpoint}/models/{model_id}"
|
| 84 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 85 |
response = requests.post(url, headers=headers, json=payload)
|
|
|
|
| 86 |
if not hasattr(response, "status_code") or response.status_code != 200:
|
| 87 |
logger.warning(f"Request to inference API returns {response}")
|
|
|
|
| 88 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
return response.json()
|
| 90 |
-
except Exception:
|
| 91 |
return {"error": response.content}
|
| 92 |
|
| 93 |
def preload_hf_inference_api(model_id):
|
|
|
|
| 28 |
if len(label_keys) == 0: # no labels found
|
| 29 |
# return everything for post processing
|
| 30 |
return list(dataset_features.keys()), list(dataset_features.keys()), None
|
| 31 |
+
|
| 32 |
+
labels = None
|
| 33 |
if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
|
| 34 |
if hasattr(dataset_features[label_keys[0]], "feature"):
|
| 35 |
label_feat = dataset_features[label_keys[0]].feature
|
| 36 |
labels = label_feat.names
|
| 37 |
+
else:
|
| 38 |
+
labels = ds.unique(label_keys[0])
|
| 39 |
else:
|
| 40 |
labels = dataset_features[label_keys[0]].names
|
| 41 |
return labels, features, label_keys
|
|
|
|
| 87 |
url = f"{hf_inference_api_endpoint}/models/{model_id}"
|
| 88 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
| 89 |
response = requests.post(url, headers=headers, json=payload)
|
| 90 |
+
|
| 91 |
if not hasattr(response, "status_code") or response.status_code != 200:
|
| 92 |
logger.warning(f"Request to inference API returns {response}")
|
| 93 |
+
|
| 94 |
try:
|
| 95 |
+
output = response.json()
|
| 96 |
+
if "error" in output and "Input is too long" in output["error"]:
|
| 97 |
+
payload.update({"parameters": {"truncation": True, "max_length": 512}})
|
| 98 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 99 |
+
if not hasattr(response, "status_code") or response.status_code != 200:
|
| 100 |
+
logger.warning(f"Request to inference API returns {response}")
|
| 101 |
return response.json()
|
| 102 |
+
except Exception:
|
| 103 |
return {"error": response.content}
|
| 104 |
|
| 105 |
def preload_hf_inference_api(model_id):
|
text_classification_ui_helpers.py
CHANGED
|
@@ -341,8 +341,8 @@ def align_columns_and_show_prediction(
|
|
| 341 |
):
|
| 342 |
return (
|
| 343 |
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
|
| 344 |
-
gr.update(visible=
|
| 345 |
-
gr.update(visible=
|
| 346 |
gr.update(visible=True, open=True),
|
| 347 |
gr.update(interactive=(run_inference and inference_token != "")),
|
| 348 |
"",
|
|
@@ -351,7 +351,7 @@ def align_columns_and_show_prediction(
|
|
| 351 |
|
| 352 |
return (
|
| 353 |
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
|
| 354 |
-
gr.update(value=prediction_input, lines=len(prediction_input)//225 + 1, visible=True),
|
| 355 |
gr.update(value=prediction_response, visible=True),
|
| 356 |
gr.update(visible=True, open=False),
|
| 357 |
gr.update(interactive=(run_inference and inference_token != "")),
|
|
@@ -428,7 +428,7 @@ def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
|
| 428 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
| 429 |
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
|
| 430 |
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
|
| 431 |
-
|
| 432 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 433 |
save_job_to_pipe(
|
| 434 |
uid,
|
|
|
|
| 341 |
):
|
| 342 |
return (
|
| 343 |
gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
|
| 344 |
+
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
|
| 345 |
+
gr.update(value=prediction_response, visible=True),
|
| 346 |
gr.update(visible=True, open=True),
|
| 347 |
gr.update(interactive=(run_inference and inference_token != "")),
|
| 348 |
"",
|
|
|
|
| 351 |
|
| 352 |
return (
|
| 353 |
gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
|
| 354 |
+
gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
|
| 355 |
gr.update(value=prediction_response, visible=True),
|
| 356 |
gr.update(visible=True, open=False),
|
| 357 |
gr.update(interactive=(run_inference and inference_token != "")),
|
|
|
|
| 428 |
ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
|
| 429 |
ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
|
| 430 |
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
|
| 431 |
+
|
| 432 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 433 |
save_job_to_pipe(
|
| 434 |
uid,
|