ImageEditPro / util.py
selfit-camera's picture
init
b9247d3
raw
history blame
19.2 kB
import os
import sys
import cv2
import json
import random
import time
import datetime
import requests
import func_timeout
import numpy as np
import gradio as gr
import boto3
import tempfile
from botocore.client import Config
from PIL import Image
# TOKEN = os.environ['TOKEN']
# APIKEY = os.environ['APIKEY']
# UKAPIURL = os.environ['UKAPIURL']
OneKey = os.environ['OneKey'].strip()
OneKey = OneKey.split("#")
TOKEN = OneKey[0]
APIKEY = OneKey[1]
UKAPIURL = OneKey[2]
LLMKEY = OneKey[3]
R2_ACCESS_KEY = OneKey[4]
R2_SECRET_KEY = OneKey[5]
R2_ENDPOINT = OneKey[6]
tmpFolder = "tmp"
os.makedirs(tmpFolder, exist_ok=True)
def upload_user_img(clientIp, timeId, img):
fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
local_path = os.path.join(tmpFolder, fileName)
img = cv2.imread(img)
cv2.imwrite(os.path.join(tmpFolder, fileName), img)
json_data = {
"token": TOKEN,
"input1": fileName,
"input2": "",
"protocol": "",
"cloud": "ali"
}
session = requests.session()
ret = requests.post(
f"{UKAPIURL}/upload",
headers={'Content-Type': 'application/json'},
json=json_data
)
res = ""
if ret.status_code==200:
if 'upload1' in ret.json():
upload_url = ret.json()['upload1']
headers = {'Content-Type': 'image/jpeg'}
response = session.put(upload_url, data=open(local_path, 'rb').read(), headers=headers)
# print(response.status_code)
if response.status_code == 200:
res = upload_url
if os.path.exists(local_path):
os.remove(local_path)
return res
class R2Api:
def __init__(self, session=None):
super().__init__()
self.R2_BUCKET = "trump-ai-voice"
self.domain = "https://www.trumpaivoice.net/"
self.R2_ACCESS_KEY = R2_ACCESS_KEY
self.R2_SECRET_KEY = R2_SECRET_KEY
self.R2_ENDPOINT = R2_ENDPOINT
self.client = boto3.client(
"s3",
endpoint_url=self.R2_ENDPOINT,
aws_access_key_id=self.R2_ACCESS_KEY,
aws_secret_access_key=self.R2_SECRET_KEY,
config=Config(signature_version="s3v4")
)
self.session = requests.Session() if session is None else session
def upload_file(self, local_path, cloud_path):
t1 = time.time()
head_dict = {
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'png': 'image/png',
'gif': 'image/gif',
'bmp': 'image/bmp',
'webp': 'image/webp',
'ico': 'image/x-icon'
}
ftype = os.path.basename(local_path).split(".")[-1].lower()
ctype = head_dict.get(ftype, 'application/octet-stream')
headers = {"Content-Type": ctype}
cloud_path = f"QwenImageEdit/Uploads/{str(datetime.date.today())}/{os.path.basename(local_path)}"
url = self.client.generate_presigned_url(
"put_object",
Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": ctype},
ExpiresIn=604800
)
retry_count = 0
while retry_count < 3:
try:
with open(local_path, 'rb') as f:
self.session.put(url, data=f.read(), headers=headers, timeout=8)
break
except (requests.exceptions.Timeout, requests.exceptions.RequestException):
retry_count += 1
if retry_count == 3:
raise Exception('Failed to upload file to R2 after 3 retries!')
continue
print("upload_file time is ====>", time.time() - t1)
return f"{self.domain}{cloud_path}"
def upload_user_img_r2(clientIp, timeId, img):
fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
local_path = os.path.join(tmpFolder, fileName)
img = cv2.imread(img)
cv2.imwrite(os.path.join(tmpFolder, fileName), img)
res = R2Api().upload_file(local_path, fileName)
if os.path.exists(local_path):
os.remove(local_path)
return res
@func_timeout.func_set_timeout(10)
def get_country_info(ip):
"""Get country information for IP address"""
try:
# Use the new API URL
url = f"https://qifu-api.baidubce.com/ip/geo/v1/district?ip={ip}"
ret = requests.get(url)
ret.raise_for_status() # Raises exception if request fails (e.g. 404, 500)
json_data = ret.json()
# Based on new JSON structure, country info is under 'data' -> 'country' path
if json_data.get("code") == "Success":
country = json_data.get("data", {}).get("country")
return country if country else "Unknown"
else:
# Handle API error codes
print(f"API request failed: {json_data.get('msg', 'Unknown error')}")
return "Unknown"
except requests.exceptions.RequestException as e:
print(f"Network request failed: {e}")
return "Unknown"
except Exception as e:
print(f"Failed to get IP location: {e}")
return "Unknown"
def get_country_info_safe(ip):
"""Safely get IP location info, returns Unknown on error"""
try:
return get_country_info(ip)
except func_timeout.FunctionTimedOut:
print(f"IP location request timeout: {ip}")
return "Unknown"
except Exception as e:
print(f"Failed to get IP location: {e}")
return "Unknown"
def create_mask_from_layers(base_image, layers):
"""
Create mask image from ImageEditor layers
Args:
base_image (PIL.Image): Original image
layers (list): ImageEditor layer data
Returns:
PIL.Image: Black and white mask image
"""
from PIL import Image, ImageDraw
import numpy as np
# Create blank mask with same size as original image
mask = Image.new('L', base_image.size, 0) # 'L' mode is grayscale, 0 is black
if not layers:
return mask
# Iterate through all layers, set drawn areas to white
for layer in layers:
if layer is not None:
# Convert layer to numpy array
layer_array = np.array(layer)
# Check layer format
if len(layer_array.shape) == 3: # RGB/RGBA format
# If RGBA, check alpha channel
if layer_array.shape[2] == 4:
# Use alpha channel as mask
alpha_channel = layer_array[:, :, 3]
# Set non-transparent areas (alpha > 0) to white
mask_array = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
else:
# RGB format, check if not pure black (0,0,0)
# Assume drawn areas are non-black
non_black = np.any(layer_array > 0, axis=2)
mask_array = np.where(non_black, 255, 0).astype(np.uint8)
elif len(layer_array.shape) == 2: # Grayscale
# Use grayscale values directly, set non-zero areas to white
mask_array = np.where(layer_array > 0, 255, 0).astype(np.uint8)
else:
continue
# Convert mask_array to PIL image and merge into total mask
layer_mask = Image.fromarray(mask_array, mode='L')
# Resize to match original image
if layer_mask.size != base_image.size:
layer_mask = layer_mask.resize(base_image.size, Image.LANCZOS)
# Merge masks (use maximum value to ensure all drawn areas are included)
mask_array_current = np.array(mask)
layer_mask_array = np.array(layer_mask)
combined_mask_array = np.maximum(mask_array_current, layer_mask_array)
mask = Image.fromarray(combined_mask_array, mode='L')
return mask
def upload_mask_image_r2(client_ip, time_id, mask_image):
"""
Upload mask image to R2
Args:
client_ip (str): Client IP
time_id (int): Timestamp
mask_image (PIL.Image): Mask image
Returns:
str: Uploaded URL
"""
file_name = f"{client_ip.replace('.', '')}{time_id}_mask.png"
local_path = os.path.join(tmpFolder, file_name)
try:
# Save mask image as PNG format (supports transparency)
mask_image.save(local_path, 'PNG')
# Upload to R2
res = R2Api().upload_file(local_path, file_name)
return res
except Exception as e:
print(f"Failed to upload mask image: {e}")
return None
finally:
# Clean up local files
if os.path.exists(local_path):
os.remove(local_path)
def submit_image_edit_task(user_image_url, prompt, task_type="80", mask_image_url=""):
"""
Submit image editing task
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"user_image": user_image_url,
"mask_image": mask_image_url,
"task_type": task_type,
"prompt": prompt,
"secret_key": "219ngu",
"is_private": "0"
}
try:
response = requests.post(
f'{UKAPIURL}/public_image_edit',
headers=headers,
json=data
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
return result['data']['task_id'], None
else:
return None, f"API Error: {result.get('message', 'Unknown error')}"
else:
return None, f"HTTP Error: {response.status_code}"
except Exception as e:
return None, f"Request Exception: {str(e)}"
def check_task_status(task_id):
"""
Query task status
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"task_id": task_id
}
try:
response = requests.post(
f'{UKAPIURL}/status_image_edit',
headers=headers,
json=data
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
task_data = result['data']
return task_data['status'], task_data.get('output1'), task_data
else:
return 'error', None, result.get('message', 'Unknown error')
else:
return 'error', None, f"HTTP Error: {response.status_code}"
except Exception as e:
return 'error', None, f"Request Exception: {str(e)}"
def process_image_edit(img_input, prompt, progress_callback=None):
"""
Complete process for image editing
Args:
img_input: Can be file path (str) or PIL Image object
prompt: Editing instructions
progress_callback: Progress callback function
"""
temp_img_path = None
try:
# Generate client IP and timestamp
client_ip = "127.0.0.1" # Default IP
time_id = int(time.time())
# Process input image - supports PIL Image and file path
if hasattr(img_input, 'save'): # PIL Image object
# Create temporary file
temp_dir = tempfile.mkdtemp()
temp_img_path = os.path.join(temp_dir, f"temp_img_{time_id}.jpg")
# Save PIL Image as temporary file
if img_input.mode != 'RGB':
img_input = img_input.convert('RGB')
img_input.save(temp_img_path, 'JPEG', quality=95)
img_path = temp_img_path
print(f"💾 PIL Image saved as temporary file: {temp_img_path}")
else:
# Assume it's a file path
img_path = img_input
if progress_callback:
progress_callback("uploading image...")
# Upload user image
uploaded_url = upload_user_img_r2(client_ip, time_id, img_path)
if not uploaded_url:
return None, "image upload failed"
# Extract actual image URL from upload URL
if "?" in uploaded_url:
uploaded_url = uploaded_url.split("?")[0]
if progress_callback:
progress_callback("submitting edit task...")
# Submit image editing task
task_id, error = submit_image_edit_task(uploaded_url, prompt)
if error:
return None, error
if progress_callback:
progress_callback(f"task submitted, ID: {task_id}, processing...")
# Wait for task completion
max_attempts = 60 # Wait up to 10 minutes
for attempt in range(max_attempts):
status, output_url, task_data = check_task_status(task_id)
if status == 'completed':
if output_url:
return output_url, "image edit completed"
else:
return None, "Task completed but no result image returned"
elif status == 'error' or status == 'failed':
return None, f"task processing failed: {task_data}"
elif status in ['queued', 'processing', 'running', 'created', 'working']:
if progress_callback:
progress_callback(f"task processing... (status: {status})")
time.sleep(1) # 等待10秒后重试
else:
if progress_callback:
progress_callback(f"unknown status: {status}")
time.sleep(1)
return None, "task processing timeout"
except Exception as e:
return None, f"error occurred during processing: {str(e)}"
finally:
# 清理临时文件
if temp_img_path and os.path.exists(temp_img_path):
try:
os.remove(temp_img_path)
# 尝试删除临时目录(如果为空)
temp_dir = os.path.dirname(temp_img_path)
if os.path.exists(temp_dir):
os.rmdir(temp_dir)
print(f"🗑️ Cleaned up temporary file: {temp_img_path}")
except Exception as cleanup_error:
print(f"⚠️ Failed to clean up temporary file: {cleanup_error}")
def process_local_image_edit(base_image, layers, prompt, progress_callback=None):
"""
处理局部图片编辑的完整流程
Args:
base_image (PIL.Image): 原始图片
layers (list): ImageEditor的层数据
prompt (str): 编辑指令
progress_callback: 进度回调函数
"""
temp_img_path = None
temp_mask_path = None
try:
# Generate client IP and timestamp
client_ip = "127.0.0.1" # Default IP
time_id = int(time.time())
if progress_callback:
progress_callback("正在创建mask图片...")
# 从layers创建mask图片
mask_image = create_mask_from_layers(base_image, layers)
# 检查mask是否有内容
mask_array = np.array(mask_image)
if np.max(mask_array) == 0:
return None, "请在图片上绘制需要编辑的区域"
print(f"📝 创建mask图片成功,绘制区域像素数: {np.sum(mask_array > 0)}")
if progress_callback:
progress_callback("正在上传原始图片...")
# 处理并上传原始图片
temp_dir = tempfile.mkdtemp()
temp_img_path = os.path.join(temp_dir, f"temp_img_{time_id}.jpg")
# 保存原始图片
if base_image.mode != 'RGB':
base_image = base_image.convert('RGB')
base_image.save(temp_img_path, 'JPEG', quality=95)
# 上传原始图片
uploaded_url = upload_user_img_r2(client_ip, time_id, temp_img_path)
if not uploaded_url:
return None, "原始图片上传失败"
# 从上传 URL 中提取实际的图片 URL
if "?" in uploaded_url:
uploaded_url = uploaded_url.split("?")[0]
if progress_callback:
progress_callback("正在上传mask图片...")
# 上传mask图片
mask_url = upload_mask_image_r2(client_ip, time_id, mask_image)
if not mask_url:
return None, "mask图片上传失败"
# 从上传 URL 中提取实际的图片 URL
if "?" in mask_url:
mask_url = mask_url.split("?")[0]
print(f"📤 图片上传成功:")
print(f" 原始图片: {uploaded_url}")
print(f" Mask图片: {mask_url}")
if progress_callback:
progress_callback("正在提交局部编辑任务...")
# 提交局部图片编辑任务 (task_type=81)
task_id, error = submit_image_edit_task(uploaded_url, prompt, task_type="81", mask_image_url=mask_url)
if error:
return None, error
if progress_callback:
progress_callback(f"任务已提交,ID: {task_id},正在处理...")
print(f"🚀 局部编辑任务已提交,任务ID: {task_id}")
# Wait for task completion
max_attempts = 60 # Wait up to 10 minutes
for attempt in range(max_attempts):
status, output_url, task_data = check_task_status(task_id)
if status == 'completed':
if output_url:
print(f"✅ 局部编辑任务完成,结果: {output_url}")
return output_url, "局部图片编辑完成"
else:
return None, "任务完成但未返回结果图片"
elif status == 'error' or status == 'failed':
return None, f"任务处理失败: {task_data}"
elif status in ['queued', 'processing', 'running', 'created', 'working']:
if progress_callback:
progress_callback(f"正在处理中... (状态: {status})")
time.sleep(1) # Wait 1 second before retry
else:
if progress_callback:
progress_callback(f"未知状态: {status}")
time.sleep(1)
return None, "任务处理超时"
except Exception as e:
print(f"❌ 局部编辑处理异常: {str(e)}")
return None, f"处理过程中发生错误: {str(e)}"
finally:
# 清理临时文件
if temp_img_path and os.path.exists(temp_img_path):
try:
os.remove(temp_img_path)
temp_dir = os.path.dirname(temp_img_path)
if os.path.exists(temp_dir):
os.rmdir(temp_dir)
print(f"🗑️ Cleaned up temporary file: {temp_img_path}")
except Exception as cleanup_error:
print(f"⚠️ Failed to clean up temporary file: {cleanup_error}")
if __name__ == "__main__":
pass