Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update optimization_utils.py
Browse files- optimization_utils.py +9 -0
    	
        optimization_utils.py
    CHANGED
    
    | @@ -96,3 +96,12 @@ def capture_component_call( | |
| 96 | 
             
                    except CapturedCallException as e:
         | 
| 97 | 
             
                        captured_call.args = e.args
         | 
| 98 | 
             
                        captured_call.kwargs = e.kwargs
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 96 | 
             
                    except CapturedCallException as e:
         | 
| 97 | 
             
                        captured_call.args = e.args
         | 
| 98 | 
             
                        captured_call.kwargs = e.kwargs
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def drain_module_parameters(module: torch.nn.Module):
         | 
| 102 | 
            +
                state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
         | 
| 103 | 
            +
                state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
         | 
| 104 | 
            +
                module.load_state_dict(state_dict, assign=True)
         | 
| 105 | 
            +
                for name, param in state_dict.items():
         | 
| 106 | 
            +
                    meta = state_dict_meta[name]
         | 
| 107 | 
            +
                    param.data = torch.Tensor([]).to(**meta)
         | 
 
			

