﻿param (
    [switch]$Help,
    [switch]$BuildMode,
    [switch]$DisablePyPIMirror,
    [switch]$DisableUpdate,
    [switch]$DisableProxy,
    [string]$UseCustomProxy,
    [switch]$DisableHuggingFaceMirror,
    [string]$UseCustomHuggingFaceMirror,
    [switch]$DisableUV,
    [string]$LaunchArg,
    [switch]$EnableShortcut,
    [switch]$DisableCUDAMalloc,
    [switch]$DisableEnvCheck,
    [switch]$DisableAutoApplyUpdate
)
# Fooocus Installer 版本和检查更新间隔
$FOOOCUS_INSTALLER_VERSION = 190
$UPDATE_TIME_SPAN = 3600
# PyPI 镜像源
$PIP_INDEX_ADDR = "https://mirrors.cloud.tencent.com/pypi/simple"
$PIP_INDEX_ADDR_ORI = "https://pypi.python.org/simple"
$PIP_EXTRA_INDEX_ADDR = "https://mirrors.cernet.edu.cn/pypi/web/simple"
$PIP_EXTRA_INDEX_ADDR_ORI = ""
$PIP_FIND_ADDR = "https://mirrors.aliyun.com/pytorch-wheels/torch_stable.html"
$PIP_FIND_ADDR_ORI = "https://download.pytorch.org/whl/torch_stable.html"
$USE_PIP_MIRROR = if ((!(Test-Path "$PSScriptRoot/disable_pypi_mirror.txt")) -and (!($DisablePyPIMirror))) { $true } else { $false }
$PIP_INDEX_MIRROR = if ($USE_PIP_MIRROR) { $PIP_INDEX_ADDR } else { $PIP_INDEX_ADDR_ORI }
$PIP_EXTRA_INDEX_MIRROR = if ($USE_PIP_MIRROR) { $PIP_EXTRA_INDEX_ADDR } else { $PIP_EXTRA_INDEX_ADDR_ORI }
$PIP_FIND_MIRROR = if ($USE_PIP_MIRROR) { $PIP_FIND_ADDR } else { $PIP_FIND_ADDR_ORI }
$PIP_FIND_MIRROR_CU121 = "https://download.pytorch.org/whl/cu121/torch_stable.html"
$PIP_EXTRA_INDEX_MIRROR_PYTORCH = "https://download.pytorch.org/whl"
$PIP_EXTRA_INDEX_MIRROR_XPU = "https://download.pytorch.org/whl/xpu"
$PIP_EXTRA_INDEX_MIRROR_CU118 = "https://download.pytorch.org/whl/cu118"
$PIP_EXTRA_INDEX_MIRROR_CU121 = "https://download.pytorch.org/whl/cu121"
$PIP_EXTRA_INDEX_MIRROR_CU124 = "https://download.pytorch.org/whl/cu124"
$PIP_EXTRA_INDEX_MIRROR_CU126 = "https://download.pytorch.org/whl/cu126"
$PIP_EXTRA_INDEX_MIRROR_CU128 = "https://download.pytorch.org/whl/cu128"
$PIP_EXTRA_INDEX_MIRROR_CU118_NJU = "https://mirror.nju.edu.cn/pytorch/whl/cu118"
$PIP_EXTRA_INDEX_MIRROR_CU124_NJU = "https://mirror.nju.edu.cn/pytorch/whl/cu124"
$PIP_EXTRA_INDEX_MIRROR_CU126_NJU = "https://mirror.nju.edu.cn/pytorch/whl/cu126"
$PIP_EXTRA_INDEX_MIRROR_CU128_NJU = "https://mirror.nju.edu.cn/pytorch/whl/cu128"
# Github 镜像源
$GITHUB_MIRROR_LIST = @(
    "https://ghfast.top/https://github.com",
    "https://mirror.ghproxy.com/https://github.com",
    "https://ghproxy.net/https://github.com",
    "https://gh.api.99988866.xyz/https://github.com",
    "https://gh-proxy.com/https://github.com",
    "https://ghps.cc/https://github.com",
    "https://gh.idayer.com/https://github.com",
    "https://ghproxy.1888866.xyz/github.com",
    "https://slink.ltd/https://github.com",
    "https://github.boki.moe/github.com",
    "https://github.moeyy.xyz/https://github.com",
    "https://gh-proxy.net/https://github.com",
    "https://gh-proxy.ygxz.in/https://github.com",
    "https://wget.la/https://github.com",
    "https://kkgithub.com",
    "https://gitclone.com/github.com"
)
# uv 最低版本
$UV_MINIMUM_VER = "0.8"
# Aria2 最低版本
$ARIA2_MINIMUM_VER = "1.37.0"
# PATH
$PYTHON_PATH = "$PSScriptRoot/python"
$PYTHON_EXTRA_PATH = "$PSScriptRoot/Fooocus/python"
$PYTHON_SCRIPTS_PATH = "$PSScriptRoot/python/Scripts"
$PYTHON_SCRIPTS_EXTRA_PATH = "$PSScriptRoot/Fooocus/python/Scripts"
$GIT_PATH = "$PSScriptRoot/git/bin"
$GIT_EXTRA_PATH = "$PSScriptRoot/Fooocus/git/bin"
$Env:PATH = "$PYTHON_EXTRA_PATH$([System.IO.Path]::PathSeparator)$PYTHON_SCRIPTS_EXTRA_PATH$([System.IO.Path]::PathSeparator)$GIT_EXTRA_PATH$([System.IO.Path]::PathSeparator)$PYTHON_PATH$([System.IO.Path]::PathSeparator)$PYTHON_SCRIPTS_PATH$([System.IO.Path]::PathSeparator)$GIT_PATH$([System.IO.Path]::PathSeparator)$Env:PATH"
# 环境变量
$Env:PIP_INDEX_URL = "$PIP_INDEX_MIRROR"
$Env:PIP_EXTRA_INDEX_URL = if ($PIP_EXTRA_INDEX_MIRROR -ne $PIP_EXTRA_INDEX_MIRROR_PYTORCH) { "$PIP_EXTRA_INDEX_MIRROR $PIP_EXTRA_INDEX_MIRROR_PYTORCH".Trim() } else { $PIP_EXTRA_INDEX_MIRROR }
$Env:PIP_FIND_LINKS = "$PIP_FIND_MIRROR"
$Env:UV_DEFAULT_INDEX = "$PIP_INDEX_MIRROR"
$Env:UV_INDEX = if ($PIP_EXTRA_INDEX_MIRROR -ne $PIP_EXTRA_INDEX_MIRROR_PYTORCH) { "$PIP_EXTRA_INDEX_MIRROR $PIP_EXTRA_INDEX_MIRROR_PYTORCH".Trim() } else { $PIP_EXTRA_INDEX_MIRROR }
$Env:UV_FIND_LINKS = "$PIP_FIND_MIRROR"
$Env:UV_LINK_MODE = "copy"
$Env:UV_HTTP_TIMEOUT = 30
$Env:UV_CONCURRENT_DOWNLOADS = 50
$Env:UV_INDEX_STRATEGY = "unsafe-best-match"
$Env:UV_CONFIG_FILE = "nul"
$Env:PIP_CONFIG_FILE = "nul"
$Env:PIP_DISABLE_PIP_VERSION_CHECK = 1
$Env:PIP_NO_WARN_SCRIPT_LOCATION = 0
$Env:PIP_TIMEOUT = 30
$Env:PIP_RETRIES = 5
$Env:PIP_PREFER_BINARY = 1
$Env:PIP_YES = 1
$Env:PYTHONUTF8 = 1
$Env:PYTHONIOENCODING = "utf-8"
$Env:PYTHONUNBUFFERED = 1
$Env:PYTHONNOUSERSITE = 1
$Env:PYTHONFAULTHANDLER = 1
$Env:GRADIO_ANALYTICS_ENABLED = "False"
$Env:HF_HUB_DISABLE_SYMLINKS_WARNING = 1
$Env:BITSANDBYTES_NOWELCOME = 1
$Env:ClDeviceGlobalMemSizeAvailablePercent = 100
$Env:CUDA_MODULE_LOADING = "LAZY"
$Env:TORCH_CUDNN_V8_API_ENABLED = 1
$Env:USE_LIBUV = 0
$Env:SYCL_CACHE_PERSISTENT = 1
$Env:TF_CPP_MIN_LOG_LEVEL = 3
$Env:SAFETENSORS_FAST_GPU = 1
$Env:CACHE_HOME = "$PSScriptRoot/cache"
$Env:HF_HOME = "$PSScriptRoot/cache/huggingface"
$Env:MATPLOTLIBRC = "$PSScriptRoot/cache"
$Env:MODELSCOPE_CACHE = "$PSScriptRoot/cache/modelscope/hub"
$Env:MS_CACHE_HOME = "$PSScriptRoot/cache/modelscope/hub"
$Env:SYCL_CACHE_DIR = "$PSScriptRoot/cache/libsycl_cache"
$Env:TORCH_HOME = "$PSScriptRoot/cache/torch"
$Env:U2NET_HOME = "$PSScriptRoot/cache/u2net"
$Env:XDG_CACHE_HOME = "$PSScriptRoot/cache"
$Env:PIP_CACHE_DIR = "$PSScriptRoot/cache/pip"
$Env:PYTHONPYCACHEPREFIX = "$PSScriptRoot/cache/pycache"
$Env:TORCHINDUCTOR_CACHE_DIR = "$PSScriptRoot/cache/torchinductor"
$Env:TRITON_CACHE_DIR = "$PSScriptRoot/cache/triton"
$Env:UV_CACHE_DIR = "$PSScriptRoot/cache/uv"
$Env:UV_PYTHON = "$PSScriptRoot/python/python.exe"



