1. 前言
最近在测试VLM模型,发现官方的网页demo,代码中视频与图片分辨率可能由于高并发设置的很小,导致达不到预期效果,于是自己研究了一下,搞了一个简单的前端部署,自己在服务器部署了下UI界面,方便在本地笔记本进行测试。
2.代码
import streamlit as st
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import os
# 加载模型和处理器 (只加载一次)
@st.cache_resource  # 这个装饰器会缓存模型和处理器
def load_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "../qwen2_vl/model_7b/", torch_dtype=torch.float16, device_map="auto"
    )
    processor = AutoProcessor.from_pretrained("../qwen2_vl/model_7b/")
    return model, processor
# 加载模型和处理器
model, processor = load_model()
def load_image(image_file):
    img = Image.open(image_file)
    return img
# Function to load and resize image to fixed height
def resize_image_to_height(image, height):
    # Resize image keeping the aspect ratio
    width = int(image.width * height / image.height)
    return image.resize((width, height))
# 处理输入
def process_input(messages):
    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    # Inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    # Clear all intermediate variables and free GPU memory
    del text, image_inputs, video_inputs, inputs, generated_ids, generated_ids_trimmed
    torch.cuda.empty_cache()
    return output_text
# Streamlit UI
st.title("VLM 视觉内容理解")
# 选择文件上传
uploaded_file = st.file_uploader("上传图片或视频", type=["jpg", "jpeg", "png", "mp4"])
# 判断文件是否上传
if uploaded_file is not None:
    # 保存文件到本地
    upload_dir = "uploads"  # 上传文件保存目录
    if not os.path.exists(upload_dir):
        os.makedirs(upload_dir)
    file_path = os.path.join(upload_dir, uploaded_file.name)
    # 保存文件
    with open(file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    # 判断文件类型
    if uploaded_file.type.startswith("image"):
        # 加载并显示图片
        img = load_image(file_path)
        # 设置固定高度
        fixed_height = 300  # 设置固定高度为300px
        # 调整图片的大小,使高度固定,宽度按比例调整
        img_resized = resize_image_to_height(img, fixed_height)
        st.image(img_resized, use_container_width=False)
        # 输入台词部分
        st.subheader("输入一句提示词")
        user_input = st.text_input("请输入提示词,并回车确认:")
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": file_path,  # 使用本地保存的文件路径
                        "max_pixels": 1024 * 960
                    },
                    {"type": "text", "text": user_input},
                ],
            }
        ]
    elif uploaded_file.type.startswith("video"):
        # 设置视频固定高度
        fixed_height = 300  # 设置固定高度为300px
        # 显示视频
        st.video(file_path)
        st.subheader("输入一句提示词")
        user_input = st.text_input("请输入提示词,并回车确认:")
        # 通过 Markdown 来调整视频的显示样式
        st.markdown(
            f"""
            <style>
                video {{
                    height: {fixed_height}px;
                    width: auto;
                }}
            </style>
            """,
            unsafe_allow_html=True
        )
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": file_path,  # 使用本地保存的文件路径
                        "max_pixels": 960 * 480,
                        "fps": 1.0,
                    },
                    {"type": "text", "text": user_input},
                ],
            }
        ]
    # 调用模型进行推理
    result = process_input(messages)
    if result:
        # 使用 st.markdown 和 CSS 来自动换行
        st.markdown("### 模型推理结果:")
        # 将输出格式化为代码块样式,并通过 CSS 实现自动换行
        st.markdown(
            f'<pre style="white-space: pre-wrap; word-wrap: break-word;">{result[0]}</pre>',
            unsafe_allow_html=True
        )
    else:
        st.markdown("### 模型推理结果:无结果。")
    # 推理完成后删除本地文件
    try:
        os.remove(file_path)
    except Exception as e:
        pass
结论
主要是利用streamlit,进行UI的搭建,涉及本地文件的上传与下载到服务器中,推理完删除。
 



















