Spaces:
Runtime error
Runtime error
fix memory bug
Browse files- inference/pipeline.py +10 -5
inference/pipeline.py
CHANGED
|
@@ -83,11 +83,16 @@ class RealCustomInferencePipeline:
|
|
| 83 |
vision_model_config = unet_config.pop("vision_model_config", None)
|
| 84 |
self.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
| 85 |
|
| 86 |
-
self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
| 87 |
-
|
| 88 |
-
self.unet_model.eval().to(self.device).to(self.torch_dtype)
|
| 89 |
-
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
| 90 |
-
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
print("loading unet model finished.")
|
| 92 |
|
| 93 |
def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
|
|
|
|
| 83 |
vision_model_config = unet_config.pop("vision_model_config", None)
|
| 84 |
self.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
| 85 |
|
| 86 |
+
# self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
| 87 |
+
|
| 88 |
+
# self.unet_model.eval().to(self.device).to(self.torch_dtype)
|
| 89 |
+
# self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
| 90 |
+
# self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
| 91 |
+
with torch.device("meta"):
|
| 92 |
+
self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
| 93 |
+
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False, assign=True)
|
| 94 |
+
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False, assign=True)
|
| 95 |
+
self.unet_model.eval()
|
| 96 |
print("loading unet model finished.")
|
| 97 |
|
| 98 |
def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
|