-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial script for the model validation tool #917
Open
ayissi-msft
wants to merge
14
commits into
main
Choose a base branch
from
t-jayissi/validation_tool
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+187
−0
Open
Changes from 11 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
60ff5f0
Add initial script for model validation tool
ayissi-msft 4d48841
Removing calculate_perplexity logic
ayissi-msft 2df8e01
Update validation script and config files based on feedback
ayissi-msft 6481452
created a json object, loop, and updated the validate_model method
ayissi-msft 15eb7d8
removed json object, added table to be printed
ayissi-msft e1cb1b7
fix return statement for validate_model
ayissi-msft 6063694
updated validation_tool and add exception messages
ayissi-msft b2b27fb
fixing the config file
ayissi-msft d6e9aed
added the precision and executive provider to the config
ayissi-msft 8f291ba
adding chat template
ayissi-msft 9b94bc7
Add the README.md
ayissi-msft d25d69f
reformatting the config file
ayissi-msft 2c5bad9
updating the chat templates
ayissi-msft e8756ed
updated README
ayissi-msft File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# ONNX Runtime GenAI Model Validation Example | ||
|
||
## Setup | ||
|
||
Clone this repository and navigate to the `tools/python/model_validation folder`. | ||
|
||
```bash | ||
git clone https://github.com/microsoft/onnxruntime-genai.git | ||
cd tools/python/model_validation | ||
``` | ||
|
||
In the model_validation folder, you should find the validation_tool.py script, validation_config.json file, and this README.md. | ||
|
||
### Current Support | ||
* Gemma | ||
* Llama | ||
* Mistral | ||
* Phi | ||
* Qwen | ||
|
||
### Usage - Build the Model | ||
This step creates optimized and quantized ONNX models that run with ONNX Runtime GenAI. | ||
|
||
1. In the validation_config.json file, enter the supported Hugging Face model name. Models can be found here. | ||
2. Include the path to the output folder, precision, and execution provider. | ||
|
||
Once the model is built, you can find it in path_to_output_folder/{model_name}. This should include the ONNX model data and tokenizer. | ||
|
||
### Run the Model Validation Script | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the validation tool automatically run the model builder. the narrative here makes users think this is a separate step. |
||
```bash | ||
python validation_tool.py -j validation_config.json | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
ayissi-msft marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"models": [ | ||
], | ||
ayissi-msft marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"inputs": [ | ||
], | ||
"max_length": 512, | ||
"min_length": 0, | ||
"do_sample": false, | ||
"top_p": 0.0, | ||
"top_k": 1, | ||
"temperature": 1.0, | ||
"repetition_penalty": 1.0, | ||
"verbose": false, | ||
"output_directory": "", | ||
"cache_directory": "", | ||
"precision": "", | ||
"executive_provider": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import onnxruntime_genai as og | ||
import argparse | ||
import time | ||
|
||
from onnxruntime_genai.models.builder import create_model | ||
import json | ||
import os | ||
import pandas as pd | ||
|
||
def create_table(output): | ||
df = pd.DataFrame(output, columns=['Model Name', 'Validation Completed', 'Exceptions / Failures']) | ||
return df | ||
|
||
def validate_model(config, model_directory): | ||
|
||
if config["verbose"]: print("Loading model...") | ||
|
||
model = og.Model(f'{model_directory}') | ||
|
||
if config["verbose"]: print("Model loaded") | ||
tokenizer = og.Tokenizer(model) | ||
tokenizer_stream = tokenizer.create_stream() | ||
if config["verbose"]: print("Tokenizer created") | ||
if config["verbose"]: print() | ||
|
||
chat_template = get_chat_template(model_directory) | ||
|
||
for input in config["inputs"]: | ||
|
||
complete_text = '' | ||
|
||
prompt = f'{chat_template.format(input=input)}' | ||
|
||
input_tokens = tokenizer.encode(prompt) | ||
|
||
params = og.GeneratorParams(model) | ||
params.input_ids = input_tokens | ||
|
||
generator = og.Generator(model, params) | ||
if config["verbose"]: print("Generator created") | ||
|
||
if config["verbose"]: print("Running generation loop ...") | ||
|
||
print() | ||
print("Output: ", end='', flush=True) | ||
|
||
generation_successful = True | ||
|
||
try: | ||
while not generator.is_done(): | ||
generator.compute_logits() | ||
generator.generate_next_token() | ||
|
||
new_token = generator.get_next_tokens()[0] | ||
|
||
value_to_save = tokenizer_stream.decode(new_token) | ||
|
||
complete_text += value_to_save | ||
|
||
print(tokenizer_stream.decode(new_token), end='', flush=True) | ||
|
||
except KeyboardInterrupt: | ||
print(" --control+c pressed, aborting generation--") | ||
generation_successful = False | ||
except Exception as e: | ||
print(f"An error occurred: {e}") | ||
generation_successful = False | ||
|
||
with open(f'{model_directory}/output.txt', 'a') as file: | ||
file.write(complete_text) | ||
|
||
# Delete the generator to free the captured graph for the next generator, if graph capture is enabled | ||
del generator | ||
|
||
return generation_successful | ||
|
||
def get_chat_template(output_directory): | ||
tokenizer_json = output_directory + '/tokenizer_config.json' | ||
with open(tokenizer_json, 'r') as file: | ||
config = json.load(file) | ||
return config["chat_template"] | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") | ||
|
||
parser.add_argument('-j', '--json', type=str, required=True, help='Path to the JSON file containing the arguments') | ||
|
||
args = parser.parse_args() | ||
|
||
with open(args.json, 'r') as file: | ||
config = json.load(file) | ||
|
||
os.makedirs(config["output_directory"], exist_ok=True) | ||
os.makedirs(config["cache_directory"], exist_ok=True) | ||
|
||
output = [] | ||
|
||
validation_complete = False | ||
|
||
for model in config["models"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
print(f"We are validating {model}") | ||
adjusted_model = model.replace("/", "_") | ||
output_path = config["output_directory"] + f'/{adjusted_model}' | ||
# From the output directory, there exist a file named tokenizer_config.json which contains the chat | ||
cache_path = config["cache_directory"] + f'/{adjusted_model}' | ||
|
||
try: | ||
create_model(model, '', output_path, config["precision"], config["executive_provider"], cache_path) | ||
except Exception as e: | ||
print(f'Failure after create model {e}') | ||
output.append([model, validation_complete, e]) | ||
continue | ||
try: | ||
validation_complete = validate_model(config, output_path) | ||
except Exception as e: | ||
print(f'Failure after validation model {e}') | ||
output.append([model, validation_complete, e]) | ||
|
||
|
||
df = create_table(output) | ||
|
||
df.to_csv("models.csv") | ||
|
||
print(df) | ||
|
||
# From the folder name, get the chat template |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add guidance of how to get chat template for a model