# 帮助信息
function Get-Fooocus-Installer-Cmdlet-Help {
    $content = "
使用:
    .\launch.ps1 [-Help] [-BuildMode] [-DisablePyPIMirror] [-DisableUpdate] [-DisableProxy] [-UseCustomProxy <代理服务器地址>] [-DisableHuggingFaceMirror] [-UseCustomHuggingFaceMirror <HuggingFace 镜像源地址>] [-DisableUV] [-LaunchArg <Fooocus 启动参数>] [-EnableShortcut] [-DisableCUDAMalloc] [-DisableEnvCheck] [-DisableAutoApplyUpdate]

参数:
    -Help
        获取 Fooocus Installer 的帮助信息

    -BuildMode
        启用 Fooocus Installer 构建模式

    -DisablePyPIMirror
        禁用 PyPI 镜像源, 使用 PyPI 官方源下载 Python 软件包

    -DisableUpdate
        禁用 Fooocus Installer 更新检查

    -DisableProxy
        禁用 Fooocus Installer 自动设置代理服务器

    -UseCustomProxy <代理服务器地址>
        使用自定义的代理服务器地址, 例如代理服务器地址为 http://127.0.0.1:10809, 则使用 -UseCustomProxy `"http://127.0.0.1:10809`" 设置代理服务器地址

    -DisableHuggingFaceMirror
        禁用 HuggingFace 镜像源, 不使用 HuggingFace 镜像源下载文件

    -UseCustomHuggingFaceMirror <HuggingFace 镜像源地址>
        使用自定义 HuggingFace 镜像源地址, 例如代理服务器地址为 https://hf-mirror.com, 则使用 -UseCustomHuggingFaceMirror `"https://hf-mirror.com`" 设置 HuggingFace 镜像源地址

    -DisableUV
        禁用 Fooocus Installer 使用 uv 安装 Python 软件包, 使用 Pip 安装 Python 软件包

    -LaunchArg <Fooocus 启动参数>
        设置 Fooocus 自定义启动参数, 如启用 --disable-offload-from-vram 和 --disable-analytics, 则使用 -LaunchArg `"--disable-offload-from-vram --disable-analytics`" 进行启用

    -EnableShortcut
        创建 Fooocus 启动快捷方式

    -DisableCUDAMalloc
        禁用 Fooocus Installer 通过 PYTORCH_CUDA_ALLOC_CONF 环境变量设置 CUDA 内存分配器

    -DisableEnvCheck
        禁用 Fooocus Installer 检查 Fooocus 运行环境中存在的问题, 禁用后可能会导致 Fooocus 环境中存在的问题无法被发现并修复

    -DisableAutoApplyUpdate
        禁用 Fooocus Installer 自动应用新版本更新


更多的帮助信息请阅读 Fooocus Installer 使用文档: https://github.com/licyk/sd-webui-all-in-one/blob/main/fooocus_installer.md
".Trim()

    if ($Help) {
        Write-Host $content
        exit 0
    }
}


# 消息输出
function Print-Msg ($msg) {
    Write-Host "[$(Get-Date -Format "yyyy-MM-dd HH:mm:ss")]" -ForegroundColor Yellow -NoNewline
    Write-Host "[Fooocus Installer]" -ForegroundColor Cyan -NoNewline
    Write-Host ":: " -ForegroundColor Blue -NoNewline
    Write-Host "$msg"
}


# 显示 Fooocus Installer 版本
function Get-Fooocus-Installer-Version {
    $ver = $([string]$FOOOCUS_INSTALLER_VERSION).ToCharArray()
    $major = ($ver[0..($ver.Length - 3)])
    $minor = $ver[-2]
    $micro = $ver[-1]
    Print-Msg "Fooocus Installer 版本: v${major}.${minor}.${micro}"
}


# PyPI 镜像源状态
function PyPI-Mirror-Status {
    if ($USE_PIP_MIRROR) {
        Print-Msg "使用 PyPI 镜像源"
    } else {
        Print-Msg "检测到 disable_pypi_mirror.txt 配置文件 / -DisablePyPIMirror 命令行参数, 已将 PyPI 源切换至官方源"
    }
}


# 修复 PyTorch 的 libomp 问题
function Fix-PyTorch {
    $content = "
import importlib.util
import shutil
import os
import ctypes
import logging


torch_spec = importlib.util.find_spec('torch')
for folder in torch_spec.submodule_search_locations:
    lib_folder = os.path.join(folder, 'lib')
    test_file = os.path.join(lib_folder, 'fbgemm.dll')
    dest = os.path.join(lib_folder, 'libomp140.x86_64.dll')
    if os.path.exists(dest):
        break

    with open(test_file, 'rb') as f:
        contents = f.read()
        if b'libomp140.x86_64.dll' not in contents:
            break
    try:
        mydll = ctypes.cdll.LoadLibrary(test_file)
    except FileNotFoundError as e:
        logging.warning('检测到 PyTorch 版本存在 libomp 问题, 进行修复')
        shutil.copyfile(os.path.join(lib_folder, 'libiomp5md.dll'), dest)
".Trim()

    Print-Msg "检测 PyTorch 的 libomp 问题中"
    python -c "$content"
    Print-Msg "PyTorch 检查完成"
}


