zzz0527 commited on
Commit
de29411
·
verified ·
1 Parent(s): a94b2ae

Upload spc.py

Browse files
SPC-UQ/MNIST_Classification/trainers/spc.py CHANGED
@@ -20,6 +20,9 @@ class SPC:
20
  self.model = model.to(device)
21
  self.criterion = nn.CrossEntropyLoss()
22
  self.optimizer = self._init_optimizer(optimizer_type, learning_rate)
 
 
 
23
 
24
  self.epoch = 0
25
  self.max_acc = -1
@@ -35,6 +38,22 @@ class SPC:
35
  else:
36
  raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def cls_loss(self, target, logits):
39
  loss_cls = self.criterion(logits, target)
40
  return loss_cls
@@ -50,7 +69,7 @@ class SPC:
50
  loss_mar = F.mse_loss(mar, mar_target)
51
  return loss_mar
52
 
53
- def run_train_step(self, data, target, num_classes):
54
  """Single training step."""
55
  self.model.train()
56
  logits, mar = self.model(data)
@@ -62,6 +81,28 @@ class SPC:
62
  self.optimizer.step()
63
  return loss.item(), mar
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def evaluate(self, test_loader, ood_loader, num_classes, threshold=None):
66
  """
67
  Evaluate classification accuracy and uncertainty quantification:
@@ -73,6 +114,7 @@ class SPC:
73
  y_all, logits_all, mar_all = [], [], []
74
 
75
  with torch.no_grad():
 
76
  for data, target in test_loader:
77
  data, target = data.to(device), target.to(device)
78
  logits, mar = self.model(data)
@@ -88,6 +130,7 @@ class SPC:
88
  expected_uncertainty = 2 * logits * (1 - logits)
89
  uncertainty = torch.sum(torch.abs(mar - expected_uncertainty), dim=1)
90
  confidences, predictions = torch.max(logits, dim=1)
 
91
 
92
  threshold = np.quantile(uncertainty.cpu().numpy(), 0.5)
93
  acc = (predictions == y_all).float().mean().item()
@@ -119,10 +162,10 @@ class SPC:
119
  all_uncertainty = torch.cat([uncertainty, uncertainty_ood])
120
  auroc = metrics.roc_auc_score(bin_labels.cpu().numpy(), all_uncertainty.cpu().numpy())
121
 
122
- return acc, acc_confident, acc_uncertain, auroc, 1000 * (len(test_loader.dataset) / len(test_loader))
123
 
124
  def train(self, train_dataset, test_dataset, ood_dataset, num_classes,
125
- batch_size=128, num_epochs=40, verbose=True, freq=1):
126
  """
127
  Train classifier with SPC-based uncertainty, and evaluate per epoch.
