AlanXian's picture
update: nougat gpu
76eb9fc
raw
history blame
18.1 kB
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import tempfile
import torch
import subprocess
import sys
import importlib.util
from tqdm import tqdm
# Check if nougat-ocr is installed
NOUGAT_AVAILABLE = importlib.util.find_spec("nougat") is not None
if not NOUGAT_AVAILABLE:
print("Warning: nougat-ocr is not installed. PDF to Markdown conversion will not be available.")
print("To install, run: pip install -U 'git+https://github.com/facebookresearch/nougat.git'")
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Set CUDA environment variables for better GPU performance with Nougat
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Meta Llama3 8B with Nougat PDF Processing</h1>
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p>
<p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p>
<p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p>
<p>📝 <b>PDF处理功能:</b> 本应用使用<a href="https://github.com/facebookresearch/nougat">Nougat</a>进行高质量PDF到Markdown的转换。该工具能够很好地保留原始布局、数学公式和表格,提供最佳的PDF文档处理体验。</p>
</div>
'''
LICENSE = """
<p/>
---
Built with Meta Llama 3
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("nuojohnchen/shuishanllm")
# Load model weights but don't initialize CUDA in main process
model = AutoModelForCausalLM.from_pretrained("nuojohnchen/shuishanllm", device_map=None)
# 确保eos_token_id不为None
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2 # 通常2是</s>标记的ID,这是一个常见的默认值
# 定义终止标记
terminators = []
if tokenizer.eos_token_id is not None:
terminators.append(tokenizer.eos_token_id)
# 尝试添加特殊的终止标记,如果存在的话
try:
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id != tokenizer.unk_token_id: # 确保不是未知标记
terminators.append(eot_id)
except:
pass
# 如果terminators为空,添加一个默认值
if not terminators:
terminators = [2] # 使用常见的</s>标记ID作为默认值
# 使用CUDA运行Nougat的PDF处理函数
def process_pdf_with_nougat_gpu(pdf_path, output_dir=None):
"""使用GPU运行Nougat处理PDF文件"""
try:
# 如果未指定输出目录,使用PDF所在目录
if output_dir is None:
output_dir = os.path.dirname(pdf_path)
# 设置CUDA环境变量
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "0" # 使用第一个GPU
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# 执行带有GPU支持的Nougat命令
print(f"使用GPU运行Nougat: {pdf_path}")
cmd = ["nougat", pdf_path, "-o", output_dir, "--device", "cuda"]
# 执行命令并捕获输出
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
timeout=300 # 5分钟超时
)
# 检查命令执行结果
if result.returncode != 0:
print(f"Nougat GPU处理失败: {result.stderr}")
return None, result.stderr
# 获取生成的markdown文件路径
base_name = os.path.basename(pdf_path)
name_without_ext = os.path.splitext(base_name)[0]
markdown_path = os.path.join(output_dir, f"{name_without_ext}.mmd")
# 检查markdown文件是否生成
if not os.path.exists(markdown_path):
return None, "Nougat处理完成,但未找到生成的Markdown文件"
# 读取markdown内容
with open(markdown_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
return markdown_content, None
except subprocess.TimeoutExpired:
return None, "Nougat处理超时"
except Exception as e:
import traceback
error = f"Nougat处理异常: {str(e)}\n{traceback.format_exc()}"
print(error)
return None, error
# 使用Python API的GPU处理方式
@spaces.GPU(stateless=True)
def process_pdf_with_nougat_api(pdf_path):
"""使用Nougat Python API与GPU处理PDF文件"""
try:
# 导入必要的库
from nougat import NougatModel
from nougat.utils.checkpoint import get_checkpoint
from nougat.dataset.rasterize import rasterize_paper
import torch
# 确保GPU可用
if not torch.cuda.is_available():
return None, "GPU不可用,无法使用Nougat API处理PDF"
# 显示GPU信息
device_count = torch.cuda.device_count()
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
print(f"使用GPU: {device_name}, 可用GPU数量: {device_count}")
# 初始化模型并移至GPU
ckpt = get_checkpoint()
model = NougatModel.from_pretrained(ckpt)
device = torch.device("cuda")
model = model.to(device)
# 处理PDF
markdown_content = ""
pages = list(rasterize_paper(pdf_path))
# 使用tqdm显示进度
for page_idx, page in enumerate(tqdm(pages, desc="处理PDF页面")):
page = page.to(device)
markdown = model.inference(page)
markdown_content += f"--- Page {page_idx+1} ---\n{markdown}\n\n"
return markdown_content, None
except Exception as e:
import traceback
error = f"Nougat API处理异常: {str(e)}\n{traceback.format_exc()}"
print(error)
return None, error
# 添加PDF转换为Markdown函数
def convert_pdf_to_markdown(pdf_file):
"""使用Nougat将PDF转换为Markdown (GPU优化版)"""
if pdf_file is None:
return "", "未上传PDF"
# 检查Nougat是否可用
if not NOUGAT_AVAILABLE:
return "", "错误: Nougat未安装。请执行 'pip install -U \"git+https://github.com/facebookresearch/nougat.git\"' 安装后重试。"
try:
# 创建临时目录用于存储PDF和输出文件
with tempfile.TemporaryDirectory() as temp_dir:
# 将二进制PDF数据保存到临时文件
temp_pdf_path = os.path.join(temp_dir, "temp.pdf")
with open(temp_pdf_path, "wb") as f:
f.write(pdf_file)
# 方法1: 首先尝试使用命令行GPU方式
print("方法1: 尝试使用命令行GPU方式处理PDF...")
markdown_content, error = process_pdf_with_nougat_gpu(temp_pdf_path, temp_dir)
if markdown_content is not None:
# 限制文本长度
if len(markdown_content) > 20000:
markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
status = f"PDF已成功转换为Markdown (GPU命令行): 生成了{len(markdown_content)}个字符"
return markdown_content, status
# 方法2: 如果命令行方式失败,尝试使用Python API方式
print(f"方法1失败: {error}")
print("方法2: 尝试使用Python API GPU方式处理PDF...")
markdown_content, api_error = process_pdf_with_nougat_api(temp_pdf_path)
if markdown_content is not None:
# 限制文本长度
if len(markdown_content) > 20000:
markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
status = f"PDF已成功转换为Markdown (GPU API): 生成了{len(markdown_content)}个字符"
return markdown_content, status
# 所有方法都失败
return "", f"PDF转换失败: 所有GPU方法都失败了\n命令行错误: {error}\nAPI错误: {api_error}"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Nougat转换错误: {str(e)}\n{error_details}")
return "", f"Markdown转换错误: {str(e)}"
@spaces.GPU(duration=120, stateless=True)
def chat_llama3_8b(message, history, temperature, max_new_tokens, markdown_content=""):
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
markdown_content (str): Optional Markdown content converted by Nougat to include in the context.
Returns:
str: The generated response.
"""
try:
conversation = []
for user, assistant in history:
# 确保所有内容都是字符串类型
user_msg = str(user) if user is not None else ""
assistant_msg = str(assistant) if assistant is not None else ""
conversation.extend([
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
])
# 确保message是字符串
message = str(message) if message is not None else ""
# 如果有Markdown内容,将其添加到用户消息中
if markdown_content and isinstance(markdown_content, str) and markdown_content.strip():
message = f"""
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
<PAPER_CONTENT>
{markdown_content}
</PAPER_CONTENT>
<QUESTION>
{message}
</QUESTION>
"""
print(f"加入Markdown的message", message)
conversation.append({"role": "user", "content": message})
# 使用简单的文本拼接方式构建提示
prompt = ""
for item in conversation:
role = item["role"]
content = item["content"]
if role == "user":
prompt += f"用户: {content}\n"
else:
prompt += f"助手: {content}\n"
prompt += "助手: " # 添加最后的提示符
# 编码提示
# 在stateless GPU环境中将模型移到CUDA设备
global model
device = torch.device("cuda")
model = model.to(device)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
)
# 只有当terminators非空时才添加eos_token_id
if terminators:
generate_kwargs['eos_token_id'] = terminators
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"生成错误: {str(e)}\n{error_details}")
yield f"生成文本时出错: {str(e)}\n\n请尝试使用不同的参数或输入。"
# Gradio block
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
# 创建Markdown内容状态
markdown_content_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
# PDF上传区域
pdf_file = gr.File(
label="上传PDF文档(可选)",
file_types=[".pdf"],
type="binary"
)
pdf_status = gr.Textbox(
label="PDF状态",
value="未上传PDF",
interactive=False
)
clear_pdf_btn = gr.Button("清除PDF")
if NOUGAT_AVAILABLE:
nougat_info = """
<div style="margin-top: 10px; margin-bottom: 10px;">
<p><b>Nougat PDF处理:</b> 系统将使用Nougat将上传的PDF转换为高质量Markdown。Nougat能够很好地保留原始布局、数学公式和表格,远优于传统的PDF文本提取。</p>
</div>
"""
else:
nougat_info = """
<div style="margin-top: 10px; margin-bottom: 10px; color: #d32f2f;">
<p><b>Nougat未安装:</b> PDF处理功能需要Nougat。请执行 <code>pip install -U 'git+https://github.com/facebookresearch/nougat.git'</code> 安装后重试。</p>
</div>
"""
gr.Markdown(nougat_info)
# 添加Markdown内容查看器(可折叠)
with gr.Accordion("查看Markdown内容", open=False):
markdown_content_display = gr.Textbox(
label="Nougat转换的Markdown内容",
lines=10,
interactive=False
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
show_label=False,
placeholder="输入您的问题...",
container=False
)
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button("发送")
with gr.Accordion("⚙️ 参数设置", open=False):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens"
)
gr.Examples(
examples=[
['How to setup a human base on Mars? Give short answer.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
inputs=msg
)
# 处理PDF上传 - 直接使用Nougat转换
pdf_file.change(
fn=convert_pdf_to_markdown,
inputs=[pdf_file],
outputs=[markdown_content_state, pdf_status]
)
# 更新Markdown内容显示
pdf_file.change(
fn=lambda content: content,
inputs=[markdown_content_state],
outputs=[markdown_content_display]
)
# 清除PDF内容
clear_pdf_btn.click(
fn=lambda: ("", "PDF已清除"),
inputs=[],
outputs=[markdown_content_state, pdf_status]
)
# 清除Markdown内容显示
clear_pdf_btn.click(
fn=lambda: "",
inputs=[],
outputs=[markdown_content_display]
)
# 聊天功能
chat_interface = gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
textbox=msg,
submit_btn=submit_btn,
additional_inputs=[temperature, max_new_tokens, markdown_content_state],
additional_inputs_accordion=None,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()