Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Runner: Quantized Model Execution Error #866

Open
Dave86ch opened this issue Sep 3, 2024 · 1 comment
Open

ONNX Runner: Quantized Model Execution Error #866

Dave86ch opened this issue Sep 3, 2024 · 1 comment

Comments

@Dave86ch
Copy link

Dave86ch commented Sep 3, 2024

Description
The ONNX Runner application works correctly with the non-quantized version of the Qwen2-0.5B-Instruct model but encounters an error when trying to use the quantized version.

Working Code (Non-Quantized Version)
The following code works correctly with the non-quantized model:


import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

public class App {
    // private static final String MODEL_PATH = "/home/davesoma/llms/OnnxRunner/app/src/main/resources/QwenOnnxMobile/qwen_mobile.onnx";
    private static final String MODEL_PATH = "/home/davesoma/llms/OnnxRunner/app/build/resources/main/Qwen2-0.5B-Instruct/model_q4f16.onnx";
    private static final int MAX_LENGTH = 500;  // Adjusted for potentially longer responses
    private static OrtEnvironment env;
    private static OrtSession session;

    public static void main(String[] args) {
        String prompt = "System: You are an AI assistant. Answer the following question concisely.\n" +
                        "Human: What is the meaning of life?\n" +
                        "AI:";

        try {
            env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            session = env.createSession(MODEL_PATH, options);

            CompletableFuture<Void> future = CompletableFuture.supplyAsync(() -> tokenizePrompt(prompt))
                .thenAccept(inputIds -> generateResponse(inputIds))
                .exceptionally(ex -> {
                    ex.printStackTrace();
                    return null;
                });

            System.out.println("AI: ");
            future.join();  // Wait for the future to complete
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            cleanupResources();
        }
    }

    private static void generateResponse(long[] inputIds) {
        if (inputIds == null || inputIds.length == 0) {
            System.err.println("No valid input tokens received.");
            return;
        }

        try {
            ArrayList<Long> generatedIds = new ArrayList<>(Arrays.asList(Arrays.stream(inputIds).boxed().toArray(Long[]::new)));
            StringBuilder generatedText = new StringBuilder();
            boolean isFirstToken = true;

            for (int i = 0; i < MAX_LENGTH; i++) {
                Map<String, OnnxTensor> inputs = prepareInputs(generatedIds);
                
                try (OrtSession.Result results = session.run(inputs)) {
                    float[][][] logits = (float[][][]) results.get(0).getValue();
                    long nextToken = argmax(logits[0][logits[0].length - 1]);
                    
                    if (nextToken == 0) {
                        break;  // End of sequence
                    }
                    
                    generatedIds.add(nextToken);
                    String decodedToken = decodeTokens(new long[]{nextToken});

                    // Properly handle spacing and word breaks
                    if (!isFirstToken) {
                        if (decodedToken.startsWith(" ") || decodedToken.startsWith("\n") || 
                            generatedText.toString().endsWith(" ") || generatedText.toString().endsWith("\n")) {
                            // Do nothing, space already exists
                        } else if (Character.isLetterOrDigit(decodedToken.charAt(0)) &&
                                   Character.isLetterOrDigit(generatedText.charAt(generatedText.length() - 1))) {
                            System.out.print(" ");
                            generatedText.append(" ");
                        }
                    }

                    generatedText.append(decodedToken);
                    System.out.print(decodedToken);
                    System.out.flush();

                    isFirstToken = false;
                }
            }
            System.out.println();  // New line at the end of the response
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static Map<String, OnnxTensor> prepareInputs(ArrayList<Long> inputIds) throws OrtException {
        long[] ids = inputIds.stream().mapToLong(l -> l).toArray();
        
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, new long[][]{ids});

        long[] attentionMask = new long[ids.length];
        Arrays.fill(attentionMask, 1);
        OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, new long[][]{attentionMask});

        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_ids", inputTensor);
        inputs.put("attention_mask", attentionMaskTensor);

        int numLayers = 24;
        int numKeyValueHeads = 2;
        int headSize = 64;
        int batchSize = 1;

        for (int i = 0; i < numLayers; i++) {
            float[][][][] pastKey = new float[batchSize][numKeyValueHeads][ids.length][headSize];
            float[][][][] pastValue = new float[batchSize][numKeyValueHeads][ids.length][headSize];
            inputs.put(String.format("past_key_values.%d.key", i), OnnxTensor.createTensor(env, pastKey));
            inputs.put(String.format("past_key_values.%d.value", i), OnnxTensor.createTensor(env, pastValue));
        }

        return inputs;
    }

