Commit
·
f5c9414
1
Parent(s):
13a8aa5
Update base/models/unet.py
Browse files- base/models/unet.py +15 -15
base/models/unet.py
CHANGED
|
@@ -569,21 +569,21 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
| 569 |
|
| 570 |
model = cls.from_config(config)
|
| 571 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 572 |
-
if not os.path.isfile(model_file):
|
| 573 |
-
|
| 574 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
| 575 |
-
for k, v in model.state_dict().items():
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
model.load_state_dict(state_dict)
|
| 587 |
|
| 588 |
return model
|
| 589 |
|
|
|
|
| 569 |
|
| 570 |
model = cls.from_config(config)
|
| 571 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 572 |
+
# if not os.path.isfile(model_file):
|
| 573 |
+
# raise RuntimeError(f"{model_file} does not exist")
|
| 574 |
+
# state_dict = torch.load(model_file, map_location="cpu")
|
| 575 |
+
# for k, v in model.state_dict().items():
|
| 576 |
+
# # print(k)
|
| 577 |
+
# if '_temp' in k:
|
| 578 |
+
# state_dict.update({k: v})
|
| 579 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
| 580 |
+
# k = k.replace('attn_fcross', 'attn1')
|
| 581 |
+
# state_dict.update({k: state_dict[k]})
|
| 582 |
+
# if 'norm_fcross' in k:
|
| 583 |
+
# k = k.replace('norm_fcross', 'norm1')
|
| 584 |
+
# state_dict.update({k: state_dict[k]})
|
| 585 |
+
|
| 586 |
+
# model.load_state_dict(state_dict)
|
| 587 |
|
| 588 |
return model
|
| 589 |
|