# Fooocus Installer 更新检测
function Check-Fooocus-Installer-Update {
    # 可用的下载源
    $urls = @(
        "https://github.com/licyk/sd-webui-all-in-one/raw/main/fooocus_installer.ps1",
        "https://gitee.com/licyk/sd-webui-all-in-one/raw/main/fooocus_installer.ps1",
        "https://github.com/licyk/sd-webui-all-in-one/releases/download/fooocus_installer/fooocus_installer.ps1",
        "https://gitee.com/licyk/sd-webui-all-in-one/releases/download/fooocus_installer/fooocus_installer.ps1",
        "https://gitlab.com/licyk/sd-webui-all-in-one/-/raw/main/fooocus_installer.ps1"
    )
    $i = 0

    New-Item -ItemType Directory -Path "$Env:CACHE_HOME" -Force > $null

    if ((Test-Path "$PSScriptRoot/disable_update.txt") -or ($DisableUpdate)) {
        Print-Msg "检测到 disable_update.txt 更新配置文件 / -DisableUpdate 命令行参数, 已禁用 Fooocus Installer 的自动检查更新功能"
        return
    }

    # 获取更新时间间隔
    try {
        $last_update_time = Get-Content "$PSScriptRoot/update_time.txt" 2> $null
        $last_update_time = Get-Date $last_update_time -Format "yyyy-MM-dd HH:mm:ss"
    }
    catch {
        $last_update_time = Get-Date 0 -Format "yyyy-MM-dd HH:mm:ss"
    }
    finally {
        $update_time = Get-Date -Format "yyyy-MM-dd HH:mm:ss"
        $time_span = New-TimeSpan -Start $last_update_time -End $update_time
    }

    if ($time_span.TotalSeconds -gt $UPDATE_TIME_SPAN) {
        Set-Content -Encoding UTF8 -Path "$PSScriptRoot/update_time.txt" -Value $(Get-Date -Format "yyyy-MM-dd HH:mm:ss") # 记录更新时间
    } else {
        return
    }

    ForEach ($url in $urls) {
        Print-Msg "检查 Fooocus Installer 更新中"
        try {
            Invoke-WebRequest -Uri $url -OutFile "$Env:CACHE_HOME/fooocus_installer.ps1"
            $latest_version = [int]$(
                Get-Content "$Env:CACHE_HOME/fooocus_installer.ps1" |
                Select-String -Pattern "FOOOCUS_INSTALLER_VERSION" |
                ForEach-Object { $_.ToString() }
            )[0].Split("=")[1].Trim()
            break
        }
        catch {
            $i += 1
            if ($i -lt $urls.Length) {
                Print-Msg "重试检查 Fooocus Installer 更新中"
            } else {
                Print-Msg "检查 Fooocus Installer 更新失败"
                return
            }
        }
    }

    if ($latest_version -le $FOOOCUS_INSTALLER_VERSION) {
        Print-Msg "Fooocus Installer 已是最新版本"
        return
    }

    if (($DisableAutoApplyUpdate) -or (Test-Path "$PSScriptRoot/disable_auto_apply_update.txt")) {
        Print-Msg "检测到 Fooocus Installer 有新版本可用, 是否进行更新 (yes/no) ?"
        Print-Msg "提示: 输入 yes 确认或 no 取消 (默认为 no)"
        $arg = (Read-Host "========================================>").Trim()
        if (!($arg -eq "yes" -or $arg -eq "y" -or $arg -eq "YES" -or $arg -eq "Y")) {
            Print-Msg "跳过 Fooocus Installer 更新"
            return
        }
    } else {
        Print-Msg "检测到 Fooocus Installer 有新版本可用"
    }

    Print-Msg "调用 Fooocus Installer 进行更新中"
    . "$Env:CACHE_HOME/fooocus_installer.ps1" -InstallPath "$PSScriptRoot" -UseUpdateMode
    $raw_params = $script:MyInvocation.Line -replace "^.*\.ps1[\s]*", ""
    Print-Msg "更新结束, 重新启动 Fooocus Installer 管理脚本中, 使用的命令行参数: $raw_params"
    Invoke-Expression "& `"$PSCommandPath`" $raw_params"
    exit 0
}


# 代理配置
function Set-Proxy {
    $Env:NO_PROXY = "localhost,127.0.0.1,::1"
    # 检测是否禁用自动设置镜像源
    if ((Test-Path "$PSScriptRoot/disable_proxy.txt") -or ($DisableProxy)) {
        Print-Msg "检测到本地存在 disable_proxy.txt 代理配置文件 / -DisableProxy 命令行参数, 禁用自动设置代理"
        return
    }

    $internet_setting = Get-ItemProperty -Path "HKCU:\Software\Microsoft\Windows\CurrentVersion\Internet Settings"
    if ((Test-Path "$PSScriptRoot/proxy.txt") -or ($UseCustomProxy)) { # 本地存在代理配置
        if ($UseCustomProxy) {
            $proxy_value = $UseCustomProxy
        } else {
            $proxy_value = Get-Content "$PSScriptRoot/proxy.txt"
        }
        $Env:HTTP_PROXY = $proxy_value
        $Env:HTTPS_PROXY = $proxy_value
        Print-Msg "检测到本地存在 proxy.txt 代理配置文件 / -UseCustomProxy 命令行参数, 已读取代理配置文件并设置代理"
    } elseif ($internet_setting.ProxyEnable -eq 1) { # 系统已设置代理
        $proxy_addr = $($internet_setting.ProxyServer)
        # 提取代理地址
        if (($proxy_addr -match "http=(.*?);") -or ($proxy_addr -match "https=(.*?);")) {
            $proxy_value = $matches[1]
            # 去除 http / https 前缀
            $proxy_value = $proxy_value.ToString().Replace("http://", "").Replace("https://", "")
            $proxy_value = "http://${proxy_value}"
        } elseif ($proxy_addr -match "socks=(.*)") {
            $proxy_value = $matches[1]
            # 去除 socks 前缀
            $proxy_value = $proxy_value.ToString().Replace("http://", "").Replace("https://", "")
            $proxy_value = "socks://${proxy_value}"
        } else {
            $proxy_value = "http://${proxy_addr}"
        }
        $Env:HTTP_PROXY = $proxy_value
        $Env:HTTPS_PROXY = $proxy_value
        Print-Msg "检测到系统设置了代理, 已读取系统中的代理配置并设置代理"
    }
}


# HuggingFace 镜像源
function Set-HuggingFace-Mirror {
    if ((Test-Path "$PSScriptRoot/disable_hf_mirror.txt") -or ($DisableHuggingFaceMirror)) { # 检测是否禁用了自动设置 HuggingFace 镜像源
        Print-Msg "检测到本地存在 disable_hf_mirror.txt 镜像源配置文件 / -DisableHuggingFaceMirror 命令行参数, 禁用自动设置 HuggingFace 镜像源"
        return
    }

    if ((Test-Path "$PSScriptRoot/hf_mirror.txt") -or ($UseCustomHuggingFaceMirror)) { # 本地存在 HuggingFace 镜像源配置
        if ($UseCustomHuggingFaceMirror) {
            $hf_mirror_value = $UseCustomHuggingFaceMirror
        } else {
            $hf_mirror_value = Get-Content "$PSScriptRoot/hf_mirror.txt"
        }
        $Env:HF_ENDPOINT = $hf_mirror_value
        Print-Msg "检测到本地存在 hf_mirror.txt 配置文件 / -UseCustomHuggingFaceMirror 命令行参数, 已读取该配置并设置 HuggingFace 镜像源"
    } else { # 使用默认设置
        $Env:HF_ENDPOINT = "https://hf-mirror.com"
        Print-Msg "使用默认 HuggingFace 镜像源"
    }
}


# 检查 uv 是否需要更新
function Check-uv-Version {
    $content = "
import re
from importlib.metadata import version



def compare_versions(version1, version2) -> int:
    try:
        nums1 = re.sub(r'[a-zA-Z]+', '', version1).replace('-', '.').replace('+', '.').split('.')
        nums2 = re.sub(r'[a-zA-Z]+', '', version2).replace('-', '.').replace('+', '.').split('.')
    except:
        return 0

    for i in range(max(len(nums1), len(nums2))):
        num1 = int(nums1[i]) if i < len(nums1) else 0
        num2 = int(nums2[i]) if i < len(nums2) else 0

        if num1 == num2:
            continue
        elif num1 > num2:
            return 1
        else:
            return -1

    return 0



def is_uv_need_update() -> bool:
    try:
        uv_ver = version('uv')
    except:
        return True
    
    if compare_versions(uv_ver, uv_minimum_ver) == -1:
        return True
    else:
        return False



uv_minimum_ver = '$UV_MINIMUM_VER'
print(is_uv_need_update())
".Trim()

    Print-Msg "检测 uv 是否需要更新"
    $status = $(python -c "$content")
    if ($status -eq "True") {
        Print-Msg "更新 uv 中"
        python -m pip install -U "uv>=$UV_MINIMUM_VER"
        if ($?) {
            Print-Msg "uv 更新成功"
        } else {
            Print-Msg "uv 更新失败, 可能会造成 uv 部分功能异常"
        }
    } else {
        Print-Msg "uv 无需更新"
    }
}


# 设置 uv 的使用状态
function Set-uv {
    # 切换 uv 指定的 Python
    if (Test-Path "$PSScriptRoot/Fooocus/python/python.exe") {
        $Env:UV_PYTHON = "$PSScriptRoot/Fooocus/python/python.exe"
    }

    if ((Test-Path "$PSScriptRoot/disable_uv.txt") -or ($DisableUV)) {
        Print-Msg "检测到 disable_uv.txt 配置文件 / -DisableUV 命令行参数, 已禁用 uv, 使用 Pip 作为 Python 包管理器"
        $Global:USE_UV = $false
    } else {
        Print-Msg "默认启用 uv 作为 Python 包管理器, 加快 Python 软件包的安装速度"
        Print-Msg "当 uv 安装 Python 软件包失败时, 将自动切换成 Pip 重试 Python 软件包的安装"
        $Global:USE_UV = $true
        Check-uv-Version
    }
}


# Fooocus 启动参数
function Get-Fooocus-Launch-Args {
    $arguments = New-Object System.Collections.ArrayList
    if ((Test-Path "$PSScriptRoot/launch_args.txt") -or ($LaunchArg)) {
        if ($LaunchArg) {
            $launch_args = $LaunchArg
        } else {
            $launch_args = Get-Content "$PSScriptRoot/launch_args.txt"
        }
        if ($launch_args.Trim().Split().Length -le 1) {
            $arguments = $launch_args.Trim().Split()
        } else {
            $arguments = [regex]::Matches($launch_args, '("[^"]*"|''[^'']*''|\S+)') | ForEach-Object {
                $_.Value -replace '^["'']|["'']$', ''
            }
        }
        Print-Msg "检测到本地存在 launch_args.txt 启动参数配置文件 / -LaunchArg 命令行参数, 已读取该启动参数配置文件并应用启动参数"
        Print-Msg "使用的启动参数: $arguments"
    }
    return $arguments
}


# 设置 Fooocus 的快捷启动方式
function Create-Fooocus-Shortcut {
    # 设置快捷方式名称
    if ((Get-Command git -ErrorAction SilentlyContinue) -and (Test-Path "$PSScriptRoot/Fooocus/.git")) {
        $git_remote = $(git -C "$PSScriptRoot/Fooocus" remote get-url origin)
        $array = $git_remote -split "/"
        $branch = "$($array[-2])/$($array[-1])"
        if (($branch -eq "lllyasviel/Fooocus") -or ($branch -eq "lllyasviel/Fooocus.git")) {
            $filename = "Fooocus"
        } elseif (($branch -eq "MoonRide303/Fooocus-MRE") -or ($branch -eq "MoonRide303/Fooocus-MRE.git")) {
            $filename = "Fooocus-MRE"
        } elseif (($branch -eq "runew0lf/RuinedFooocus") -or ($branch -eq "runew0lf/RuinedFooocus.git")) {
            $filename = "RuinedFooocus"
        } else {
            $filename = "Fooocus"
        }
    } else {
        $filename = "Fooocus"
    }

    $url = "https://modelscope.cn/models/licyks/invokeai-core-model/resolve/master/pypatchmatch/gradio_icon.ico"
    $shortcut_icon = "$PSScriptRoot/gradio_icon.ico"

    if ((!(Test-Path "$PSScriptRoot/enable_shortcut.txt")) -and (!($EnableShortcut))) {
        return
    }

    Print-Msg "检测到 enable_shortcut.txt 配置文件 / -EnableShortcut 命令行参数, 开始检查 Fooocus 快捷启动方式中"
    if (!(Test-Path "$shortcut_icon")) {
        Print-Msg "获取 Fooocus 图标中"
        Invoke-WebRequest -Uri $url -OutFile "$PSScriptRoot/gradio_icon.ico"
        if (!($?)) {
            Print-Msg "获取 Fooocus 图标失败, 无法创建 Fooocus 快捷启动方式"
            return
        }
    }

    Print-Msg "更新 Fooocus 快捷启动方式"
    $shell = New-Object -ComObject WScript.Shell
    $desktop = [System.Environment]::GetFolderPath("Desktop")
    $shortcut_path = "$desktop\$filename.lnk"
    $shortcut = $shell.CreateShortcut($shortcut_path)
    $shortcut.TargetPath = "$PSHome\powershell.exe"
    $launch_script_path = $(Get-Item "$PSScriptRoot/launch.ps1").FullName
    $shortcut.Arguments = "-ExecutionPolicy Bypass -File `"$launch_script_path`""
    $shortcut.IconLocation = $shortcut_icon

    # 保存到桌面
    $shortcut.Save()
    $start_menu_path = "$Env:APPDATA/Microsoft/Windows/Start Menu/Programs"
    $taskbar_path = "$Env:APPDATA\Microsoft\Internet Explorer\Quick Launch\User Pinned\TaskBar"
    # 保存到开始菜单
    Copy-Item -Path "$shortcut_path" -Destination "$start_menu_path" -Force
    # 固定到任务栏
    # Copy-Item -Path "$shortcut_path" -Destination "$taskbar_path" -Force
    # $shell = New-Object -ComObject Shell.Application
    # $shell.Namespace([System.IO.Path]::GetFullPath($taskbar_path)).ParseName((Get-Item $shortcut_path).Name).InvokeVerb('taskbarpin')
}


