| import random | |
| from BeamDiffusionModel.models.CoSeD.cross_attention import get_softmax | |
| from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG | |
| from BeamDiffusionModel.tree.tree import BeamSearchTree | |
| from BeamDiffusionModel.utils.utils import gen_img | |
| def set_softmax(nodes, softmax, n_latents, n_max_latents): | |
| for node, softmax_value in zip(nodes, softmax): | |
| node.set_softmax(softmax_value, n_latents, n_max_latents) | |
| def beam_inference(sd, steps, latents_idx, n_seeds= 1, seeds=[], steps_back=2, beam_width=4, window_size=2, use_rand=True): | |
| while len(seeds) < n_seeds: | |
| seeds.append(random.randint(0, 10**6)) | |
| captions = steps | |
| tree = BeamSearchTree(steps_back,beam_width,latents_idx,len(captions)) | |
| nodes_to_explore = [] | |
| for i, caption in enumerate(captions): | |
| if i == 0: | |
| for seed in seeds: | |
| latents, img = gen_img(sd,caption, seed=seed) | |
| new_node = tree.add_node(tree.root, caption, i + 1, "Rand Seed", "Rand Seed", | |
| img, latents, None) | |
| nodes_to_explore.append(new_node) | |
| else: | |
| next_nodes = [] | |
| for child, parent_node in enumerate(nodes_to_explore): | |
| parent_childs = [] | |
| current_step_embeddings, current_image_embeddings = [], [] | |
| if use_rand: | |
| seed = random.randint(0, 10 ** 6) | |
| latents, img = gen_img(sd,caption, seed=seed) | |
| new_node = tree.add_node(parent_node, caption, i + 1, "Rand Seed", "Rand Seed", | |
| img, latents, None) | |
| parent_childs.append(new_node) | |
| current_step_embedding, current_image_embedding = new_node.get_features() | |
| current_step_embeddings.append(current_step_embedding) | |
| current_image_embeddings.append(current_image_embedding) | |
| ancestors = parent_node.get_ancestors(steps_back-1) | |
| for ancestor_idx, ancestor in enumerate(ancestors): | |
| for latent in latents_idx: | |
| ancestor_latent = ancestor.get_latent(latent) | |
| latents, img = gen_img(sd,caption, latent=ancestor_latent) | |
| new_node = tree.add_node(parent_node, caption, i + 1, ancestor.step, latent,img, latents, None) | |
| parent_childs.append(new_node) | |
| current_step_embedding, current_image_embedding = new_node.get_features() | |
| current_step_embeddings.append(current_step_embedding) | |
| current_image_embeddings.append(current_image_embedding) | |
| if current_step_embeddings != []: | |
| previous_steps_embeddings, previous_images_embeddings = tree.get_previous_steps_features(parent_childs[-1]) | |
| softmax = get_softmax(previous_steps_embeddings, previous_images_embeddings, | |
| current_step_embeddings, | |
| current_image_embeddings) | |
| set_softmax(parent_childs, softmax, len(latents_idx), CONFIG["stable_diffusion"]["diffusion_settings"]["steps"]) | |
| next_nodes += parent_childs | |
| if i >= window_size: | |
| print("-----------------------------------Cleaning some nodes-----------------------------------") | |
| best_paths = tree.get_n_best_paths(beam_width, i + 1) | |
| new_next_nodes = [] | |
| for node in next_nodes: | |
| for node_path in best_paths: | |
| if node in node_path: | |
| new_next_nodes.append(node) | |
| next_nodes = new_next_nodes | |
| nodes_to_explore = next_nodes | |
| return tree.best_path_imgs() |