|
|
from utils.distributed import is_main_process, get_rank, get_world_size |
|
|
import io |
|
|
import json |
|
|
import re |
|
|
import numpy as np |
|
|
from os.path import join |
|
|
from tqdm import trange |
|
|
from PIL import Image |
|
|
from PIL import ImageFile |
|
|
from torchvision.transforms import PILToTensor |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
|
|
|
|
|
def load_image_from_path(image_path, client): |
|
|
if image_path.startswith('s3') or image_path.startswith('p2'): |
|
|
value = client.Get(image_path) |
|
|
img_bytes = np.frombuffer(value, dtype=np.uint8) |
|
|
buff = io.BytesIO(img_bytes) |
|
|
image = Image.open(buff).convert('RGB') |
|
|
else: |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image = PILToTensor()(image).unsqueeze(0) |
|
|
return image |
|
|
|
|
|
def pre_text(text, max_l=None, pre_text=True): |
|
|
if pre_text: |
|
|
text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) |
|
|
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person') |
|
|
|
|
|
text = re.sub(r"\s{2,}", ' ', text) |
|
|
text = text.rstrip('\n').strip(' ') |
|
|
|
|
|
if max_l: |
|
|
words = text.split(' ') |
|
|
if len(words) > max_l: |
|
|
text = ' '.join(words[:max_l]) |
|
|
else: |
|
|
pass |
|
|
return text |
|
|
|
|
|
|