# 设置 CUDA 内存分配器
function Set-PyTorch-CUDA-Memory-Alloc {
    if ((!(Test-Path "$PSScriptRoot/disable_set_pytorch_cuda_memory_alloc.txt")) -and (!($DisableCUDAMalloc))) {
        Print-Msg "检测是否可设置 CUDA 内存分配器"
    } else {
        Print-Msg "检测到 disable_set_pytorch_cuda_memory_alloc.txt 配置文件 / -DisableCUDAMalloc 命令行参数, 已禁用自动设置 CUDA 内存分配器"
        return
    }

    $content = "
import os
import importlib.util
import subprocess

#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
    if os.name == 'nt':
        import ctypes

        # Define necessary C structures and types
        class DISPLAY_DEVICEA(ctypes.Structure):
            _fields_ = [
                ('cb', ctypes.c_ulong),
                ('DeviceName', ctypes.c_char * 32),
                ('DeviceString', ctypes.c_char * 128),
                ('StateFlags', ctypes.c_ulong),
                ('DeviceID', ctypes.c_char * 128),
                ('DeviceKey', ctypes.c_char * 128)
            ]

        # Load user32.dll
        user32 = ctypes.windll.user32

        # Call EnumDisplayDevicesA
        def enum_display_devices():
            device_info = DISPLAY_DEVICEA()
            device_info.cb = ctypes.sizeof(device_info)
            device_index = 0
            gpu_names = set()

            while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
                device_index += 1
                gpu_names.add(device_info.DeviceString.decode('utf-8'))
            return gpu_names
        return enum_display_devices()
    else:
        gpu_names = set()
        out = subprocess.check_output(['nvidia-smi', '-L'])
        for l in out.split(b'\n'):
            if len(l) > 0:
                gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
        return gpu_names

blacklist = {'GeForce GTX TITAN X', 'GeForce GTX 980', 'GeForce GTX 970', 'GeForce GTX 960', 'GeForce GTX 950', 'GeForce 945M',
                'GeForce 940M', 'GeForce 930M', 'GeForce 920M', 'GeForce 910M', 'GeForce GTX 750', 'GeForce GTX 745', 'Quadro K620',
                'Quadro K1200', 'Quadro K2200', 'Quadro M500', 'Quadro M520', 'Quadro M600', 'Quadro M620', 'Quadro M1000',
                'Quadro M1200', 'Quadro M2000', 'Quadro M2200', 'Quadro M3000', 'Quadro M4000', 'Quadro M5000', 'Quadro M5500', 'Quadro M6000',
                'GeForce MX110', 'GeForce MX130', 'GeForce 830M', 'GeForce 840M', 'GeForce GTX 850M', 'GeForce GTX 860M',
                'GeForce GTX 1650', 'GeForce GTX 1630', 'Tesla M4', 'Tesla M6', 'Tesla M10', 'Tesla M40', 'Tesla M60'
                }


def cuda_malloc_supported():
    try:
        names = get_gpu_names()
    except:
        names = set()
    for x in names:
        if 'NVIDIA' in x:
            for b in blacklist:
                if b in x:
                    return False
    return True


def is_nvidia_device():
    try:
        names = get_gpu_names()
    except:
        names = set()
    for x in names:
        if 'NVIDIA' in x:
            return True
    return False


def get_pytorch_cuda_alloc_conf(is_cuda = True):
    if is_nvidia_device():
        if cuda_malloc_supported():
            if is_cuda:
                return 'cuda_malloc'
            else:
                return 'pytorch_malloc'
        else:
            return 'pytorch_malloc'
    else:
        return None


def main():
    try:
        version = ''
        torch_spec = importlib.util.find_spec('torch')
        for folder in torch_spec.submodule_search_locations:
            ver_file = os.path.join(folder, 'version.py')
            if os.path.isfile(ver_file):
                spec = importlib.util.spec_from_file_location('torch_version_import', ver_file)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                version = module.__version__
        if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
            if '+cu' in version: #only on cuda torch
                print(get_pytorch_cuda_alloc_conf())
            else:
                print(get_pytorch_cuda_alloc_conf(False))
        else:
            print(None)
    except Exception as _:
        print(None)


if __name__ == '__main__':
    main()
".Trim()

    $status = $(python -c "$content")
    switch ($status) {
        cuda_malloc {
            Print-Msg "设置 CUDA 内存分配器为 CUDA 内置异步分配器"
            $Env:PYTORCH_CUDA_ALLOC_CONF = "backend:cudaMallocAsync"
        }
        pytorch_malloc {
            Print-Msg "设置 CUDA 内存分配器为 PyTorch 原生分配器"
            $Env:PYTORCH_CUDA_ALLOC_CONF = "garbage_collection_threshold:0.9,max_split_size_mb:512"
        }
        Default {
            Print-Msg "显卡非 Nvidia 显卡, 无法设置 CUDA 内存分配器"
        }
    }
}


# 检查 Fooocus 依赖完整性
function Check-Fooocus-Requirements {
    $content = "
'''运行环境检查'''
import re
import os
import sys
import copy
import logging
import argparse
import importlib.metadata
from collections import namedtuple
from pathlib import Path
from typing import Optional, TypedDict, Union


def get_args() -> argparse.Namespace:
    '''获取命令行参数输入参数输入'''
    parser = argparse.ArgumentParser(description='运行环境检查')
    def normalized_filepath(filepath): return str(
        Path(filepath).absolute().as_posix())

    parser.add_argument(
        '--requirement-path', type=normalized_filepath, default=None, help='依赖文件路径')
    parser.add_argument('--debug-mode', action='store_true', help='显示调试信息')

    return parser.parse_args()


COMMAND_ARGS = get_args()


class ColoredFormatter(logging.Formatter):
    '''Logging 格式化'''
    COLORS = {
        'DEBUG': '\033[0;36m',          # CYAN
        'INFO': '\033[0;32m',           # GREEN
        'WARNING': '\033[0;33m',        # YELLOW
        'ERROR': '\033[0;31m',          # RED
        'CRITICAL': '\033[0;37;41m',    # WHITE ON RED
        'RESET': '\033[0m',             # RESET COLOR
    }

    def format(self, record):
        colored_record = copy.copy(record)
        levelname = colored_record.levelname
        seq = self.COLORS.get(levelname, self.COLORS['RESET'])
        colored_record.levelname = '{}{}{}'.format(
            seq, levelname, self.COLORS['RESET'])
        return super().format(colored_record)


def get_logger(
    name: str,
    level: int = logging.INFO,
) -> logging.Logger:
    '''获取 Loging 对象

    参数:
        name (str):
            Logging 名称


    '''
    logger = logging.getLogger(name)
    logger.propagate = False

    if not logger.handlers:
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            ColoredFormatter(
                '[%(name)s]-|%(asctime)s|-%(levelname)s: %(message)s', '%H:%M:%S'
            )
        )
        logger.addHandler(handler)

    logger.setLevel(level)
    logger.debug('Logger initialized.')

    return logger


logger = get_logger(
    'Env Checker',
    logging.DEBUG if COMMAND_ARGS.debug_mode else logging.INFO
)


# 提取版本标识符组件的正则表达式
# ref:
# https://peps.python.org/pep-0440
# https://packaging.python.org/en/latest/specifications/version-specifiers
VERSION_PATTERN = r'''
    v?
    (?:
        (?:(?P<epoch>[0-9]+)!)?                           # epoch
        (?P<release>[0-9]+(?:\.[0-9]+)*)                  # release segment
        (?P<pre>                                          # pre-release
            [-_\.]?
            (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
            [-_\.]?
            (?P<pre_n>[0-9]+)?
        )?
        (?P<post>                                         # post release
            (?:-(?P<post_n1>[0-9]+))
            |
            (?:
                [-_\.]?
                (?P<post_l>post|rev|r)
                [-_\.]?
                (?P<post_n2>[0-9]+)?
            )
        )?
        (?P<dev>                                          # dev release
            [-_\.]?
            (?P<dev_l>dev)
            [-_\.]?
            (?P<dev_n>[0-9]+)?
        )?
    )
    (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
'''


# 编译正则表达式
package_version_parse_regex = re.compile(
    r'^\s*' + VERSION_PATTERN + r'\s*$',
    re.VERBOSE | re.IGNORECASE,
)


# 定义版本组件的命名元组
VersionComponent = namedtuple(
    'VersionComponent', [
        'epoch',
        'release',
        'pre_l',
        'pre_n',
        'post_n1',
        'post_l',
        'post_n2',
        'dev_l',
        'dev_n',
        'local',
        'is_wildcard'
    ]
)


def parse_version(version_str: str) -> VersionComponent:
    '''解释 Python 软件包版本号

    参数:
        version_str (str):
            Python 软件包版本号

    返回值:
        VersionComponent: 版本组件的命名元组

    异常:
        ValueError: 如果 Python 版本号不符合 PEP440 规范
    '''
    # 检测并剥离通配符
    wildcard = version_str.endswith('.*') or version_str.endswith('*')
    clean_str = version_str.rstrip(
        '*').rstrip('.') if wildcard else version_str

    match = package_version_parse_regex.match(clean_str)
    if not match:
        logger.error(f'未知的版本号字符串: {version_str}')
        raise ValueError(f'Invalid version string: {version_str}')

    components = match.groupdict()

    # 处理 release 段 (允许空字符串)
    release_str = components['release'] or '0'
    release_segments = [int(seg) for seg in release_str.split('.')]

    # 构建命名元组
    return VersionComponent(
        epoch=int(components['epoch'] or 0),
        release=release_segments,
        pre_l=components['pre_l'],
        pre_n=int(components['pre_n']) if components['pre_n'] else None,
        post_n1=int(components['post_n1']) if components['post_n1'] else None,
        post_l=components['post_l'],
        post_n2=int(components['post_n2']) if components['post_n2'] else None,
        dev_l=components['dev_l'],
        dev_n=int(components['dev_n']) if components['dev_n'] else None,
        local=components['local'],
        is_wildcard=wildcard
    )


