Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running custom encoder-decoder models in onnxruntime-genai #875

Open
KarelZe opened this issue Sep 5, 2024 · 3 comments
Open

Running custom encoder-decoder models in onnxruntime-genai #875

KarelZe opened this issue Sep 5, 2024 · 3 comments

Comments

@KarelZe
Copy link

KarelZe commented Sep 5, 2024

background

My question is about executing encoder-decoder models with onnx genai runtime. My goal is to convert the DONUT transformer https://arxiv.org/abs/2111.15664, a sequence-to-sequence transformer for document understanding with swin encoder and mbart decoder to onnx and run it using onnxruntime-genai.

I managed to convert the individual components to onnx. Now I'm stuck at writing a genai_config.json suitable for encoder-decoder models.

steps

I started with the huggingface implementation of DONUT (https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/donut#overview) and converted the encoder and merged decoder (with kv-cache) to onnx using optimum https://huggingface.co/docs/optimum/index. I converted the DONUT processor, which consists of an image processor and sentencepiece-tokenizer to onnx using onnx runtime extensions (https://github.com/microsoft/onnxruntime-extensions/blob/main/onnxruntime_extensions/tools/). I merged the image processor and swin-encoder into a single graph. I can provide the conversion scripts if needed.

My components have the following inputs/outputs shapes:

tokenizer:
image

encoder with image processor:
image

decoder:
image
image
(more past keys + values)

question

However, I am stuck at manually writing/loading a suitable genai_config.json, with the hidden states from the encoder for use in the decoder's attention mechanism. I'm aware of https://onnxruntime.ai/docs/genai/reference/config.html to write configs, but it seemingly focuses on decoder-only models. I'm also aware of make_genai_config https://github.com/microsoft/onnxruntime-genai/blob/c7eba3c63a454edd6662eb007ff397d1146cc081/src/python/py/models/builder.py for auto-config generation of supported models.

I tried the following config:

{
    "model": {
        "bos_token_id": 1,
        "context_length": 1536,
        "decoder": {
            "session_options": {
                "log_id": "onnxruntime-genai",
                "provider_options": []
            },
            "filename": "decoder_model_merged.onnx",
            "head_size": 96,
            "hidden_size": 1024,
            "inputs": {
                "input_ids": "input_ids",
                "encoder_hidden_states": "encoder_hidden_states",
                "past_key_names": "past_key_values.%d.decoder.key",
                "past_value_names": "past_key_values.%d.decoder.value"
            },
            "outputs": {
                "logits": "logits",
                "present_key_names": "present.%d.decoder.key",
                "present_value_names": "present.%d.decoder.value"
            },
            "num_attention_heads": 16,
            "num_hidden_layers": 12,
            "num_key_value_heads": 32
        },
        "encoder": {
            "filename": "encoder_with_img_processor.onnx",
            "inputs": {
                "pixel_values": "pixel_values"
            },
            "outputs": {
                "encoder_hidden_states": "last_hidden_state"
            }
        },
        "eos_token_id": 2,
        "pad_token_id": 1,
        "type": "sentencepiece",
        "vocab_size": 57654
    },
    "search": {
        "diversity_penalty": 0.0,
        "do_sample": false,
        "early_stopping": true,
        "length_penalty": 1.0,
        "max_length": 1536,
        "min_length": 0,
        "no_repeat_ngram_size": 0,
        "num_beams": 1,
        "num_return_sequences": 1,
        "past_present_share_buffer": true,
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_k": 50,
        "top_p": 1.0
    }
}

When loading the config I receive the error RuntimeError: Error encountered while parsing 'output/genai_config.json' JSON Error: Unknown value: encoder_hidden_states at line 15 index 64. As far as I can tell from

it's currently not possible to pass encoder hidden states to the model as inputs.

Do you have any plans of extending onnx-genai runtime for encoder-decoder models? Could you please give me a hint/advice how to work around this?

Thank you for your assistance.

@KarelZe KarelZe changed the title Running custom encoder-decoder models in onnx-genai-runtime Running custom encoder-decoder models in onnxruntime-genai Sep 5, 2024
@yufenglee
Copy link
Member

We don't have encoder-decoder support yet. Let me take a deep into at your model and see how we can support it best.

@KarelZe
Copy link
Author

KarelZe commented Sep 6, 2024

@yufenglee Thanks for your update. Would be really cool to see this happen🤗

I'll provide a build script for the onnx conversion for a headstart. Please let me know if I can help with implementation/test.

@KarelZe
Copy link
Author

KarelZe commented Sep 8, 2024

@yufenglee Here's the build.py script I used for conversion. Please let me know, if you have any questions. 👍

"""Build script for DONUT transformer.

Partly adapted from:
https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cuda/blob/main/onnx/builder.py
"""

import logging
import shutil
import subprocess
import sys
from pathlib import Path

import numpy as np
import onnx
import onnxruntime as ort
from onnxruntime_extensions import gen_processing_models, get_library_path
from onnxruntime_extensions.tools.pre_post_processing import (
    ChannelsLastToChannelsFirst,
    ImageBytesToFloat,
    LetterBox,
    Normalize,
    PrePostProcessor,
    Resize,
    Unsqueeze,
    create_named_value,
)
from PIL import Image
from transformers import AutoConfig, DonutProcessor, VisionEncoderDecoderModel

logger = logging.getLogger(__name__)

output_dir = Path("output/")
cache_dir = Path("cache_dir/")
path_model = Path(
    "/path/to/model/"
)
path_test_img = Path("/path/to/test/img.png")

precision = "fp16"
# pipeline in onnx extension requires at least 16 better 18.
opset = 18


def export_model():
    """Export encoder + decoder.

    see:
    https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model
    """
    subprocess.run([
        "optimum-cli",
        "export",
        "onnx",
        "--model",
        path_model,
        output_dir / "model_init_export",
        "--task",
        "image-to-text-with-past",
        "--framework",
        "pt",
        "--opset",
        str(opset),
    ])


def optimize_encoder():
    """Optimize encoder."""
    filename = "encoder_model.onnx"
    temp_folder_1 = output_dir / "model_init_export"
    fpath_1 = temp_folder_1 / filename

    onnx.checker.check_model(fpath_1)
    onnx.shape_inference.infer_shapes_path(fpath_1)
    onnx_model = onnx.load_model(fpath_1, load_external_data=True)

    temp_folder_2 = output_dir / "encoder_after_export"
    temp_folder_2.mkdir(exist_ok=True)
    fpath_2 = temp_folder_2 / filename

    onnx.save_model(
        onnx_model,
        fpath_2,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location=f"{filename}.data",
        size_threshold=0,
        convert_attribute=False,
    )

    temp_folder_3 = output_dir / "encoder_after_opt"
    temp_folder_3.mkdir(exist_ok=True)
    fpath_3 = temp_folder_3 / filename

    subprocess.run([
        f"{sys.executable}",
        "-m",
        "onnxruntime.transformers.optimizer",
        "--input",
        fpath_2,
        "--output",
        fpath_3,
        "--model_type",
        "swin",
        "--num_heads",
        str(0),  # In config 4 8 16 32  --> Use 0 is auto-discover from graph
        "--hidden_size",
        str(0),  # 0 = auto-discover
        "--use_external_data_format",
        "--opt_level",
        str(0),
    ])
    shutil.rmtree(temp_folder_2)

    fpath_4 = output_dir / filename
    cmd = [
        f"{sys.executable}",
        "-m",
        "onnxruntime.quantization.matmul_4bits_quantizer",
        "--input_model",
        fpath_3,
        "--output_model",
        fpath_4,
        "--block_size",
        str(32),
    ]
    if precision == "fp32":
        cmd.extend(["--accuracy_level", str(4)])

    subprocess.run(cmd)

    shutil.rmtree(temp_folder_3)


def optimize_decoder():
    """Optimize decoder.

    Adapted from:
    https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cuda/blob/main/onnx/builder.py
    """
    filename = "decoder_model_merged.onnx"
    temp_folder_1 = output_dir / "model_init_export"
    fpath_1 = temp_folder_1 / filename

    onnx.checker.check_model(fpath_1)
    onnx.shape_inference.infer_shapes_path(fpath_1)
    onnx_model = onnx.load_model(fpath_1, load_external_data=True)

    temp_folder_2 = output_dir / "decoder_after_export"
    temp_folder_2.mkdir(exist_ok=True)
    fpath_2 = temp_folder_2 / filename

    onnx.save_model(
        onnx_model,
        fpath_2,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location=f"{filename}.data",
        size_threshold=0,
        convert_attribute=False,
    )

    temp_folder_3 = output_dir / "decoder_after_opt"
    temp_folder_3.mkdir(exist_ok=True)
    fpath_3 = temp_folder_3 / filename

    subprocess.run([
        f"{sys.executable}",
        "-m",
        "onnxruntime.transformers.optimizer",
        "--input",
        fpath_2,
        "--output",
        fpath_3,
        "--model_type",
        "bart",
        "--num_heads",
        str(config.decoder.decoder_attention_heads),
        "--hidden_size",
        str(config.decoder.d_model),
        "--use_external_data_format",
        "--opt_level",
        str(0),
    ])
    shutil.rmtree(temp_folder_2)

    fpath_4 = output_dir / filename
    cmd = [
        f"{sys.executable}",
        "-m",
        "onnxruntime.quantization.matmul_4bits_quantizer",
        "--input_model",
        fpath_3,
        "--output_model",
        fpath_4,
        "--block_size",
        str(32),
    ]
    if precision == "fp32":
        cmd.extend(["--accuracy_level", str(4)])

    subprocess.run(cmd)

    shutil.rmtree(temp_folder_3)


def build_tokenizer() -> None:
    """Get sentence piece tokenizer from DONUT processor."""
    onnx_tokenizer = gen_processing_models(tokenizer, opset=opset, pre_kwargs={})[0]
    fpath_tokenizer = output_dir / "tokenizer.onnx"
    with fpath_tokenizer.open(mode="wb") as f:
        f.write(onnx_tokenizer.SerializeToString())


def build_img_preprocessor() -> None:
    """Build image processor.

    Adapted from:
    https://github.com/microsoft/onnxruntime-extensions/pull/478/files#diff-8f875d92e23f555946efe7bec0ccdefde80c06a4b74b595071961ec4e0f84f5d

    For operations and their order see:
    https://github.com/huggingface/transformers/blob/v4.42.0/src/transformers/models/donut/image_processing_donut.py#L54


    Raises
    ------
        NotImplementedError: do_thumbnail not yet implemented
        NotImplementedError: do_align_long_axis not yet implemented
        NotImplementedError: resizing methods !=2 not implemented

    """
    pixel_values_in = [
        create_named_value("pixel_values", onnx.TensorProto.UINT8, ["h", "w", 3])
    ]
    pipeline = PrePostProcessor(pixel_values_in, onnx_opset=opset)

    steps = []

    size = (
        image_processor_config["size"]["height"],
        image_processor_config["size"]["width"],
    )

    if image_processor_config["do_align_long_axis"]:
        logger.warning(
            "'do_align_long_axis' is not yet implemented. This will lead to a performance degradation for rotated images."
        )

    if image_processor_config["do_resize"]:
        # 2 = BILINEAR
        # see: https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Resampling.BICUBIC
        if image_processor_config["resample"] != 2:
            logger.warning(
                "resampling method '%s' not supported. Resampling with 2=BILINEAR.",
                image_processor_config["resample"],
            )

        steps.append(
            Resize(
                size,
                layout="HWC",
                name="do_resize",
                policy="not_larger",
            )
        )

    if image_processor_config["do_thumbnail"]:
        logger.warning("'do_thumbnail' is not yet implemented.")
    if image_processor_config["do_pad"]:
        steps.append(
            LetterBox(target_shape=size, fill_value=0, name="do_pad", layout="HWC")
        )

    if image_processor_config["do_rescale"]:
        steps.append(
            ImageBytesToFloat(
                image_processor_config["rescale_factor"], name="do_rescale"
            )
        )

    if image_processor_config["do_normalize"]:
        mean_std = list(
            zip(
                image_processor_config["image_mean"],
                image_processor_config["image_std"],
            )
        )
        steps.append(Normalize(mean_std, layout="HWC", name="do_normalize"))

    steps.extend([
        ChannelsLastToChannelsFirst(name="RGBImageCHW"),  # HWC to CHW
        Unsqueeze([0], name="unsqueeze"),  # add batch dimension, CHW --> 1CHW
    ])
    pipeline.add_pre_processing(steps)

    pixel_values_out = [
        onnx.helper.make_tensor_value_info(
            "pixel_values_out", onnx.TensorProto.FLOAT, [1, 3, *size]
        )
    ]

    g = onnx.helper.make_graph(
        [onnx.helper.make_node("Identity", ["pixel_values"], ["pixel_values_out"])],
        "empty",
        pixel_values_in,
        pixel_values_out,
    )
    onnx_import = onnx.helper.make_operatorsetid("", opset)
    ir_version = onnx.helper.find_min_ir_version_for([onnx_import])
    model = onnx.helper.make_model_gen_version(
        g, opset_imports=[onnx_import], ir_version=ir_version
    )

    new_model = pipeline.run(model)
    new_model.doc_string = "Donut-like image pre-processor."
    new_model.graph.doc_string = ""

    temp_folder_1 = output_dir / "img_processor_after_export"
    temp_folder_1.mkdir(exist_ok=True)

    filename = "preprocessor.onnx"

    fpath_1 = temp_folder_1 / filename
    onnx.save_model(new_model, fpath_1)
    onnx.checker.check_model(fpath_1)
    onnx.shape_inference.infer_shapes_path(fpath_1)

    subprocess.run([
        f"{sys.executable}",
        "-m",
        "onnxoptimizer",
        fpath_1,
        output_dir / filename,
    ])


def test_components():
    """Test tokenizer, image processor, and fused image processor with encoder."""
    input_text = "<s_name>YOUR_NAME</s_name>"

    sess_options = ort.SessionOptions()
    sess_options.register_custom_ops_library(get_library_path())
    session = ort.InferenceSession(
        output_dir / "tokenizer.onnx",
        sess_options=sess_options,
        providers=["CPUExecutionProvider"],
    )
    input_feed = {"inputs": np.asarray([input_text])}

    outputs = session.run(["token_indices"], input_feed)
    print("token_ids", outputs[0])
    print(tokenizer(input_text))

    session = ort.InferenceSession(
        output_dir / "preprocessor.onnx",
        sess_options,
        providers=["CPUExecutionProvider"],
    )
    outputs = session.run(
        ["pixel_values_out"], {"pixel_values": np.array(Image.open(path_test_img))}
    )
    # remove batch dim, CHW -> HWC
    # [-1, 1] -> [0, 1] * 255 for visualization
    img = np.squeeze(outputs[0]).transpose((1, 2, 0))
    img = np.uint8((img - img.min()) / (img.max() - img.min()) * 255)
    Image.fromarray(img).save(
        output_dir / "test_img_from_onnx_processor.png", format="PNG"
    )
    session = ort.InferenceSession(
        output_dir / "encoder_with_img_processor.onnx",
        sess_options,
        providers=["CPUExecutionProvider"],
    )
    outputs = session.run(
        ["last_hidden_state"], {"pixel_values": np.array(Image.open(path_test_img))}
    )
    print("last_hidden_state", outputs[0])
    print("last_hidden-state (shape)", outputs[0].shape)


def merge_img_processor_encoder():
    """Merge image processor and encoder into a single graph."""
    model_preprocessor = onnx.load(output_dir / "preprocessor.onnx")
    model_encoder = onnx.load(output_dir / "encoder_model.onnx")

    output_name_from_img_processor = model_preprocessor.graph.output[0].name
    input_name_of_encoder = model_encoder.graph.input[0].name

    merged_model = onnx.compose.merge_models(
        model_preprocessor,
        model_encoder,
        io_map=[(output_name_from_img_processor, input_name_of_encoder)],
    )

    filename = "encoder_with_img_processor.onnx"

    onnx.save(
        merged_model,
        output_dir / filename,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location=f"{filename}.data",
        size_threshold=0,
        convert_attribute=False,
    )


if __name__ == "__main__":
    config = AutoConfig.from_pretrained(path_model)

    processor = DonutProcessor.from_pretrained(path_model)
    image_processor = processor.image_processor
    image_processor_config = image_processor.to_dict()
    tokenizer = processor.tokenizer
    model = VisionEncoderDecoderModel.from_pretrained(path_model)
    config = AutoConfig.from_pretrained(path_model)

    export_model()
    build_tokenizer()
    optimize_encoder()
    optimize_decoder()
    build_img_preprocessor()
    merge_img_processor_encoder()
    test_components()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants