Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch | |
| try: | |
| import torch_npu | |
| from torch_npu.contrib import transfer_to_npu | |
| import importlib | |
| import transformers.utils | |
| import transformers.models | |
| origin_utils = transformers.utils | |
| origin_models = transformers.models | |
| import flash_attn | |
| flash_attn.hack_transformers_flash_attn_2_available_check() | |
| importlib.reload(transformers.utils) | |
| importlib.reload(transformers.models) | |
| origin_func = torch.nn.functional.interpolate | |
| def new_func(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): | |
| if mode == "bilinear": | |
| dtype = input.dtype | |
| res = origin_func(input.to(torch.bfloat16), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias) | |
| return res.to(dtype) | |
| else: | |
| return origin_func(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias) | |
| torch.nn.functional.interpolate = new_func | |
| from utils import patch_npu_record_stream | |
| from utils import patch_npu_diffusers_get_1d_rotary_pos_embed | |
| patch_npu_record_stream() | |
| patch_npu_diffusers_get_1d_rotary_pos_embed() | |
| USE_NPU = True | |
| except: | |
| USE_NPU = False | |
| from dreamomni2.pipeline_dreamomni2 import DreamOmni2Pipeline | |
| from diffusers.utils import load_image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor | |
| # from qwen_vl_utils import process_vision_info | |
| from utils.vprocess import process_vision_info, resizeinput | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| import json | |
| from PIL import Image | |
| import re | |
| import argparse | |
| if USE_NPU: | |
| device = "npu" | |
| else: | |
| device = "cuda" | |
| def extract_gen_content(text): | |
| text = text[6:-7] | |
| return text | |
| def parse_args(): | |
| """Parses command-line arguments for model paths and server configuration.""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--vlm_path", | |
| type=str, | |
| default="./models/vlm-model", | |
| help="Path to the VLM model directory." | |
| ) | |
| parser.add_argument( | |
| "--edit_lora_path", | |
| type=str, | |
| default="./models/edit_lora", | |
| help="Path to the FLUX.1-Kontext editing LoRA weights directory." | |
| ) | |
| parser.add_argument( | |
| "--base_model_path", | |
| type=str, | |
| default="black-forest-labs/FLUX.1-Kontext-dev", | |
| help="Path to the FLUX.1-Kontext editing." | |
| ) | |
| parser.add_argument( | |
| "--input_img_path", | |
| type=str, | |
| nargs='+', # Accept one or more input paths | |
| default=["example_input/edit_tests/src.jpg", "example_input/edit_tests/ref.jpg"], | |
| help="List of input image paths (e.g., src and ref images)." | |
| ) | |
| # Argument for the input instruction | |
| parser.add_argument( | |
| "--input_instruction", | |
| type=str, | |
| default="Make the woman from the second image stand on the road in the first image.", | |
| help="Instruction for image editing." | |
| ) | |
| # Argument for the output image path | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default="example_input/edit_tests/edi_res.png", | |
| help="Path to save the output image." | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| ARGS = parse_args() | |
| vlm_path = ARGS.vlm_path | |
| edit_lora_path = ARGS.edit_lora_path | |
| base_model = ARGS.base_model_path | |
| pipe = DreamOmni2Pipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16) | |
| pipe.to(device) | |
| pipe.load_lora_weights( | |
| edit_lora_path, | |
| adapter_name="edit" | |
| ) | |
| pipe.set_adapters(["edit"], adapter_weights=[1]) | |
| vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| vlm_path, torch_dtype="bfloat16", device_map="cuda" | |
| ) | |
| processor = AutoProcessor.from_pretrained(vlm_path) | |
| def infer_vlm(input_img_path,input_instruction,prefix): | |
| tp=[] | |
| for path in input_img_path: | |
| tp.append({"type": "image", "image": path}) | |
| tp.append({"type": "text", "text": input_instruction+prefix}) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": tp, | |
| } | |
| ] | |
| # Preparation for inference | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to("cuda") | |
| # Inference | |
| generated_ids = vlm_model.generate(**inputs, do_sample=False, max_new_tokens=4096) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return output_text[0] | |
| def infer(source_imgs,prompt): | |
| image = pipe( | |
| images=source_imgs, | |
| height=source_imgs[0].height, | |
| width=source_imgs[0].width, | |
| prompt=prompt, | |
| num_inference_steps=30, | |
| guidance_scale=3.5, | |
| ).images[0] | |
| return image | |
| input_img_path=ARGS.input_img_path | |
| input_instruction=ARGS.input_instruction | |
| prefix=" It is editing task." | |
| source_imgs = [] | |
| for path in input_img_path: | |
| img = load_image(path) | |
| # source_imgs.append(img) | |
| source_imgs.append(resizeinput(img)) | |
| prompt=infer_vlm(input_img_path,input_instruction,prefix) | |
| prompt = extract_gen_content(prompt) | |
| image=infer(source_imgs,prompt) | |
| output_path = ARGS.output_path | |
| image.save(output_path) |