def compare_version_objects(v1: VersionComponent, v2: VersionComponent) -> int:
    '''比较两个版本字符串 Python 软件包版本号

    参数:
        v1 (VersionComponent):
            第 1 个 Python 版本号标识符组件
        v2 (VersionComponent):
            第 2 个 Python 版本号标识符组件

    返回值:
        int: 如果版本号 1 大于 版本号 2, 则返回1, 小于则返回-1, 如果相等则返回0
    '''

    # 比较 epoch
    if v1.epoch != v2.epoch:
        return v1.epoch - v2.epoch

    # 对其 release 长度, 缺失部分补 0
    if len(v1.release) != len(v2.release):
        for _ in range(abs(len(v1.release) - len(v2.release))):
            if len(v1.release) < len(v2.release):
                v1.release.append(0)
            else:
                v2.release.append(0)

    # 比较 release
    for n1, n2 in zip(v1.release, v2.release):
        if n1 != n2:
            return n1 - n2
    # 如果 release 长度不同，较短的版本号视为较小 ?
    # 但是这样是行不通的! 比如 0.15.0 和 0.15, 处理后就会变成 [0, 15, 0] 和 [0, 15]
    # 计算结果就会变成 len([0, 15, 0]) > len([0, 15])
    # 但 0.15.0 和 0.15 实际上是一样的版本
    # if len(v1.release) != len(v2.release):
    #     return len(v1.release) - len(v2.release)

    # 比较 pre-release
    if v1.pre_l and not v2.pre_l:
        return -1  # pre-release 小于正常版本
    elif not v1.pre_l and v2.pre_l:
        return 1
    elif v1.pre_l and v2.pre_l:
        pre_order = {
            'a': 0,
            'b': 1,
            'c': 2,
            'rc': 3,
            'alpha': 0,
            'beta': 1,
            'pre': 0,
            'preview': 0
        }
        if pre_order[v1.pre_l] != pre_order[v2.pre_l]:
            return pre_order[v1.pre_l] - pre_order[v2.pre_l]
        elif v1.pre_n is not None and v2.pre_n is not None:
            return v1.pre_n - v2.pre_n
        elif v1.pre_n is None and v2.pre_n is not None:
            return -1
        elif v1.pre_n is not None and v2.pre_n is None:
            return 1

    # 比较 post-release
    if v1.post_n1 is not None:
        post_n1 = v1.post_n1
    elif v1.post_l:
        post_n1 = int(v1.post_n2) if v1.post_n2 else 0
    else:
        post_n1 = 0

    if v2.post_n1 is not None:
        post_n2 = v2.post_n1
    elif v2.post_l:
        post_n2 = int(v2.post_n2) if v2.post_n2 else 0
    else:
        post_n2 = 0

    if post_n1 != post_n2:
        return post_n1 - post_n2

    # 比较 dev-release
    if v1.dev_l and not v2.dev_l:
        return -1  # dev-release 小于 post-release 或正常版本
    elif not v1.dev_l and v2.dev_l:
        return 1
    elif v1.dev_l and v2.dev_l:
        if v1.dev_n is not None and v2.dev_n is not None:
            return v1.dev_n - v2.dev_n
        elif v1.dev_n is None and v2.dev_n is not None:
            return -1
        elif v1.dev_n is not None and v2.dev_n is None:
            return 1

    # 比较 local version
    if v1.local and not v2.local:
        return -1  # local version 小于 dev-release 或正常版本
    elif not v1.local and v2.local:
        return 1
    elif v1.local and v2.local:
        local1 = v1.local.split('.')
        local2 = v2.local.split('.')
        # 和 release 的处理方式一致, 对其 local version 长度, 缺失部分补 0
        if len(local1) != len(local2):
            for _ in range(abs(len(local1) - len(local2))):
                if len(local1) < len(local2):
                    local1.append(0)
                else:
                    local2.append(0)
        for l1, l2 in zip(local1, local2):
            if l1.isdigit() and l2.isdigit():
                l1, l2 = int(l1), int(l2)
            if l1 != l2:
                return (l1 > l2) - (l1 < l2)
        return len(local1) - len(local2)

    return 0  # 版本相同


def compare_versions(version1: str, version2: str) -> int:
    '''比较两个版本字符串 Python 软件包版本号

    参数:
        version1 (str):
            版本号 1
        version2 (str):
            版本号 2

    返回值:
        int: 如果版本号 1 大于 版本号 2, 则返回1, 小于则返回-1, 如果相等则返回0
    '''
    v1 = parse_version(version1)
    v2 = parse_version(version2)
    return compare_version_objects(v1, v2)


def compatible_version_matcher(spec_version: str):
    '''PEP 440 兼容性版本匹配 (~= 操作符)

    返回值:
        _is_compatible(version_str: str) -> bool: 一个接受 version_str (str) 参数的判断函数
    '''
    # 解析规范版本
    spec = parse_version(spec_version)

    # 获取有效release段（去除末尾的零）
    clean_release = []
    for num in spec.release:
        if num != 0 or (clean_release and clean_release[-1] != 0):
            clean_release.append(num)

    # 确定最低版本和前缀匹配规则
    if len(clean_release) == 0:
        logger.error('解析到错误的兼容性发行版本号')
        raise ValueError('Invalid version for compatible release clause')

    # 生成前缀匹配模板（忽略后缀）
    prefix_length = len(clean_release) - 1
    if prefix_length == 0:
        # 处理类似 ~= 2 的情况（实际 PEP 禁止，但这里做容错）
        prefix_pattern = [spec.release[0]]
        min_version = parse_version(f'{spec.release[0]}')
    else:
        prefix_pattern = list(spec.release[:prefix_length])
        min_version = spec

    def _is_compatible(version_str: str) -> bool:
        target = parse_version(version_str)

        # 主版本前缀检查
        target_prefix = target.release[:len(prefix_pattern)]
        if target_prefix != prefix_pattern:
            return False

        # 最低版本检查 (自动忽略 pre/post/dev 后缀)
        return compare_version_objects(target, min_version) >= 0

    return _is_compatible


def version_match(spec: str, version: str) -> bool:
    '''PEP 440 版本前缀匹配

    参数:
        spec (str): 版本匹配表达式 (e.g. '1.1.*')
        version (str): 需要检测的实际版本号 (e.g. '1.1a1')

    返回值:
        bool: 是否匹配
    '''
    # 分离通配符和本地版本
    spec_parts = spec.split('+', 1)
    spec_main = spec_parts[0].rstrip('.*')  # 移除通配符
    has_wildcard = spec.endswith('.*') and '+' not in spec

    # 解析规范版本 (不带通配符)
    try:
        spec_ver = parse_version(spec_main)
    except ValueError:
        return False

    # 解析目标版本 (忽略本地版本)
    target_ver = parse_version(version.split('+', 1)[0])

    # 前缀匹配规则
    if has_wildcard:
        # 生成补零后的 release 段
        spec_release = spec_ver.release.copy()
        while len(spec_release) < len(target_ver.release):
            spec_release.append(0)

        # 比较前 N 个 release 段 (N 为规范版本长度)
        return (
            target_ver.release[:len(spec_ver.release)] == spec_ver.release
            and target_ver.epoch == spec_ver.epoch
        )
    else:
        # 严格匹配时使用原比较函数
        return compare_versions(spec_main, version) == 0


def is_v1_ge_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否大于或等于 v2

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号大于或等于 v2 版本号则返回True
        e.g.:
            1.1, 1.0 -> True
            1.0, 1.0 -> True
            0.9, 1.0 -> False
    '''
    return compare_versions(v1, v2) >= 0


def is_v1_gt_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否大于 v2

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号大于 v2 版本号则返回True
        e.g.:
            1.1, 1.0 -> True
            1.0, 1.0 -> False
    '''
    return compare_versions(v1, v2) > 0


def is_v1_eq_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否等于 v2

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号等于 v2 版本号则返回True
        e.g.:
            1.0, 1.0 -> True
            0.9, 1.0 -> False
            1.1, 1.0 -> False
    '''
    return compare_versions(v1, v2) == 0


def is_v1_lt_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否小于 v2

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号小于 v2 版本号则返回True
        e.g.:
            0.9, 1.0 -> True
            1.0, 1.0 -> False
    '''
    return compare_versions(v1, v2) < 0


def is_v1_le_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否小于或等于 v2

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号小于或等于 v2 版本号则返回True
        e.g.:
            0.9, 1.0 -> True
            1.0, 1.0 -> True
            1.1, 1.0 -> False
    '''
    return compare_versions(v1, v2) <= 0


def is_v1_c_eq_v2(v1: str, v2: str) -> bool:
    '''查看 Python 版本号 v1 是否大于等于 v2, (兼容性版本匹配)

    参数:
        v1 (str):
            第 1 个 Python 软件包版本号, 该版本由 ~= 符号指定

        v2 (str):
            第 2 个 Python 软件包版本号

    返回值:
        bool: 如果 v1 版本号等于 v2 版本号则返回True
        e.g.:
            1.0*, 1.0a1 -> True
            0.9*, 1.0 -> False
    '''
    func = compatible_version_matcher(v1)
    return func(v2)


def version_string_is_canonical(version: str) -> bool:
    '''判断版本号标识符是否符合标准

    参数:
        version (str):
            版本号字符串

    返回值:
        bool: 如果版本号标识符符合 PEP 440 标准, 则返回True

    '''
    return re.match(
        r'^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$',
        version,
    ) is not None


def is_package_has_version(package: str) -> bool:
    '''检查 Python 软件包是否指定版本号

    参数:
        package (str):
            Python 软件包名

    返回值:
        bool: 如果 Python 软件包存在版本声明, 如torch==2.3.0, 则返回True
    '''
    return package != (
        package.replace('===', '')
        .replace('~=', '')
        .replace('!=', '')
        .replace('<=', '')
        .replace('>=', '')
        .replace('<', '')
        .replace('>', '')
        .replace('==', '')
    )


def get_package_name(package: str) -> str:
    '''获取 Python 软件包的包名, 去除末尾的版本声明

    参数:
        package (str):
            Python 软件包名

    返回值:
        str: 返回去除版本声明后的 Python 软件包名
    '''
    return (
        package.split('===')[0]
        .split('~=')[0]
        .split('!=')[0]
        .split('<=')[0]
        .split('>=')[0]
        .split('<')[0]
        .split('>')[0]
        .split('==')[0]
        .strip()
    )


def get_package_version(package: str) -> str:
    '''获取 Python 软件包的包版本号

    参数:
        package (str):
            Python 软件包名

    返回值:
        str: 返回 Python 软件包的包版本号
    '''
    return (
        package.split('===').pop()
        .split('~=').pop()
        .split('!=').pop()
        .split('<=').pop()
        .split('>=').pop()
        .split('<').pop()
        .split('>').pop()
        .split('==').pop()
        .strip()
    )


WHEEL_PATTERN = r'''
    ^                           # 字符串开始
    (?P<distribution>[^-]+)     # 包名 (匹配第一个非连字符段)
    -                           # 分隔符
    (?:                         # 版本号和可选构建号组合
        (?P<version>[^-]+)      # 版本号 (至少一个非连字符段)
        (?:-(?P<build>\d\w*))?  # 可选构建号 (以数字开头)
    )
    -                           # 分隔符
    (?P<python>[^-]+)           # Python 版本标签
    -                           # 分隔符
    (?P<abi>[^-]+)              # ABI 标签
    -                           # 分隔符
    (?P<platform>[^-]+)         # 平台标签
    \.whl$                      # 固定后缀