128
  """
@@ -134,38 +177,84 @@ class SPC:
134
  acc_curve, acc_conf_curve, acc_unc_curve = [], [], []
135
  train_times, test_times = [], []
136
 
137
- for self.epoch in range(1, num_epochs + 1):
138
- total_loss, total_unc = 0.0, 0.0
139
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- for data, target in train_loader:
142
- data, target = data.to(device), target.to(device)
143
- loss, mar = self.run_train_step(data, target, num_classes)
144
- total_loss += loss
145
-
146
- if verbose:
147
- with torch.no_grad():
148
- logits, mar = self.model(data)
149
- prob = F.softmax(logits, dim=1)
150
- mar_target = 2 * prob * (1 - prob)
151
- uncertainty = torch.sum(torch.abs(mar - mar_target), dim=1)
152
- total_unc += uncertainty.mean().item()
153
-
154
- train_times.append((time.time() - start_time) * 1000)
155
- avg_loss = total_loss / len(train_loader)
156
- avg_uncertainty = total_unc / len(train_loader)
157
-
158
- loss_curve.append(avg_loss)
159
- uncertainty_curve.append(avg_uncertainty)
160
-
161
- if self.epoch % freq == 0:
162
- acc, acc_conf, acc_unc, auroc, test_time = self.evaluate(test_loader, ood_loader, num_classes)
163
- acc_curve.append(acc)
164
- acc_conf_curve.append(acc_conf)
165
- acc_unc_curve.append(acc_unc)
166
- test_times.append(test_time)
167
-
168
- if acc > self.max_acc:
169
  self.max_acc = acc
170
  self.max_acc_confident = acc_conf
171
  self.max_acc_uncertain = acc_unc
@@ -193,7 +282,7 @@ class SPC:
193
  plt.grid(True); plt.legend(); plt.show()
194
 
195
  plt.figure(figsize=(10, 6))
196
- plt.plot(acc, label='Test Accuracy')
197
  plt.plot(acc_cer, label='Confident Accuracy')
198
  plt.plot(acc_unc, label='Uncertain Accuracy')
199
  plt.title('Test Accuracy Curves')
 
20
  self.model = model.to(device)
21
  self.criterion = nn.CrossEntropyLoss()
22
  self.optimizer = self._init_optimizer(optimizer_type, learning_rate)
23
+ self.optimizer_cls = self._init_optimizer_cls(optimizer_type, learning_rate)
24
+ self.optimizer_mar = self._init_optimizer_mar(optimizer_type, learning_rate)
25
+
26
 
27
  self.epoch = 0
28
  self.max_acc = -1
 
38
  else:
39
  raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
40
 
41
+ def _init_optimizer_cls(self, optimizer_type, lr):
42
+ if optimizer_type.upper() == 'ADAM':
43
+ return optim.Adam(list(self.model.hidden.parameters()) + list(self.model.hidden_pred.parameters()) + list(self.model.output_pred.parameters()), lr=lr)
44
+ elif optimizer_type.upper() == 'SGD':
45
+ return optim.SGD(list(self.model.hidden.parameters()) + list(self.model.hidden_pred.parameters()) + list(self.model.output_pred.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
46
+ else:
47
+ raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
48
+
49
+ def _init_optimizer_mar(self, optimizer_type, lr):
50
+ if optimizer_type.upper() == 'ADAM':
51
+ return optim.Adam(list(self.model.hidden_mar.parameters()) + list(self.model.output_mar.parameters()), lr=lr)
52
+ elif optimizer_type.upper() == 'SGD':
53
+ return optim.SGD(list(self.model.hidden_mar.parameters()) + list(self.model.output_mar.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
54
+ else:
55
+ raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
56
+
57
  def cls_loss(self, target, logits):
58
  loss_cls = self.criterion(logits, target)
59
  return loss_cls
 
69
  loss_mar = F.mse_loss(mar, mar_target)
70
  return loss_mar
71
 
72
+ def joint_train_step(self, data, target, num_classes):
73
  """Single training step."""
74
  self.model.train()
75
  logits, mar = self.model(data)
 
81
  self.optimizer.step()
82
  return loss.item(), mar
83
 
84
+ def cls_train_step(self, data, target):
85
+ """Single training step."""
86
+ self.model.train()
87
+ logits, mar = self.model(data)
88
+ loss_cls = self.cls_loss(target, logits)
89
+ loss = loss_cls
90
+ self.optimizer_cls.zero_grad()
91
+ loss.backward()
92
+ self.optimizer_cls.step()
93
+ return loss.item(), mar
94
+
95
+ def mar_train_step(self, data, target, num_classes):
96
+ """Single training step."""
97
+ self.model.train()
98
+ logits, mar = self.model(data)
99
+ loss_mar = self.mar_loss(target, num_classes, logits, mar)
100
+ loss = loss_mar
101
+ self.optimizer_mar.zero_grad()
102
+ loss.backward()
103
+ self.optimizer_mar.step()
104
+ return loss.item(), mar
105
+
106
  def evaluate(self, test_loader, ood_loader, num_classes, threshold=None):
107
  """
108
  Evaluate classification accuracy and uncertainty quantification:
 
114
  y_all, logits_all, mar_all = [], [], []
115
 
116
  with torch.no_grad():
117
+ start_time = time.time()
118
  for data, target in test_loader:
119
  data, target = data.to(device), target.to(device)
120
  logits, mar = self.model(data)
 
130
  expected_uncertainty = 2 * logits * (1 - logits)
131
  uncertainty = torch.sum(torch.abs(mar - expected_uncertainty), dim=1)
132
  confidences, predictions = torch.max(logits, dim=1)
133
+ test_time=(time.time() - start_time) * 1000
134
 
135
  threshold = np.quantile(uncertainty.cpu().numpy(), 0.5)
136
  acc = (predictions == y_all).float().mean().item()
 
162
  all_uncertainty = torch.cat([uncertainty, uncertainty_ood])
163
  auroc = metrics.roc_auc_score(bin_labels.cpu().numpy(), all_uncertainty.cpu().numpy())
164
 
