Christina Theodoris
commited on
Commit
·
eeba323
1
Parent(s):
f75f5ac
update examples for predict_eval and handle roc for 2 cell classes
Browse files- examples/cell_classification.ipynb +1 -2
- geneformer/classifier.py +52 -34
- geneformer/evaluation_utils.py +1 -1
examples/cell_classification.ipynb
CHANGED
|
@@ -266,8 +266,7 @@
|
|
| 266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 267 |
" output_directory=output_dir,\n",
|
| 268 |
" output_prefix=output_prefix,\n",
|
| 269 |
-
" split_id_dict=train_valid_id_split_dict
|
| 270 |
-
" predict=True)"
|
| 271 |
]
|
| 272 |
},
|
| 273 |
{
|
|
|
|
| 266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
| 267 |
" output_directory=output_dir,\n",
|
| 268 |
" output_prefix=output_prefix,\n",
|
| 269 |
+
" split_id_dict=train_valid_id_split_dict)"
|
|
|
|
| 270 |
]
|
| 271 |
},
|
| 272 |
{
|
geneformer/classifier.py
CHANGED
|
@@ -30,7 +30,7 @@ Geneformer classifier.
|
|
| 30 |
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
| 31 |
... output_directory="path/to/output_directory",
|
| 32 |
... output_prefix="output_prefix",
|
| 33 |
-
...
|
| 34 |
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
| 35 |
... output_directory="path/to/output_directory",
|
| 36 |
... output_prefix="output_prefix",
|
|
@@ -308,7 +308,7 @@ class Classifier:
|
|
| 308 |
output_directory,
|
| 309 |
output_prefix,
|
| 310 |
split_id_dict=None,
|
| 311 |
-
test_size=
|
| 312 |
attr_to_split=None,
|
| 313 |
attr_to_balance=None,
|
| 314 |
max_trials=100,
|
|
@@ -417,27 +417,48 @@ class Classifier:
|
|
| 417 |
data_dict["test"].save_to_disk(test_data_output_path)
|
| 418 |
elif (test_size is not None) and (self.classifier == "cell"):
|
| 419 |
if 1 > test_size > 0:
|
| 420 |
-
|
| 421 |
-
data
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
).with_suffix(".dataset")
|
| 439 |
-
|
| 440 |
-
|
| 441 |
else:
|
| 442 |
data_output_path = (
|
| 443 |
Path(output_directory) / f"{output_prefix}_labeled"
|
|
@@ -1012,7 +1033,7 @@ class Classifier:
|
|
| 1012 |
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
| 1013 |
|
| 1014 |
# evaluate the model
|
| 1015 |
-
|
| 1016 |
model,
|
| 1017 |
num_classes,
|
| 1018 |
id_class_dict,
|
|
@@ -1023,24 +1044,21 @@ class Classifier:
|
|
| 1023 |
)
|
| 1024 |
|
| 1025 |
all_conf_mat_df = pd.DataFrame(
|
| 1026 |
-
|
| 1027 |
columns=id_class_dict.values(),
|
| 1028 |
index=id_class_dict.values(),
|
| 1029 |
)
|
| 1030 |
all_metrics = {
|
| 1031 |
"conf_matrix": all_conf_mat_df,
|
| 1032 |
-
"macro_f1":
|
| 1033 |
-
"acc":
|
| 1034 |
}
|
| 1035 |
all_roc_metrics = None # roc metrics not reported for multiclass
|
|
|
|
| 1036 |
if num_classes == 2:
|
| 1037 |
mean_fpr = np.linspace(0, 1, 100)
|
| 1038 |
-
|
| 1039 |
-
all_roc_auc =
|
| 1040 |
-
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
| 1041 |
-
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
| 1042 |
-
all_tpr, all_roc_auc, all_tpr_wt
|
| 1043 |
-
)
|
| 1044 |
all_roc_metrics = {
|
| 1045 |
"mean_tpr": mean_tpr,
|
| 1046 |
"mean_fpr": mean_fpr,
|
|
@@ -1137,7 +1155,7 @@ class Classifier:
|
|
| 1137 |
|
| 1138 |
predictions_file : path
|
| 1139 |
| Path of model predictions output to plot
|
| 1140 |
-
| (saved output from self.validate if
|
| 1141 |
| (or saved output from self.evaluate_saved_model)
|
| 1142 |
id_class_dict_file : Path
|
| 1143 |
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
|
@@ -1173,7 +1191,7 @@ class Classifier:
|
|
| 1173 |
predictions_logits = np.array(predictions["predictions"])
|
| 1174 |
true_ids = predictions["label_ids"]
|
| 1175 |
else:
|
| 1176 |
-
# format is output from self.validate if
|
| 1177 |
predictions_logits = predictions.predictions
|
| 1178 |
true_ids = predictions.label_ids
|
| 1179 |
|
|
|
|
| 30 |
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
| 31 |
... output_directory="path/to/output_directory",
|
| 32 |
... output_prefix="output_prefix",
|
| 33 |
+
... predict_eval=True)
|
| 34 |
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
| 35 |
... output_directory="path/to/output_directory",
|
| 36 |
... output_prefix="output_prefix",
|
|
|
|
| 308 |
output_directory,
|
| 309 |
output_prefix,
|
| 310 |
split_id_dict=None,
|
| 311 |
+
test_size=None,
|
| 312 |
attr_to_split=None,
|
| 313 |
attr_to_balance=None,
|
| 314 |
max_trials=100,
|
|
|
|
| 417 |
data_dict["test"].save_to_disk(test_data_output_path)
|
| 418 |
elif (test_size is not None) and (self.classifier == "cell"):
|
| 419 |
if 1 > test_size > 0:
|
| 420 |
+
if attr_to_split is None:
|
| 421 |
+
data_dict = data.train_test_split(
|
| 422 |
+
test_size=test_size,
|
| 423 |
+
stratify_by_column=self.stratify_splits_col,
|
| 424 |
+
seed=42,
|
| 425 |
+
)
|
| 426 |
+
train_data_output_path = (
|
| 427 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
| 428 |
+
).with_suffix(".dataset")
|
| 429 |
+
test_data_output_path = (
|
| 430 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
| 431 |
+
).with_suffix(".dataset")
|
| 432 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
| 433 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
| 434 |
+
else:
|
| 435 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
| 436 |
+
data,
|
| 437 |
+
attr_to_split,
|
| 438 |
+
attr_to_balance,
|
| 439 |
+
test_size,
|
| 440 |
+
max_trials,
|
| 441 |
+
pval_threshold,
|
| 442 |
+
self.cell_state_dict["state_key"],
|
| 443 |
+
self.nproc,
|
| 444 |
+
)
|
| 445 |
+
balance_df.to_csv(
|
| 446 |
+
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
| 447 |
+
)
|
| 448 |
+
train_data_output_path = (
|
| 449 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
| 450 |
+
).with_suffix(".dataset")
|
| 451 |
+
test_data_output_path = (
|
| 452 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
| 453 |
+
).with_suffix(".dataset")
|
| 454 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
| 455 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
| 456 |
+
else:
|
| 457 |
+
data_output_path = (
|
| 458 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
| 459 |
).with_suffix(".dataset")
|
| 460 |
+
data.save_to_disk(data_output_path)
|
| 461 |
+
print(data_output_path)
|
| 462 |
else:
|
| 463 |
data_output_path = (
|
| 464 |
Path(output_directory) / f"{output_prefix}_labeled"
|
|
|
|
| 1033 |
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
| 1034 |
|
| 1035 |
# evaluate the model
|
| 1036 |
+
result = self.evaluate_model(
|
| 1037 |
model,
|
| 1038 |
num_classes,
|
| 1039 |
id_class_dict,
|
|
|
|
| 1044 |
)
|
| 1045 |
|
| 1046 |
all_conf_mat_df = pd.DataFrame(
|
| 1047 |
+
result["conf_mat"],
|
| 1048 |
columns=id_class_dict.values(),
|
| 1049 |
index=id_class_dict.values(),
|
| 1050 |
)
|
| 1051 |
all_metrics = {
|
| 1052 |
"conf_matrix": all_conf_mat_df,
|
| 1053 |
+
"macro_f1": result["macro_f1"],
|
| 1054 |
+
"acc": result["acc"],
|
| 1055 |
}
|
| 1056 |
all_roc_metrics = None # roc metrics not reported for multiclass
|
| 1057 |
+
|
| 1058 |
if num_classes == 2:
|
| 1059 |
mean_fpr = np.linspace(0, 1, 100)
|
| 1060 |
+
mean_tpr = result["roc_metrics"]["interp_tpr"]
|
| 1061 |
+
all_roc_auc = result["roc_metrics"]["auc"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1062 |
all_roc_metrics = {
|
| 1063 |
"mean_tpr": mean_tpr,
|
| 1064 |
"mean_fpr": mean_fpr,
|
|
|
|
| 1155 |
|
| 1156 |
predictions_file : path
|
| 1157 |
| Path of model predictions output to plot
|
| 1158 |
+
| (saved output from self.validate if predict_eval=True)
|
| 1159 |
| (or saved output from self.evaluate_saved_model)
|
| 1160 |
id_class_dict_file : Path
|
| 1161 |
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
|
|
|
| 1191 |
predictions_logits = np.array(predictions["predictions"])
|
| 1192 |
true_ids = predictions["label_ids"]
|
| 1193 |
else:
|
| 1194 |
+
# format is output from self.validate if predict_eval=True
|
| 1195 |
predictions_logits = predictions.predictions
|
| 1196 |
true_ids = predictions.label_ids
|
| 1197 |
|
geneformer/evaluation_utils.py
CHANGED
|
@@ -201,10 +201,10 @@ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix
|
|
| 201 |
plt.ylabel("True Positive Rate")
|
| 202 |
plt.title(title)
|
| 203 |
plt.legend(loc="lower right")
|
| 204 |
-
plt.show()
|
| 205 |
|
| 206 |
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
| 207 |
plt.savefig(output_file, bbox_inches="tight")
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
# plot confusion matrix
|
|
|
|
| 201 |
plt.ylabel("True Positive Rate")
|
| 202 |
plt.title(title)
|
| 203 |
plt.legend(loc="lower right")
|
|
|
|
| 204 |
|
| 205 |
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
| 206 |
plt.savefig(output_file, bbox_inches="tight")
|
| 207 |
+
plt.show()
|
| 208 |
|
| 209 |
|
| 210 |
# plot confusion matrix
|