'''


def parse_wheel_filename(filename: str) -> str:
    '''解析 Python wheel 文件名并返回 distribution 名称

    参数:
        filename (str):
            wheel 文件名, 例如 pydantic-1.10.15-py3-none-any.whl

    返回值:
        str: distribution 名称, 例如 pydantic

    异常:
        ValueError: 如果文件名不符合 PEP491 规范
    '''
    match = re.fullmatch(WHEEL_PATTERN, filename, re.VERBOSE)
    if not match:
        logger.error('未知的 Wheel 文件名: %s', filename)
        raise ValueError(f'Invalid wheel filename: {filename}')
    return match.group('distribution')


def parse_wheel_version(filename: str) -> str:
    '''解析 Python wheel 文件名并返回 version 名称

    参数:
        filename (str):
            wheel 文件名, 例如 pydantic-1.10.15-py3-none-any.whl

    返回值:
        str: version 名称, 例如 1.10.15

    异常:
        ValueError: 如果文件名不符合 PEP491 规范
    '''
    match = re.fullmatch(WHEEL_PATTERN, filename, re.VERBOSE)
    if not match:
        logger.error('未知的 Wheel 文件名: %s', filename)
        raise ValueError(f'Invalid wheel filename: {filename}')
    return match.group('version')


def parse_wheel_to_package_name(filename: str) -> str:
    '''解析 Python wheel 文件名并返回 <distribution>==<version>

    参数:
        filename (str):
            wheel 文件名, 例如 pydantic-1.10.15-py3-none-any.whl

    返回值:
        str: <distribution>==<version> 名称, 例如 pydantic==1.10.15
    '''
    distribution = parse_wheel_filename(filename)
    version = parse_wheel_version(filename)
    return f'{distribution}=={version}'


def remove_optional_dependence_from_package(filename: str) -> str:
    '''移除 Python 软件包声明中可选依赖

    参数:
        filename (str):
            Python 软件包名

    返回值:
        str: 移除可选依赖后的软件包名, e.g. diffusers[torch]==0.10.2 -> diffusers==0.10.2
    '''
    return re.sub(r'\[.*?\]', '', filename)


def parse_requirement_list(requirements: list) -> list:
    '''将 Python 软件包声明列表解析成标准 Python 软件包名列表

    参数:
        requirements (list):
            Python 软件包名声明列表
            e.g:
            python
            requirements = [
                'torch==2.3.0',
                'diffusers[torch]==0.10.2',
                'NUMPY',
                '-e .',
                '--index-url https://pypi.python.org/simple',
                '--extra-index-url https://download.pytorch.org/whl/cu124',
                '--find-links https://download.pytorch.org/whl/torch_stable.html',
                '-e git+https://github.com/Nerogar/mgds.git@2c67a5a#egg=mgds',
                'git+https://github.com/WASasquatch/img2texture.git',
                'https://github.com/Panchovix/pydantic-fixreforge/releases/download/main_v1/pydantic-1.10.15-py3-none-any.whl',
                'prodigy-plus-schedule-free==1.9.1 # prodigy+schedulefree optimizer',
                'protobuf<5,>=4.25.3',
            ]
            

    返回值:
        list: 将 Python 软件包名声明列表解析成标准声明列表
        e.g. 上述例子中的软件包名声明列表将解析成:
        python
            requirements = [
                'torch==2.3.0',
                'diffusers==0.10.2',
                'numpy',
                'mgds',
                'img2texture',
                'pydantic==1.10.15',
                'prodigy-plus-schedule-free==1.9.1',
                'protobuf<5',
                'protobuf>=4.25.3',
            ]
            
    '''
    package_list = []
    canonical_package_list = []
    requirement: str
    for requirement in requirements:
        requirement = requirement.strip()
        logger.debug('原始 Python 软件包名: %s', requirement)

        if (
            requirement is None
            or requirement == ''
            or requirement.startswith('#')
            or '# skip_verify' in requirement
            or requirement.startswith('--index-url')
            or requirement.startswith('--extra-index-url')
            or requirement.startswith('--find-links')
            or requirement.startswith('-e .')
        ):
            continue

        # -e git+https://github.com/Nerogar/mgds.git@2c67a5a#egg=mgds -> mgds
        # git+https://github.com/WASasquatch/img2texture.git -> img2texture
        # git+https://github.com/deepghs/waifuc -> waifuc
        if requirement.startswith('-e git+http') or requirement.startswith('git+http'):
            egg_match = re.search(r'egg=([^#&]+)', requirement)
            if egg_match:
                package_list.append(egg_match.group(1).split('-')[0])
                continue

            package_name = os.path.basename(requirement)
            package_name = package_name.split(
                '.git')[0] if package_name.endswith('.git') else package_name
            package_list.append(package_name)
            continue

        # https://github.com/Panchovix/pydantic-fixreforge/releases/download/main_v1/pydantic-1.10.15-py3-none-any.whl -> pydantic==1.10.15
        if requirement.startswith('https://') or requirement.startswith('http://'):
            package_name = parse_wheel_to_package_name(
                os.path.basename(requirement))
            package_list.append(package_name)
            continue

        # 常规 Python 软件包声明
        # prodigy-plus-schedule-free==1.9.1 # prodigy+schedulefree optimizer -> prodigy-plus-schedule-free==1.9.1
        cleaned_requirements = re.sub(
            r'\s*#.*$', '', requirement).strip().split(',')
        if len(cleaned_requirements) > 1:
            package_name = get_package_name(cleaned_requirements[0].strip())
            for package_name_with_version_marked in cleaned_requirements:
                version_symbol = str.replace(
                    package_name_with_version_marked, package_name, '', 1)
                format_package_name = remove_optional_dependence_from_package(
                    f'{package_name}{version_symbol}'.strip())
                package_list.append(format_package_name)
        else:
            format_package_name = remove_optional_dependence_from_package(
                cleaned_requirements[0].strip())
            package_list.append(format_package_name)

    # 处理包名大小写并统一成小写
    for p in package_list:
        p: str = p.lower().strip()
        logger.debug('预处理后的 Python 软件包名: %s', p)
        if not is_package_has_version(p):
            logger.debug('%s 无版本声明', p)
            canonical_package_list.append(p)
            continue

        if version_string_is_canonical(get_package_version(p)):
            canonical_package_list.append(p)
        else:
            logger.debug('%s 软件包名的版本不符合标准', p)

    return canonical_package_list


def remove_duplicate_object_from_list(origin: list) -> list:
    '''对list进行去重

    参数:
        origin (list):
            原始的list

    返回值:
        list: 去重后的list, e.g. [1, 2, 3, 2] -> [1, 2, 3]
    '''
    return list(set(origin))


def read_packages_from_requirements_file(file_path: Union[str, Path]) -> list:
    '''从 requirements.txt 文件中读取 Python 软件包版本声明列表

    参数:
        file_path (str, Path):
            requirements.txt 文件路径

    返回值:
        list: 从 requirements.txt 文件中读取的 Python 软件包声明列表
    '''
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.readlines()
    except Exception as e:
        logger.error('打开 %s 时出现错误: %s\n请检查文件是否出现损坏', file_path, e)
        return []


def get_package_version_from_library(package_name: str) -> Union[str, None]:
    '''获取已安装的 Python 软件包版本号

    参数:
        package_name (str):

    返回值:
        (str | None): 如果获取到 Python 软件包版本号则返回版本号字符串, 否则返回None
    '''
    try:
        ver = importlib.metadata.version(package_name)
    except:
        ver = None

    if ver is None:
        try:
            ver = importlib.metadata.version(package_name.lower())
        except:
            ver = None

    if ver is None:
        try:
            ver = importlib.metadata.version(package_name.replace('_', '-'))
        except:
            ver = None

    return ver


def is_package_installed(package: str) -> bool:
    '''判断 Python 软件包是否已安装在环境中

    参数:
        package (str):
            Python 软件包名

    返回值:
        bool: 如果 Python 软件包未安装或者未安装正确的版本, 则返回False
    '''
    # 分割 Python 软件包名和版本号
    if '===' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('===')]
    elif '~=' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('~=')]
    elif '!=' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('!=')]
    elif '<=' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('<=')]
    elif '>=' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('>=')]
    elif '<' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('<')]
    elif '>' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('>')]
    elif '==' in package:
        pkg_name, pkg_version = [x.strip() for x in package.split('==')]
    else:
        pkg_name, pkg_version = package.strip(), None

    env_pkg_version = get_package_version_from_library(pkg_name)
    logger.debug(
        '已安装 Python 软件包检测: pkg_name: %s, env_pkg_version: %s, pkg_version: %s',
        pkg_name, env_pkg_version, pkg_version
    )

    if env_pkg_version is None:
        return False

    if pkg_version is not None:
        # ok = env_pkg_version === / == pkg_version
        if '===' in package or '==' in package:
            logger.debug('包含条件: === / ==')
            if is_v1_eq_v2(env_pkg_version, pkg_version):
                logger.debug('%s == %s', env_pkg_version, pkg_version)
                return True

        # ok = env_pkg_version ~= pkg_version
        if '~=' in package:
            logger.debug('包含条件: ~=')
            if is_v1_c_eq_v2(pkg_version, env_pkg_version):
                logger.debug('%s ~= %s', pkg_version, env_pkg_version)
                return True

        # ok = env_pkg_version != pkg_version
        if '!=' in package:
            logger.debug('包含条件: !=')
            if not is_v1_eq_v2(env_pkg_version, pkg_version):
                logger.debug('%s != %s', env_pkg_version, pkg_version)
                return True

        # ok = env_pkg_version <= pkg_version
        if '<=' in package:
            logger.debug('包含条件: <=')
            if is_v1_le_v2(env_pkg_version, pkg_version):
                logger.debug('%s <= %s', env_pkg_version, pkg_version)
                return True

        # ok = env_pkg_version >= pkg_version
        if '>=' in package:
            logger.debug('包含条件: >=')
            if is_v1_ge_v2(env_pkg_version, pkg_version):
                logger.debug('%s >= %s', env_pkg_version, pkg_version)
                return True

        # ok = env_pkg_version < pkg_version
        if '<' in package:
            logger.debug('包含条件: <')
            if is_v1_lt_v2(env_pkg_version, pkg_version):
                logger.debug('%s < %s', env_pkg_version, pkg_version)
                return True

        # ok = env_pkg_version > pkg_version
        if '>' in package:
            logger.debug('包含条件: >')
            if is_v1_gt_v2(env_pkg_version, pkg_version):
                logger.debug('%s > %s', env_pkg_version, pkg_version)
                return True

        logger.debug('%s 需要安装', package)
        return False

    return True


def validate_requirements(requirement_path: Union[str, Path]) -> bool:
    '''检测环境依赖是否完整

    参数:
        requirement_path (str, Path):
            依赖文件路径

    返回值:
        bool: 如果有缺失依赖则返回False
    '''
    origin_requires = read_packages_from_requirements_file(requirement_path)
    requires = parse_requirement_list(origin_requires)
    for package in requires:
        if not is_package_installed(package):
            return False

    return True


def main() -> None:
    requirement_path = COMMAND_ARGS.requirement_path

    if not os.path.isfile(requirement_path):
        logger.error('依赖文件未找到, 无法检查运行环境')
        sys.exit(1)

    logger.debug('检测运行环境中')
    print(validate_requirements(requirement_path))
    logger.debug('环境检查完成')


if __name__ == '__main__':
    main()
".Trim()

    Print-Msg "检查 Fooocus 内核依赖完整性中"
    if (!(Test-Path "$Env:CACHE_HOME")) {
        New-Item -ItemType Directory -Path "$Env:CACHE_HOME" > $null
    }
    Set-Content -Encoding UTF8 -Path "$Env:CACHE_HOME/check_fooocus_requirement.py" -Value $content

    $dep_path = "$PSScriptRoot/Fooocus/requirements_versions.txt"
    if (!(Test-Path "$dep_path")) {
        $dep_path = "$PSScriptRoot/Fooocus/requirements.txt"
    }
    if (!(Test-Path "$dep_path")) {
        Print-Msg "未检测到 Fooocus 依赖文件, 跳过依赖完整性检查"
        return
    }

    $status = $(python "$Env:CACHE_HOME/check_fooocus_requirement.py" --requirement-path "$dep_path")

    if ($status -eq "False") {
        Print-Msg "检测到 Fooocus 内核有依赖缺失, 安装 Fooocus 依赖中"
        if ($USE_UV) {
            uv pip install -r "$dep_path"
            if (!($?)) {
                Print-Msg "检测到 uv 安装 Python 软件包失败, 尝试回滚至 Pip 重试 Python 软件包安装"
                python -m pip install -r "$dep_path"
            }
        } else {
            python -m pip install -r "$dep_path"
        }
        if ($?) {
            Print-Msg "Fooocus 依赖安装成功"
        } else {
            Print-Msg "Fooocus 依赖安装失败, 这将会导致 Fooocus 缺失依赖无法正常运行"
        }
    } else {
        Print-Msg "Fooocus 无缺失依赖"
    }
}


# 检查 onnxruntime-gpu 版本问题
function Check-Onnxruntime-GPU {
    $content = "
import re
import argparse
import importlib.metadata
from pathlib import Path
from enum import Enum


def get_args() -> argparse.Namespace:
    '''获取命令行参数

    :return argparse.Namespace: 命令行参数命名空间
    '''
    parser = argparse.ArgumentParser()

    parser.add_argument('--ignore-ort-install', action='store_true', help='忽略 onnxruntime-gpu 未安装的状态, 强制进行检查')

    return parser.parse_args()


def get_onnxruntime_version_file() -> Path | None:
    '''获取记录 onnxruntime 版本的文件路径

    :return Path | None: 记录 onnxruntime 版本的文件路径
    '''
    package = 'onnxruntime-gpu'
    version_file = 'onnxruntime/capi/version_info.py'
    try:
        util = [
            p for p in importlib.metadata.files(package)
            if version_file in str(p)
        ][0]
        info_path = Path(util.locate())
    except Exception as _:
        info_path = None

    return info_path


def get_onnxruntime_support_cuda_version() -> tuple[str | None, str | None]:
    '''获取 onnxruntime 支持的 CUDA, cuDNN 版本

    :return tuple[str | None, str | None]: onnxruntime 支持的 CUDA, cuDNN 版本
    '''
    ver_path = get_onnxruntime_version_file()
    cuda_ver = None
    cudnn_ver = None
    try:
        with open(ver_path, 'r', encoding='utf8') as f:
            for line in f:
                if 'cuda_version' in line:
                    cuda_ver = get_value_from_variable(line, 'cuda_version')
                if 'cudnn_version' in line:
                    cudnn_ver = get_value_from_variable(line, 'cudnn_version')
    except Exception as _:
        pass

    return cuda_ver, cudnn_ver


def get_value_from_variable(content: str, var_name: str) -> str | None:
    '''从字符串 (Python 代码片段) 中找出指定字符串变量的值

    :param content(str): 待查找的内容
    :param var_name(str): 待查找的字符串变量
    :return str | None: 返回字符串变量的值
    '''
    pattern = fr'^\s*{var_name}\s*=\s*.*\s*$'
    match = re.findall(pattern, content, flags=re.MULTILINE)
    if match:
        match_str = ''.join(re.findall(r'[\d.]+', match[0].split('=').pop().strip()))
    return match_str if len(match_str) != 0 else None


def compare_versions(version1: str, version2: str) -> int:
    '''对比两个版本号大小

    :param version1(str): 第一个版本号
    :param version2(str): 第二个版本号
    :return int: 版本对比结果, 1 为第一个版本号大, -1 为第二个版本号大, 0 为两个版本号一样
    '''
    # 将版本号拆分成数字列表
    try:
        nums1 = (
            re.sub(r'[a-zA-Z]+', '', version1)
            .replace('-', '.')
            .replace('_', '.')
            .replace('+', '.')
            .split('.')
        )
        nums2 = (
            re.sub(r'[a-zA-Z]+', '', version2)
            .replace('-', '.')
            .replace('_', '.')
            .replace('+', '.')
            .split('.')
        )
    except Exception as _:
        return 0

    for i in range(max(len(nums1), len(nums2))):
        num1 = int(nums1[i]) if i < len(nums1) else 0  # 如果版本号 1 的位数不够, 则补 0
        num2 = int(nums2[i]) if i < len(nums2) else 0  # 如果版本号 2 的位数不够, 则补 0

        if num1 == num2:
            continue
        elif num1 > num2:
            return 1  # 版本号 1 更大
        else:
            return -1  # 版本号 2 更大

    return 0  # 版本号相同


def get_torch_cuda_ver() -> tuple[str | None, str | None, str | None]:
    '''获取 Torch 的本体, CUDA, cuDNN 版本

    :return tuple[str | None, str | None, str | None]: Torch, CUDA, cuDNN 版本
    '''
    try:
        import torch
        torch_ver = torch.__version__
        cuda_ver = torch.version.cuda
        cudnn_ver = torch.backends.cudnn.version()
        return (
            str(torch_ver) if torch_ver is not None else None,
            str(cuda_ver) if cuda_ver is not None else None,
            str(cudnn_ver) if cudnn_ver is not None else None,
        )
    except Exception as _:
        return None, None, None


class OrtType(str, Enum):
    '''onnxruntime-gpu 的类型

    版本说明: 
    - CU121CUDNN8: CUDA 12.1 + cuDNN8
    - CU121CUDNN9: CUDA 12.1 + cuDNN9
    - CU118: CUDA 11.8
    '''
    CU121CUDNN8 = 'cu121cudnn8'
    CU121CUDNN9 = 'cu121cudnn9'
    CU118 = 'cu118'

    def __str__(self):
        return self.value


def need_install_ort_ver(ignore_ort_install: bool = True) -> OrtType | None:
    '''判断需要安装的 onnxruntime 版本

    :param ignore_ort_install(bool): 当 onnxruntime 未安装时跳过检查
    :return OrtType: 需要安装的 onnxruntime-gpu 类型
    '''
    # 检测是否安装了 Torch
    torch_ver, cuda_ver, cuddn_ver = get_torch_cuda_ver()
    # 缺少 Torch / CUDA / cuDNN 版本时取消判断
    if (
        torch_ver is None
        or cuda_ver is None
        or cuddn_ver is None
    ):
        if not ignore_ort_install:
            try:
                _ = importlib.metadata.version('onnxruntime-gpu')
            except Exception as _:
                # onnxruntime-gpu 没有安装时
                return OrtType.CU121CUDNN9
        return None

    # onnxruntime 记录的 cuDNN 支持版本只有一位数, 所以 Torch 的 cuDNN 版本只能截取一位
    cuddn_ver = cuddn_ver[0]

    # 检测是否安装了 onnxruntime-gpu
    ort_support_cuda_ver, ort_support_cudnn_ver = get_onnxruntime_support_cuda_version()
    # 通常 onnxruntime 的 CUDA 版本和 cuDNN 版本会同时存在, 所以只需要判断 CUDA 版本是否存在即可
    if ort_support_cuda_ver is not None:
        # 当 onnxruntime 已安装

        # 判断 Torch 中的 CUDA 版本
        if compare_versions(cuda_ver, '12.0') >= 0:
            # CUDA >= 12.0

            # 比较 onnxtuntime 支持的 CUDA 版本是否和 Torch 中所带的 CUDA 版本匹配
            if compare_versions(ort_support_cuda_ver, '12.0') >= 0:
                # CUDA 版本为 12.x, torch 和 ort 的 CUDA 版本匹配

                # 判断 Torch 和 onnxruntime 的 cuDNN 是否匹配
                if compare_versions(ort_support_cudnn_ver, cuddn_ver) > 0:
                    # ort cuDNN 版本 > torch cuDNN 版本
                    return OrtType.CU121CUDNN8
                elif compare_versions(ort_support_cudnn_ver, cuddn_ver) < 0:
                    # ort cuDNN 版本 < torch cuDNN 版本
                    return OrtType.CU121CUDNN9
                else:
                    # 版本相等, 无需重装
                    return None
            else:
                # CUDA 版本非 12.x, 不匹配
                if compare_versions(cuddn_ver, '8') > 0:
                    return OrtType.CU121CUDNN9
                else:
                    return OrtType.CU121CUDNN8
        else:
            # CUDA <= 11.8
            if compare_versions(ort_support_cuda_ver, '12.0') < 0:
                return None
            else:
                return OrtType.CU118
    else:
        if ignore_ort_install:
            return None

        if compare_versions(cuda_ver, '12.0') >= 0:
            if compare_versions(cuddn_ver, '8') > 0:
                return OrtType.CU121CUDNN9
            else:
                return OrtType.CU121CUDNN8
        else:
            return OrtType.CU118


if __name__ == '__main__':
    arg = get_args()
    # print(need_install_ort_ver(not arg.ignore_ort_install))
    print(need_install_ort_ver())
".Trim()

    Print-Msg "检查 onnxruntime-gpu 版本问题中"
    $status = $(python -c "$content")

    $need_reinstall_ort = $false
    $need_switch_mirror = $false
    switch ($status) {
        cu118 {
            $need_reinstall_ort = $true
            $ort_version = "onnxruntime-gpu==1.18.1"
        }
        cu121cudnn9 {
            $need_reinstall_ort = $true
            $ort_version = "onnxruntime-gpu>=1.19.0"
        }
        cu121cudnn8 {
            $need_reinstall_ort = $true
            $ort_version = "onnxruntime-gpu==1.17.1"
            $need_switch_mirror = $true
        }
        Default {
            $need_reinstall_ort = $false
        }
    }

    if ($need_reinstall_ort) {
        Print-Msg "检测到 onnxruntime-gpu 所支持的 CUDA 版本 和 PyTorch 所支持的 CUDA 版本不匹配, 将执行重装操作"
        if ($need_switch_mirror) {
            $tmp_pip_index_url = $Env:PIP_INDEX_URL
            $tmp_pip_extra_index_url = $Env:PIP_EXTRA_INDEX_URL
            $tmp_uv_index_url = $Env:UV_DEFAULT_INDEX
            $tmp_UV_extra_index_url = $Env:UV_INDEX
            $Env:PIP_INDEX_URL = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
            $Env:PIP_EXTRA_INDEX_URL = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple"
            $Env:UV_DEFAULT_INDEX = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
            $Env:UV_INDEX = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple"
        }

        Print-Msg "卸载原有的 onnxruntime-gpu 中"
        python -m pip uninstall onnxruntime-gpu -y

        Print-Msg "重新安装 onnxruntime-gpu 中"
        if ($USE_UV) {
            uv pip install $ort_version
            if (!($?)) {
                Print-Msg "检测到 uv 安装 Python 软件包失败, 尝试回滚至 Pip 重试 Python 软件包安装"
                python -m pip install $ort_version
            }
        } else {
            python -m pip install $ort_version
        }
        if ($?) {
            Print-Msg "onnxruntime-gpu 重新安装成功"
        } else {
            Print-Msg "onnxruntime-gpu 重新安装失败, 这可能导致部分功能无法正常使用, 如使用反推模型无法正常调用 GPU 导致推理降速"
        }

        if ($need_switch_mirror) {
            $Env:PIP_INDEX_URL = $tmp_pip_index_url
            $Env:PIP_EXTRA_INDEX_URL = $tmp_pip_extra_index_url
            $Env:UV_DEFAULT_INDEX = $tmp_uv_index_url
            $Env:UV_INDEX = $tmp_UV_extra_index_url
        }
    } else {
        Print-Msg "onnxruntime-gpu 无版本问题"
    }
}


# 检查 Numpy 版本
function Check-Numpy-Version {
    $content = "
import importlib.metadata
from importlib.metadata import version

try:
    ver = int(version('numpy').split('.')[0])
except:
    ver = -1

if ver > 1:
    print(True)
else:
    print(False)
".Trim()

    Print-Msg "检查 Numpy 版本中"
    $status = $(python -c "$content")

    if ($status -eq "True") {
        Print-Msg "检测到 Numpy 版本大于 1, 这可能导致部分组件出现异常, 尝试重装中"
        if ($USE_UV) {
            uv pip install "numpy==1.26.4"
            if (!($?)) {
                Print-Msg "检测到 uv 安装 Python 软件包失败, 尝试回滚至 Pip 重试 Python 软件包安装"
                python -m pip install "numpy==1.26.4"
            }
        } else {
            python -m pip install "numpy==1.26.4"
        }
        if ($?) {
            Print-Msg "Numpy 重新安装成功"
        } else {
            Print-Msg "Numpy 重新安装失败, 这可能导致部分功能异常"
        }
    } else {
        Print-Msg "Numpy 无版本问题"
    }
}


# 检测 Microsoft Visual C++ Redistributable
function Check-MS-VCPP-Redistributable {
    Print-Msg "检测 Microsoft Visual C++ Redistributable 是否缺失"
    if ([string]::IsNullOrEmpty($Env:SYSTEMROOT)) {
        $vc_runtime_dll_path = "C:/Windows/System32/vcruntime140_1.dll"
    } else {
        $vc_runtime_dll_path = "$Env:SYSTEMROOT/System32/vcruntime140_1.dll"
    }

    if (Test-Path "$vc_runtime_dll_path") {
        Print-Msg "Microsoft Visual C++ Redistributable 未缺失"
    } else {
        Print-Msg "检测到 Microsoft Visual C++ Redistributable 缺失, 这可能导致 PyTorch 无法正常识别 GPU 导致报错"
        Print-Msg "Microsoft Visual C++ Redistributable 下载: https://aka.ms/vs/17/release/vc_redist.x64.exe"
        Print-Msg "请下载并安装 Microsoft Visual C++ Redistributable 后重新启动"
        Start-Sleep -Seconds 2
    }
}


# 检查 Fooocus 运行环境
function Check-Fooocus-Env {
    if ((Test-Path "$PSScriptRoot/disable_check_env.txt") -or ($DisableEnvCheck)) {
        Print-Msg "检测到 disable_check_env.txt 配置文件 / -DisableEnvCheck 命令行参数, 已禁用 Fooocus 运行环境检测, 这可能会导致 Fooocus 运行环境中存在的问题无法被发现并解决"
        return
    } else {
        Print-Msg "检查 Fooocus 运行环境中"
    }

    Check-Fooocus-Requirements
    Fix-PyTorch
    Check-Onnxruntime-GPU
    Check-Numpy-Version
    Check-MS-VCPP-Redistributable
    Print-Msg "Fooocus 运行环境检查完成"
}


# 设置 Fooocus 的 HuggingFace 镜像
function Get-Fooocus-HuggingFace-Mirror-Arg {
    $hf_mirror_arg = New-Object System.Collections.ArrayList

    if ((Get-Command git -ErrorAction SilentlyContinue) -and (Test-Path "$PSScriptRoot/Fooocus/.git")) {
        $git_remote = $(git -C "$PSScriptRoot/Fooocus" remote get-url origin)
        $array = $git_remote -split "/"
        $branch = "$($array[-2])/$($array[-1])"
        if (!(($branch -eq "lllyasviel/Fooocus") -or ($branch -eq "lllyasviel/Fooocus.git"))) {
            return $hf_mirror_arg
        }
    }

    if ((!(Test-Path "$PSScriptRoot/disable_hf_mirror.txt")) -and (!($DisableHuggingFaceMirror))) {
        $hf_mirror_arg.Add("--hf-mirror") | Out-Null
        $hf_mirror_arg.Add("$Env:HF_ENDPOINT") | Out-Null
    }

    return $hf_mirror_arg
}


function Main {
    Print-Msg "初始化中"
    Get-Fooocus-Installer-Version
    Get-Fooocus-Installer-Cmdlet-Help
    Set-Proxy
    if ($BuildMode) {
        Print-Msg "Fooocus Installer 构建模式已启用, 跳过 Fooocus Installer 更新检查"
    } else {
        Check-Fooocus-Installer-Update
    }
    Set-HuggingFace-Mirror
    Set-uv
    PyPI-Mirror-Status

    if (!(Test-Path "$PSScriptRoot/Fooocus")) {
        Print-Msg "在 $PSScriptRoot 路径中未找到 Fooocus 文件夹, 请检查 Fooocus 是否已正确安装, 或者尝试运行 Fooocus Installer 进行修复"
        Read-Host | Out-Null
        return
    }

    $launch_args = Get-Fooocus-Launch-Args
    $hf_mirror_arg = Get-Fooocus-HuggingFace-Mirror-Arg
    # 记录上次的路径
    $current_path = $(Get-Location).ToString()
    Set-Location "$PSScriptRoot/Fooocus"

    Create-Fooocus-Shortcut
    Check-Fooocus-Env
    Set-PyTorch-CUDA-Memory-Alloc
    Print-Msg "启动 Fooocus 中"
    if ($BuildMode) {
        Print-Msg "Fooocus Installer 构建模式已启用, 跳过启动 Fooocus"
    } else {
        python launch.py $launch_args $hf_mirror_arg
        $req = $?
        if ($req) {
            Print-Msg "Fooocus 正常退出"
        } else {
            Print-Msg "Fooocus 出现异常, 已退出"
        }
        Read-Host | Out-Null
    }
    Set-Location "$current_path"
}

###################

Main
