hanquansanren commited on
Commit
9847dc4
·
1 Parent(s): d938e9f
Files changed (1) hide show
  1. app.py +1 -130
app.py CHANGED
@@ -180,139 +180,10 @@ def prepare_data_single(input_image, input_image_ori):
180
  return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
181
 
182
 
183
- @GPU
184
- def run_evaluation_docunet(
185
- settings, logger, val_loader, diffusion: GaussianDiffusion, model,
186
- pretrained_dewarp_model,pretrained_line_seg_model=None,pretrained_seg_model=None
187
- ):
188
- os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
189
- batch_preprocessing = None
190
- pbar = tqdm(enumerate(val_loader), total=len(val_loader))
191
- pyramid = VGGPyramid(train=False).to(dist_util.dev())
192
- SIZE = None
193
-
194
- # for each document image
195
-
196
- for i, data in pbar:
197
- radius = 4
198
- raw_corr = None
199
- data_path = data['path']
200
- source_288 = F.interpolate(data['source_image'], size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
201
-
202
- if settings.env.time_variant == True:
203
- init_feat = torch.zeros((data['source_image'].shape[0], 256, 64, 64), dtype=torch.float32).to(dist_util.dev())
204
- else:
205
- init_feat = None
206
-
207
-
208
- with torch.inference_mode():
209
- ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1
210
- ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288]
211
- if settings.env.use_init_flow:
212
- init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
213
- else:
214
- init_flow = torch.zeros((data['source_image'].shape[0], 2, 64, 64), dtype=torch.float32).to(dist_util.dev())
215
-
216
-
217
- (
218
- data,
219
- H_ori, # 512
220
- W_ori, # 512
221
- source, # [1, 3, 512, 512] 0-1
222
- target, # None
223
- batch_ori, # None
224
- batch_ori_inter, # None
225
- source_256,# [1, 3, 256, 256] 0-1
226
- target_256, # None
227
- source_vis, # [1, 3, H, W] cpu仅用于可视化
228
- target_vis, # None
229
- mask, # [1, 512, 512] 全白
230
- source_0
231
- ) = prepare_data(settings, batch_preprocessing, SIZE, data)
232
-
233
-
234
-
235
- with torch.no_grad():
236
- if settings.env.use_gt_mask == False:
237
- # ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
238
- mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
239
- hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False)
240
- hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False)
241
- hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False)
242
- hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False)
243
- hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False)
244
- hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False)
245
-
246
- seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
247
- # tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
248
- if settings.env.use_line_mask:
249
- textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
250
- textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64]
251
- else:
252
- seg_map_all = None
253
- textline_map = None
254
 
255
 
256
- if settings.env.train_VGG:
257
- c20 = None
258
- feature_size = 64
259
- else:
260
- feature_size = 64
261
- if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
262
- with th.no_grad():
263
- c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
264
- # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
265
- else:
266
- with th.no_grad():
267
- c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
268
- # 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
269
 
270
- source_64 = None # F.interpolate(source, size=(feature_size), mode='bilinear', align_corners=True)
271
- logger.info(f"Starting sampling with VGG Features")
272
-
273
- sample = run_sample_lr_dewarping(
274
- settings,
275
- logger,
276
- diffusion,
277
- model,
278
- radius, # 4
279
- source, # [B, 3, 512, 512] 0~1
280
- feature_size, # 64
281
- raw_corr, # None
282
- init_flow, # [B, 2, 64, 64] -1~1
283
- c20, # # [B, 64, 64, 64]
284
- source_64, # None
285
- pyramid,
286
- mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x
287
- seg_map_all,
288
- textline_map,
289
- init_feat
290
- ) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果
291
-
292
-
293
- if settings.env.use_sr_net == False:
294
- sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
295
- # sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
296
- # sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
297
- base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
298
- # sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
299
- sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99
300
- ref_flow = None
301
- if ref_flow is not None:
302
- ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
303
- # ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori
304
- # ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori
305
- ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1
306
- # init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
307
- else:
308
- raise ValueError("Invalid value")
309
-
310
-
311
- if settings.env.visualize:
312
- output = visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow)
313
-
314
-
315
-
316
  def run_single_docunet(input_image_ori):
317
  input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3]
318
 
 
180
  return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ @GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def run_single_docunet(input_image_ori):
188
  input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3]
189