Commit 
							
							·
						
						f5c9414
	
1
								Parent(s):
							
							13a8aa5
								
Update base/models/unet.py
Browse files- base/models/unet.py +15 -15
 
    	
        base/models/unet.py
    CHANGED
    
    | 
         @@ -569,21 +569,21 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): 
     | 
|
| 569 | 
         | 
| 570 | 
         
             
                    model = cls.from_config(config)
         
     | 
| 571 | 
         
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         
     | 
| 572 | 
         
            -
                    if not os.path.isfile(model_file):
         
     | 
| 573 | 
         
            -
             
     | 
| 574 | 
         
            -
                    state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 575 | 
         
            -
                    for k, v in model.state_dict().items():
         
     | 
| 576 | 
         
            -
             
     | 
| 577 | 
         
            -
             
     | 
| 578 | 
         
            -
             
     | 
| 579 | 
         
            -
             
     | 
| 580 | 
         
            -
             
     | 
| 581 | 
         
            -
             
     | 
| 582 | 
         
            -
             
     | 
| 583 | 
         
            -
             
     | 
| 584 | 
         
            -
             
     | 
| 585 | 
         
            -
             
     | 
| 586 | 
         
            -
                    model.load_state_dict(state_dict)
         
     | 
| 587 | 
         | 
| 588 | 
         
             
                    return model
         
     | 
| 589 | 
         | 
| 
         | 
|
| 569 | 
         | 
| 570 | 
         
             
                    model = cls.from_config(config)
         
     | 
| 571 | 
         
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         
     | 
| 572 | 
         
            +
                    # if not os.path.isfile(model_file):
         
     | 
| 573 | 
         
            +
                    #     raise RuntimeError(f"{model_file} does not exist")
         
     | 
| 574 | 
         
            +
                    # state_dict = torch.load(model_file, map_location="cpu")
         
     | 
| 575 | 
         
            +
                    # for k, v in model.state_dict().items():
         
     | 
| 576 | 
         
            +
                    #     # print(k)
         
     | 
| 577 | 
         
            +
                    #     if '_temp' in k:
         
     | 
| 578 | 
         
            +
                    #         state_dict.update({k: v})
         
     | 
| 579 | 
         
            +
                    #     if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
         
     | 
| 580 | 
         
            +
                    #         k = k.replace('attn_fcross', 'attn1')
         
     | 
| 581 | 
         
            +
                    #         state_dict.update({k: state_dict[k]})
         
     | 
| 582 | 
         
            +
                    #     if 'norm_fcross' in k:
         
     | 
| 583 | 
         
            +
                    #         k = k.replace('norm_fcross', 'norm1')
         
     | 
| 584 | 
         
            +
                    #         state_dict.update({k: state_dict[k]})
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                    # model.load_state_dict(state_dict)
         
     | 
| 587 | 
         | 
| 588 | 
         
             
                    return model
         
     | 
| 589 | 
         |