Spaces:
Running
on
Zero
Running
on
Zero
Update optimization_utils.py
Browse files- optimization_utils.py +9 -0
optimization_utils.py
CHANGED
|
@@ -96,3 +96,12 @@ def capture_component_call(
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def drain_module_parameters(module: torch.nn.Module):
|
| 102 |
+
state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
|
| 103 |
+
state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
| 104 |
+
module.load_state_dict(state_dict, assign=True)
|
| 105 |
+
for name, param in state_dict.items():
|
| 106 |
+
meta = state_dict_meta[name]
|
| 107 |
+
param.data = torch.Tensor([]).to(**meta)
|