    private static long argmax(float[] array) {
        int maxIndex = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[maxIndex]) {
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    private static long[] tokenizePrompt(String prompt) {
        try {
            ProcessBuilder processBuilder = new ProcessBuilder(
                "/home/davesoma/llms/OnnxRunner/app/src/main/python/venv/bin/python",
                "/home/davesoma/llms/OnnxRunner/app/src/main/python/Tokenizer.py",
                "tokenize", prompt);
            processBuilder.redirectErrorStream(true);
            Process process = processBuilder.start();

            BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
            StringBuilder output = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                output.append(line);
            }

            String jsonOutput = output.toString().trim();
            if (jsonOutput.isEmpty()) {
                System.err.println("No output from tokenizer script.");
                return null;
            }

            ObjectMapper mapper = new ObjectMapper();
            return mapper.readValue(jsonOutput, long[].class);
        } catch (Exception e) {
            e.printStackTrace();
            System.err.println("Failed to tokenize prompt: " + e.getMessage());
        }
        return null;
    }

    private static String decodeTokens(long[] tokenIds) {
        try {
            String tokenIdsJson = new ObjectMapper().writeValueAsString(tokenIds);
            
            ProcessBuilder processBuilder = new ProcessBuilder(
                "/home/davesoma/llms/OnnxRunner/app/src/main/python/venv/bin/python",
                "/home/davesoma/llms/OnnxRunner/app/src/main/python/Tokenizer.py",
                "decode", tokenIdsJson);
            processBuilder.redirectErrorStream(true);
            Process process = processBuilder.start();

            BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
            StringBuilder output = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                output.append(line);
            }

            return output.toString().trim();
        } catch (Exception e) {
            e.printStackTrace();
            System.err.println("Failed to decode tokens: " + e.getMessage());
        }
        return null;
    }

    private static void cleanupResources() {
        if (session != null) {
            try {
                session.close();
                
            } catch (OrtException e) {
                e.printStackTrace();
            }
        }
        if (env != null) {
            env.close();
        }
    }
}

Error with Quantized Version
When trying to use the quantized model (model_q4f16.onnx), the following error is encountered:

ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Reshape node. Name:'/model/Reshape' Status Message: /onnxruntime_src/include/onnxruntime/core/framework/op_kernel_context.h:42 const T* onnxruntime::OpKernelContext::Input(int) const [with T = onnxruntime::Tensor] Missing Input: position_ids
    at ai.onnxruntime.OrtSession.run(Native Method)
    at ai.onnxruntime.OrtSession.run(OrtSession.java:395)
    at ai.onnxruntime.OrtSession.run(OrtSession.java:242)
    at ai.onnxruntime.OrtSession.run(OrtSession.java:210)
    at onnxrunner.App.generateResponse(App.java:294)
    at onnxrunner.App.lambda$main$1(App.java:265)
    at java.base/java.util.concurrent.CompletableFuture$UniAccept.tryFire(CompletableFuture.java:718)
    at java.base/java.util.concurrent.CompletableFuture.postComplete(CompletableFuture.java:510)
    at java.base/java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1773)
    at java.base/java.util.concurrent.CompletableFuture$AsyncSupply.exec(CompletableFuture.java:1760)
    at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:387)
    at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1312)
    at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1843)
    at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1808)
    at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:188)
    ```
    
The error suggests that the quantized model might require a "position_ids" input that isn't provided in the current implementation.
The non-quantized version works without requiring this input.


@yufenglee
Copy link
Member

@Dave86ch, could you please check if the quantized models has "position_ids" input? if so, you need to add the "position_ids" under the "input" section.

image

If you still have issue, could you please share the non-quantized and quantized models? And how did you generate the quantized model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants