|
|
import os |
|
|
from rkllm.api import RKLLM |
|
|
|
|
|
def convert_model(model_path, output_name, do_quantization=False): |
|
|
"""转换单个模型""" |
|
|
llm = RKLLM() |
|
|
|
|
|
print(f"正在加载模型: {model_path}") |
|
|
ret = llm.load_huggingface(model=model_path, model_lora=None, device='cpu') |
|
|
if ret != 0: |
|
|
print(f'加载模型失败: {model_path}') |
|
|
return ret |
|
|
|
|
|
print(f"正在构建模型: {output_name} (量化: {do_quantization})") |
|
|
qparams = None |
|
|
ret = llm.build(do_quantization=do_quantization, optimization_level=1, quantized_dtype='w8a8', |
|
|
quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams) |
|
|
|
|
|
if ret != 0: |
|
|
print(f'构建模型失败: {output_name}') |
|
|
return ret |
|
|
|
|
|
|
|
|
print(f"正在导出模型: {output_name}") |
|
|
ret = llm.export_rkllm(output_name) |
|
|
if ret != 0: |
|
|
print(f'导出模型失败: {output_name}') |
|
|
return ret |
|
|
|
|
|
print(f"成功转换: {output_name}") |
|
|
return 0 |
|
|
|
|
|
def main(): |
|
|
"""主函数:遍历所有子文件夹并转换模型""" |
|
|
current_dir = '.' |
|
|
|
|
|
|
|
|
subdirs = [d for d in os.listdir(current_dir) |
|
|
if os.path.isdir(os.path.join(current_dir, d)) and not d.startswith('.')] |
|
|
|
|
|
print(f"找到 {len(subdirs)} 个模型文件夹: {subdirs}") |
|
|
|
|
|
for subdir in subdirs: |
|
|
model_path = os.path.join(current_dir, subdir) |
|
|
|
|
|
|
|
|
base_name = subdir.replace('/', '_').replace('\\', '_') |
|
|
quantized_output = f"{base_name}_w8a8.rkllm" |
|
|
unquantized_output = f"{base_name}_f16.rkllm" |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"处理模型文件夹: {subdir}") |
|
|
print(f"{'='*50}") |
|
|
|
|
|
|
|
|
print(f"\n--- 转换非量化版本 ---") |
|
|
ret = convert_model(model_path, unquantized_output, do_quantization=False) |
|
|
if ret != 0: |
|
|
print(f"非量化版本转换失败: {subdir}") |
|
|
continue |
|
|
|
|
|
|
|
|
print(f"\n--- 转换量化版本 ---") |
|
|
ret = convert_model(model_path, quantized_output, do_quantization=True) |
|
|
if ret != 0: |
|
|
print(f"量化版本转换失败: {subdir}") |
|
|
continue |
|
|
|
|
|
print(f"\n✓ {subdir} 模型转换完成!") |
|
|
print(f" - 非量化版本: {unquantized_output}") |
|
|
print(f" - 量化版本: {quantized_output}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|