zzz0527 commited on
Commit
2e742e8
·
verified ·
1 Parent(s): 036f936

Upload evaluate_laplace.py

Browse files
SPC-UQ/Image_Classification/evaluate_laplace.py CHANGED
@@ -273,6 +273,7 @@ if __name__ == "__main__":
273
  adv_y_prob = torch.cat(adv_y_prob, dim=0)
274
  adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
275
  _, adv_predictions = torch.max(adv_y_prob, 1)
 
276
  adv_accuracy = (adv_y_true == adv_predictions).mean()
277
  adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
278
  bin_labels = np.concatenate([
 
273
  adv_y_prob = torch.cat(adv_y_prob, dim=0)
274
  adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
275
  _, adv_predictions = torch.max(adv_y_prob, 1)
276
+ adv_predictions=adv_predictions.detach().cpu().numpy()
277
  adv_accuracy = (adv_y_true == adv_predictions).mean()
278
  adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
279
  bin_labels = np.concatenate([