Spaces:
Runtime error
Runtime error
| # for contextual attention | |
| import torch | |
| def extract_image_patches(images, ksizes, strides, rates, padding="same"): | |
| """ | |
| Extract patches from images and put them in the C output dimension. | |
| :param padding: | |
| :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape | |
| :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for | |
| each dimension of images | |
| :param strides: [stride_rows, stride_cols] | |
| :param rates: [dilation_rows, dilation_cols] | |
| :return: A Tensor | |
| """ | |
| assert len(images.size()) == 4 | |
| assert padding in ["same", "valid"] | |
| batch_size, channel, height, width = images.size() | |
| if padding == "same": | |
| images = same_padding(images, ksizes, strides, rates) | |
| elif padding == "valid": | |
| pass | |
| else: | |
| raise NotImplementedError( | |
| 'Unsupported padding type: {}.\ | |
| Only "same" or "valid" are supported.'.format( | |
| padding | |
| ) | |
| ) | |
| unfold = torch.nn.Unfold( | |
| kernel_size=ksizes, dilation=rates, padding=0, stride=strides | |
| ) | |
| patches = unfold(images) | |
| return patches # [N, C*k*k, L], L is the total number of such blocks | |
| def same_padding(images, ksizes, strides, rates): | |
| assert len(images.size()) == 4 | |
| batch_size, channel, rows, cols = images.size() | |
| out_rows = (rows + strides[0] - 1) // strides[0] | |
| out_cols = (cols + strides[1] - 1) // strides[1] | |
| effective_k_row = (ksizes[0] - 1) * rates[0] + 1 | |
| effective_k_col = (ksizes[1] - 1) * rates[1] + 1 | |
| padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows) | |
| padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols) | |
| # Pad the input | |
| padding_top = int(padding_rows / 2.0) | |
| padding_left = int(padding_cols / 2.0) | |
| padding_bottom = padding_rows - padding_top | |
| padding_right = padding_cols - padding_left | |
| paddings = (padding_left, padding_right, padding_top, padding_bottom) | |
| images = torch.nn.ZeroPad2d(paddings)(images) | |
| return images | |
| def reduce_mean(x, axis=None, keepdim=False): | |
| if not axis: | |
| axis = range(len(x.shape)) | |
| for i in sorted(axis, reverse=True): | |
| x = torch.mean(x, dim=i, keepdim=keepdim) | |
| return x | |
| def reduce_std(x, axis=None, keepdim=False): | |
| if not axis: | |
| axis = range(len(x.shape)) | |
| for i in sorted(axis, reverse=True): | |
| x = torch.std(x, dim=i, keepdim=keepdim) | |
| return x | |
| def reduce_sum(x, axis=None, keepdim=False): | |
| if not axis: | |
| axis = range(len(x.shape)) | |
| for i in sorted(axis, reverse=True): | |
| x = torch.sum(x, dim=i, keepdim=keepdim) | |
| return x | |