|
|
import random |
|
|
import numpy as np |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from statistics import mean |
|
|
import copy |
|
|
import json |
|
|
from typing import Any, Mapping |
|
|
import open_clip |
|
|
import torch |
|
|
|
|
|
from sentence_transformers.util import (semantic_search, |
|
|
dot_score, |
|
|
normalize_embeddings) |
|
|
|
|
|
|
|
|
def nn_project(curr_embeds, embedding_layer, print_hits=False): |
|
|
with torch.no_grad(): |
|
|
bsz,seq_len,emb_dim = curr_embeds.shape |
|
|
|
|
|
curr_embeds = curr_embeds.reshape((-1,emb_dim)) |
|
|
curr_embeds = normalize_embeddings(curr_embeds) |
|
|
|
|
|
embedding_matrix = embedding_layer.weight |
|
|
embedding_matrix = normalize_embeddings(embedding_matrix) |
|
|
|
|
|
hits = semantic_search(curr_embeds, embedding_matrix, |
|
|
query_chunk_size=curr_embeds.shape[0], |
|
|
top_k=1, |
|
|
score_function=dot_score) |
|
|
|
|
|
if print_hits: |
|
|
all_hits = [] |
|
|
for hit in hits: |
|
|
all_hits.append(hit[0]["score"]) |
|
|
print(f"mean hits:{mean(all_hits)}") |
|
|
|
|
|
nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device=curr_embeds.device) |
|
|
nn_indices = nn_indices.reshape((bsz,seq_len)) |
|
|
|
|
|
projected_embeds = embedding_layer(nn_indices) |
|
|
|
|
|
return projected_embeds, nn_indices |
|
|
|
|
|
def decode_ids(input_ids, tokenizer, by_token=False): |
|
|
input_ids = input_ids.detach().cpu().numpy() |
|
|
|
|
|
texts = [] |
|
|
|
|
|
if by_token: |
|
|
for input_ids_i in input_ids: |
|
|
curr_text = [] |
|
|
for tmp in input_ids_i: |
|
|
curr_text.append(tokenizer.decode([tmp])) |
|
|
|
|
|
texts.append('|'.join(curr_text)) |
|
|
else: |
|
|
for input_ids_i in input_ids: |
|
|
texts.append(tokenizer.decode(input_ids_i)) |
|
|
|
|
|
return texts |
|
|
|
|
|
def get_target_feature(model, preprocess, tokenizer_funct, device, target_images=None, target_prompts=None): |
|
|
if target_images is not None: |
|
|
with torch.no_grad(): |
|
|
curr_images = [preprocess(i).unsqueeze(0) for i in target_images] |
|
|
curr_images = torch.concatenate(curr_images).to(device) |
|
|
all_target_features = model.encode_image(curr_images) |
|
|
else: |
|
|
texts = tokenizer_funct(target_prompts).to(device) |
|
|
all_target_features = model.encode_text(texts) |
|
|
|
|
|
return all_target_features |
|
|
|
|
|
def encode_text_embedding(model, text_embedding, ids, avg_text=False): |
|
|
cast_dtype = model.transformer.get_cast_dtype() |
|
|
|
|
|
x = text_embedding + model.positional_embedding.to(cast_dtype) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = model.transformer(x, attn_mask=model.attn_mask) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = model.ln_final(x) |
|
|
|
|
|
|
|
|
|
|
|
if avg_text: |
|
|
x = x[torch.arange(x.shape[0]), :ids.argmax(dim=-1)] |
|
|
x[:, 1:-1] |
|
|
x = x.mean(dim=1) @ model.text_projection |
|
|
else: |
|
|
x = x[torch.arange(x.shape[0]), ids.argmax(dim=-1)] @ model.text_projection |
|
|
|
|
|
return x |
|
|
|
|
|
def forward_text_embedding(model, embeddings, ids, image_features, avg_text=False, return_feature=False): |
|
|
text_features = encode_text_embedding(model, embeddings, ids, avg_text=avg_text) |
|
|
|
|
|
if return_feature: |
|
|
return text_features |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
logits_per_image = image_features @ text_features.t() |
|
|
logits_per_text = logits_per_image.t() |
|
|
|
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
def initialize_prompt(tokenizer, token_embedding, args, device, original_prompt): |
|
|
prompt_len = args["prompt_len"] |
|
|
|
|
|
|
|
|
tokens = tokenizer.encode(original_prompt) |
|
|
if len(tokens) > prompt_len: |
|
|
tokens = tokens[:prompt_len] |
|
|
if len(tokens) < prompt_len: |
|
|
tokens += [0] * (prompt_len - len(tokens)) |
|
|
|
|
|
prompt_ids = torch.tensor([tokens] * args["prompt_bs"]).to(device) |
|
|
|
|
|
prompt_embeds = token_embedding(prompt_ids).detach() |
|
|
prompt_embeds.requires_grad = True |
|
|
|
|
|
|
|
|
template_text = "{}" |
|
|
padded_template_text = template_text.format(" ".join(["<start_of_text>"] * prompt_len)) |
|
|
dummy_ids = tokenizer.encode(padded_template_text) |
|
|
|
|
|
|
|
|
dummy_ids = [i if i != 49406 else -1 for i in dummy_ids] |
|
|
dummy_ids = [49406] + dummy_ids + [49407] |
|
|
dummy_ids += [0] * (77 - len(dummy_ids)) |
|
|
dummy_ids = torch.tensor([dummy_ids] * args["prompt_bs"]).to(device) |
|
|
|
|
|
|
|
|
tmp_dummy_ids = copy.deepcopy(dummy_ids) |
|
|
tmp_dummy_ids[tmp_dummy_ids == -1] = 0 |
|
|
dummy_embeds = token_embedding(tmp_dummy_ids).detach() |
|
|
dummy_embeds.requires_grad = False |
|
|
|
|
|
return prompt_embeds, dummy_embeds, dummy_ids |
|
|
|
|
|
def optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, original_prompt): |
|
|
opt_iters = args["iter"] |
|
|
lr = args["lr"] |
|
|
weight_decay = args["weight_decay"] |
|
|
print_step = args["print_step"] |
|
|
batch_size = args["batch_size"] |
|
|
print_new_best = True |
|
|
|
|
|
|
|
|
prompt_embeds, dummy_embeds, dummy_ids = initialize_prompt(tokenizer, token_embedding, args, device, original_prompt) |
|
|
p_bs, p_len, p_dim = prompt_embeds.shape |
|
|
|
|
|
|
|
|
input_optimizer = torch.optim.AdamW([prompt_embeds], lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
best_sim = -1000 * args["loss_weight"] |
|
|
best_text = "" |
|
|
|
|
|
for step in range(opt_iters): |
|
|
|
|
|
if batch_size is None: |
|
|
target_features = all_target_features |
|
|
else: |
|
|
curr_indx = torch.randperm(len(all_target_features)) |
|
|
target_features = all_target_features[curr_indx][0:batch_size] |
|
|
|
|
|
universal_target_features = all_target_features |
|
|
|
|
|
|
|
|
projected_embeds, nn_indices = nn_project(prompt_embeds, token_embedding, print_hits=False) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
padded_embeds = dummy_embeds.detach().clone() |
|
|
padded_embeds[dummy_ids == -1] = projected_embeds.reshape(-1, p_dim) |
|
|
logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, universal_target_features) |
|
|
scores_per_prompt = logits_per_image.mean(dim=0) |
|
|
universal_cosim_score = scores_per_prompt.max().item() |
|
|
best_indx = scores_per_prompt.argmax().item() |
|
|
|
|
|
|
|
|
tmp_embeds = prompt_embeds.detach().clone() |
|
|
tmp_embeds.data = projected_embeds.data |
|
|
tmp_embeds.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
padded_embeds = dummy_embeds.detach().clone() |
|
|
padded_embeds[dummy_ids == -1] = tmp_embeds.reshape(-1, p_dim) |
|
|
|
|
|
logits_per_image, _ = forward_text_embedding(model, padded_embeds, dummy_ids, target_features) |
|
|
cosim_scores = logits_per_image |
|
|
loss = 1 - cosim_scores.mean() |
|
|
loss = loss * args["loss_weight"] |
|
|
|
|
|
prompt_embeds.grad, = torch.autograd.grad(loss, [tmp_embeds]) |
|
|
|
|
|
input_optimizer.step() |
|
|
input_optimizer.zero_grad() |
|
|
|
|
|
curr_lr = input_optimizer.param_groups[0]["lr"] |
|
|
cosim_scores = cosim_scores.mean().item() |
|
|
|
|
|
decoded_text = decode_ids(nn_indices, tokenizer)[best_indx] |
|
|
if print_step is not None and (step % print_step == 0 or step == opt_iters-1): |
|
|
per_step_message = f"step: {step}, lr: {curr_lr}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if best_sim * args["loss_weight"] < universal_cosim_score * args["loss_weight"]: |
|
|
best_sim = universal_cosim_score |
|
|
best_text = decoded_text |
|
|
if print_new_best: |
|
|
print(f"step: {step}, new best cosine sim: {best_sim}, new best prompt: {best_text}") |
|
|
|
|
|
if print_step is not None: |
|
|
print(f"best cosine sim: {best_sim}, best prompt: {best_text}") |
|
|
|
|
|
return best_text |
|
|
|
|
|
|
|
|
def optimize_prompt(model, preprocess, args, device, target_images=None, target_prompts=None): |
|
|
token_embedding = model.token_embedding |
|
|
tokenizer = open_clip.tokenizer._tokenizer |
|
|
tokenizer_funct = open_clip.get_tokenizer(args["clip_model"]) |
|
|
|
|
|
all_target_features = get_target_feature(model, preprocess, tokenizer_funct, device, target_images=target_images) |
|
|
learned_prompt = optimize_prompt_loop(model, tokenizer, token_embedding, all_target_features, args, device, target_prompts) |
|
|
|
|
|
return learned_prompt |
|
|
|
|
|
|