Spaces:
Build error
Build error
edits
Browse files
app.py
CHANGED
|
@@ -29,10 +29,18 @@ state = torch.load('fire.pth', map_location='cpu')
|
|
| 29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
| 30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
| 31 |
net_sfm.load_state_dict(state['state_dict'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
|
|
|
|
|
|
|
|
|
| 34 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
| 35 |
-
net_imagenet.load_state_dict(state2['state_dict'])
|
| 36 |
|
| 37 |
# ---------------------------------------
|
| 38 |
transform = transforms.Compose([
|
|
|
|
| 29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
| 30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
| 31 |
net_sfm.load_state_dict(state['state_dict'])
|
| 32 |
+
dim_red_params_dict = {}
|
| 33 |
+
for name, param in net_sfm.named_parameters():
|
| 34 |
+
if 'dim_reduction' in name:
|
| 35 |
+
dim_red_params_dict[name] = param
|
| 36 |
+
|
| 37 |
|
| 38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
| 39 |
+
state2['net_params'] = state['net_params']
|
| 40 |
+
state2['state_dict'] += dim_red_params_dict
|
| 41 |
+
# state2['net_params'] =
|
| 42 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
| 43 |
+
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
| 44 |
|
| 45 |
# ---------------------------------------
|
| 46 |
transform = transforms.Compose([
|