Upload evaluate_laplace.py
Browse files
SPC-UQ/Image_Classification/evaluate_laplace.py
CHANGED
|
@@ -273,6 +273,7 @@ if __name__ == "__main__":
|
|
| 273 |
adv_y_prob = torch.cat(adv_y_prob, dim=0)
|
| 274 |
adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
|
| 275 |
_, adv_predictions = torch.max(adv_y_prob, 1)
|
|
|
|
| 276 |
adv_accuracy = (adv_y_true == adv_predictions).mean()
|
| 277 |
adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
|
| 278 |
bin_labels = np.concatenate([
|
|
|
|
| 273 |
adv_y_prob = torch.cat(adv_y_prob, dim=0)
|
| 274 |
adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
|
| 275 |
_, adv_predictions = torch.max(adv_y_prob, 1)
|
| 276 |
+
adv_predictions=adv_predictions.detach().cpu().numpy()
|
| 277 |
adv_accuracy = (adv_y_true == adv_predictions).mean()
|
| 278 |
adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
|
| 279 |
bin_labels = np.concatenate([
|