ghibli-avatar / generate_image.py
ziheWang's picture
Update generate_image.py
bc6e519 verified
import torch
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# 加载翻译模型(中文->英文)
translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
# 翻译函数
def translate_to_english(chinese_text):
inputs = translator_tokenizer.encode(chinese_text, return_tensors="pt", max_length=512, truncation=True)
outputs = translator_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
english_text = translator_tokenizer.decode(outputs[0], skip_special_tokens=True)
return english_text
# 加载 Ghibli 风格 SD 模型
model_id = "nitrosocke/Ghibli-Diffusion"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# 主函数:生成头像
def generate_ghibli_avatar(prompt, reference_image=None):
english_prompt = translate_to_english(prompt)
final_prompt = f"A portrait of a {english_prompt}, ghibli style, anime, studio ghibli, headshot"
image = pipe(prompt=final_prompt).images[0]
return image