165
+ return acc, acc_confident, acc_uncertain, auroc, test_time / len(test_loader)
166
 
167
  def train(self, train_dataset, test_dataset, ood_dataset, num_classes,
168
+ batch_size=128, num_epochs=40, verbose=True, freq=1, joint_training=0):
169
  """
170
  Train classifier with SPC-based uncertainty, and evaluate per epoch.
171
  """
 
177
  acc_curve, acc_conf_curve, acc_unc_curve = [], [], []
178
  train_times, test_times = [], []
179
 
180
+ if joint_training:
181
+ for self.epoch in range(1, num_epochs + 1):
182
+ total_loss, total_unc = 0.0, 0.0
183
+ start_time = time.time()
184
+
185
+ for data, target in train_loader:
186
+ data, target = data.to(device), target.to(device)
187
+ loss, mar = self.joint_train_step(data, target, num_classes)
188
+ total_loss += loss
189
+
190
+ if verbose:
191
+ with torch.no_grad():
192
+ logits, mar = self.model(data)
193
+ prob = F.softmax(logits, dim=1)
194
+ mar_target = 2 * prob * (1 - prob)
195
+ uncertainty = torch.sum(torch.abs(mar - mar_target), dim=1)
196
+ total_unc += uncertainty.mean().item()
197
+
198
+ train_times.append((time.time() - start_time) * 1000)
199
+ avg_loss = total_loss / len(train_loader)
200
+ avg_uncertainty = total_unc / len(train_loader)
201
+
202
+ loss_curve.append(avg_loss)
203
+ uncertainty_curve.append(avg_uncertainty)
204
+
205
+ if self.epoch % freq == 0:
206
+ acc, acc_conf, acc_unc, auroc, test_time = self.evaluate(test_loader, ood_loader, num_classes)
207
+ acc_curve.append(acc)
208
+ acc_conf_curve.append(acc_conf)
209
+ acc_unc_curve.append(acc_unc)
210
+ test_times.append(test_time)
211
+
212
+ if acc > self.max_acc:
213
+ self.max_acc = acc
214
+ self.max_acc_confident = acc_conf
215
+ self.max_acc_uncertain = acc_unc
216
+ self.max_auroc = auroc
217
+
218
+ else:
219
+ for self.epoch in range(1, num_epochs + 1):
220
+ total_loss = 0.0
221
+
222
+ for data, target in train_loader:
223
+ data, target = data.to(device), target.to(device)
224
+ loss, mar = self.cls_train_step(data, target)
225
+ total_loss += loss
226
+
227
+ avg_loss = total_loss / len(train_loader)
228
+ loss_curve.append(avg_loss)
229
+
230
+ for self.epoch in range(1, num_epochs + 1):
231
+ total_unc = 0.0
232
+ start_time = time.time()
233
+
234
+ for data, target in train_loader:
235
+ data, target = data.to(device), target.to(device)
236
+ loss, mar = self.mar_train_step(data, target, num_classes)
237
+
238
+ if verbose:
239
+ with torch.no_grad():
240
+ logits, mar = self.model(data)
241
+ prob = F.softmax(logits, dim=1)
242
+ mar_target = 2 * prob * (1 - prob)
243
+ uncertainty = torch.sum(torch.abs(mar - mar_target), dim=1)
244
+ total_unc += uncertainty.mean().item()
245
+
246
+ train_times.append((time.time() - start_time) * 1000)
247
+ avg_uncertainty = total_unc / len(train_loader)
248
+
249
+ uncertainty_curve.append(avg_uncertainty)
250
+
251
+ if self.epoch % freq == 0:
252
+ acc, acc_conf, acc_unc, auroc, test_time = self.evaluate(test_loader, ood_loader, num_classes)
253
+ acc_curve.append(acc)
254
+ acc_conf_curve.append(acc_conf)
255
+ acc_unc_curve.append(acc_unc)
256
+ test_times.append(test_time)
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  self.max_acc = acc
259
  self.max_acc_confident = acc_conf
260
  self.max_acc_uncertain = acc_unc
 
282
  plt.grid(True); plt.legend(); plt.show()
283
 
284
  plt.figure(figsize=(10, 6))
285
+ plt.plot(acc, label='Overall Accuracy')
286
  plt.plot(acc_cer, label='Confident Accuracy')
287
  plt.plot(acc_unc, label='Uncertain Accuracy')
288
  plt.title('Test Accuracy Curves')