Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9847dc4
1
Parent(s):
d938e9f
app.py
CHANGED
|
@@ -180,139 +180,10 @@ def prepare_data_single(input_image, input_image_ori):
|
|
| 180 |
return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
|
| 181 |
|
| 182 |
|
| 183 |
-
@GPU
|
| 184 |
-
def run_evaluation_docunet(
|
| 185 |
-
settings, logger, val_loader, diffusion: GaussianDiffusion, model,
|
| 186 |
-
pretrained_dewarp_model,pretrained_line_seg_model=None,pretrained_seg_model=None
|
| 187 |
-
):
|
| 188 |
-
os.makedirs(f'vis_hp/{settings.env.eval_dataset_name}/{settings.name}', exist_ok=True)
|
| 189 |
-
batch_preprocessing = None
|
| 190 |
-
pbar = tqdm(enumerate(val_loader), total=len(val_loader))
|
| 191 |
-
pyramid = VGGPyramid(train=False).to(dist_util.dev())
|
| 192 |
-
SIZE = None
|
| 193 |
-
|
| 194 |
-
# for each document image
|
| 195 |
-
|
| 196 |
-
for i, data in pbar:
|
| 197 |
-
radius = 4
|
| 198 |
-
raw_corr = None
|
| 199 |
-
data_path = data['path']
|
| 200 |
-
source_288 = F.interpolate(data['source_image'], size=(288), mode='bilinear', align_corners=True).to(dist_util.dev())
|
| 201 |
-
|
| 202 |
-
if settings.env.time_variant == True:
|
| 203 |
-
init_feat = torch.zeros((data['source_image'].shape[0], 256, 64, 64), dtype=torch.float32).to(dist_util.dev())
|
| 204 |
-
else:
|
| 205 |
-
init_feat = None
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
with torch.inference_mode():
|
| 209 |
-
ref_bm, mask_x = pretrained_dewarp_model(source_288) # [1,2,288,288] 0~288 0~1
|
| 210 |
-
ref_flow = ref_bm/287.0 # [-1, 1] # [1,2,288,288]
|
| 211 |
-
if settings.env.use_init_flow:
|
| 212 |
-
init_flow = F.interpolate(ref_flow, size=(64), mode='bilinear', align_corners=True) # [24, 2, 64, 64]
|
| 213 |
-
else:
|
| 214 |
-
init_flow = torch.zeros((data['source_image'].shape[0], 2, 64, 64), dtype=torch.float32).to(dist_util.dev())
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
(
|
| 218 |
-
data,
|
| 219 |
-
H_ori, # 512
|
| 220 |
-
W_ori, # 512
|
| 221 |
-
source, # [1, 3, 512, 512] 0-1
|
| 222 |
-
target, # None
|
| 223 |
-
batch_ori, # None
|
| 224 |
-
batch_ori_inter, # None
|
| 225 |
-
source_256,# [1, 3, 256, 256] 0-1
|
| 226 |
-
target_256, # None
|
| 227 |
-
source_vis, # [1, 3, H, W] cpu仅用于可视化
|
| 228 |
-
target_vis, # None
|
| 229 |
-
mask, # [1, 512, 512] 全白
|
| 230 |
-
source_0
|
| 231 |
-
) = prepare_data(settings, batch_preprocessing, SIZE, data)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
with torch.no_grad():
|
| 236 |
-
if settings.env.use_gt_mask == False:
|
| 237 |
-
# ref_bm, mask_x = self.pretrained_dewarp_model(source_288) # [1,2,288,288] bm 0~288 mskx0-256
|
| 238 |
-
mskx, d0, hx6, hx5d, hx4d, hx3d, hx2d, hx1d = pretrained_seg_model(source_288)
|
| 239 |
-
hx6 = F.interpolate(hx6, size=64, mode='bilinear', align_corners=False)
|
| 240 |
-
hx5d = F.interpolate(hx5d, size=64, mode='bilinear', align_corners=False)
|
| 241 |
-
hx4d = F.interpolate(hx4d, size=64, mode='bilinear', align_corners=False)
|
| 242 |
-
hx3d = F.interpolate(hx3d, size=64, mode='bilinear', align_corners=False)
|
| 243 |
-
hx2d = F.interpolate(hx2d, size=64, mode='bilinear', align_corners=False)
|
| 244 |
-
hx1d = F.interpolate(hx1d, size=64, mode='bilinear', align_corners=False)
|
| 245 |
-
|
| 246 |
-
seg_map_all = torch.cat((hx6, hx5d, hx4d, hx3d, hx2d, hx1d), dim=1) # [b, 384, 64, 64]
|
| 247 |
-
# tv_save_image(mskx,"vis_hp/debug_vis/mskx.png")
|
| 248 |
-
if settings.env.use_line_mask:
|
| 249 |
-
textline_map, textline_mask = pretrained_line_seg_model(mskx) # [3, 64, 256, 256]
|
| 250 |
-
textline_map = F.interpolate(textline_map, size=64, mode='bilinear', align_corners=False) # [3, 64, 64, 64]
|
| 251 |
-
else:
|
| 252 |
-
seg_map_all = None
|
| 253 |
-
textline_map = None
|
| 254 |
|
| 255 |
|
| 256 |
-
if settings.env.train_VGG:
|
| 257 |
-
c20 = None
|
| 258 |
-
feature_size = 64
|
| 259 |
-
else:
|
| 260 |
-
feature_size = 64
|
| 261 |
-
if settings.env.train_mode == 'stage_1_dit_cat' or settings.env.train_mode =='stage_1_dit_cross':
|
| 262 |
-
with th.no_grad():
|
| 263 |
-
c20 = extract_raw_features_single2(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
|
| 264 |
-
# 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
|
| 265 |
-
else:
|
| 266 |
-
with th.no_grad():
|
| 267 |
-
c20 = extract_raw_features_single(pyramid, source, source_256, feature_size) # [24, 1, 64, 64, 64, 64]
|
| 268 |
-
# 平均互相关,VGG最浅层特征的下采样(512*512->64*64)
|
| 269 |
|
| 270 |
-
|
| 271 |
-
logger.info(f"Starting sampling with VGG Features")
|
| 272 |
-
|
| 273 |
-
sample = run_sample_lr_dewarping(
|
| 274 |
-
settings,
|
| 275 |
-
logger,
|
| 276 |
-
diffusion,
|
| 277 |
-
model,
|
| 278 |
-
radius, # 4
|
| 279 |
-
source, # [B, 3, 512, 512] 0~1
|
| 280 |
-
feature_size, # 64
|
| 281 |
-
raw_corr, # None
|
| 282 |
-
init_flow, # [B, 2, 64, 64] -1~1
|
| 283 |
-
c20, # # [B, 64, 64, 64]
|
| 284 |
-
source_64, # None
|
| 285 |
-
pyramid,
|
| 286 |
-
mask_x, #mask_x, # F.interpolate(mskx, size=(512), mode='bilinear', align_corners=True)[:,:1,:,:] , # mask_x
|
| 287 |
-
seg_map_all,
|
| 288 |
-
textline_map,
|
| 289 |
-
init_feat
|
| 290 |
-
) # sample: [1, 2, 64, 64] 偏移量 [-1,1]范围 五步DDIM的结果
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
if settings.env.use_sr_net == False:
|
| 294 |
-
sample = F.interpolate(sample, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
|
| 295 |
-
# sample[:, 0, :, :] = sample[:, 0, :, :] * W_ori
|
| 296 |
-
# sample[:, 1, :, :] = sample[:, 1, :, :] * H_ori
|
| 297 |
-
base = F.interpolate(coords_grid_tensor((512,512))/511., size=(H_ori, W_ori), mode='bilinear', align_corners=True)
|
| 298 |
-
# sample = ( ((sample + base.to(sample.device)) )*2 - 1 )
|
| 299 |
-
sample = ( ((sample + base.to(sample.device))*1 )*2 - 1 )*0.987 # (2 * (bm / 286.8) - 1) * 0.99
|
| 300 |
-
ref_flow = None
|
| 301 |
-
if ref_flow is not None:
|
| 302 |
-
ref_flow = F.interpolate(ref_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True) # [-1,+1] 偏移场
|
| 303 |
-
# ref_flow[:, 0, :, :] = ref_flow[:, 0, :, :] * W_ori
|
| 304 |
-
# ref_flow[:, 1, :, :] = ref_flow[:, 1, :, :] * H_ori
|
| 305 |
-
ref_flow = (ref_flow + base.to(ref_flow.device))*2 -1
|
| 306 |
-
# init_flow = F.interpolate(init_flow, size=(H_ori, W_ori), mode='bilinear', align_corners=True)
|
| 307 |
-
else:
|
| 308 |
-
raise ValueError("Invalid value")
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
if settings.env.visualize:
|
| 312 |
-
output = visualize_dewarping(settings, sample, data, i, source_vis, data_path, ref_flow)
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
def run_single_docunet(input_image_ori):
|
| 317 |
input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3]
|
| 318 |
|
|
|
|
| 180 |
return input_image, H_ori, W_ori, source, target, batch_ori, batch_ori_inter, source_256, target_256, source_vis, target_vis, mask, source_0
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
@GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def run_single_docunet(input_image_ori):
|
| 188 |
input_image_ori = np.array(input_image_ori, dtype=np.uint8) # [x, y, 3]
|
| 189 |
|