Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial script for the model validation tool #917

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tools/python/model_validation/validation_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
ayissi-msft marked this conversation as resolved.
Show resolved Hide resolved
"models": [
],
ayissi-msft marked this conversation as resolved.
Show resolved Hide resolved
"inputs": [
],
"max_length": 512,
"min_length": 0,
"do_sample": false,
"top_p": 0.0,
"top_k": 1,
"temperature": 1.0,
"repetition_penalty": 1.0,
"verbose": false,
"output_directory": "",
"cache_directory": "",
"precision": "",
"executive_provider": ""
}
126 changes: 126 additions & 0 deletions tools/python/model_validation/validation_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import onnxruntime_genai as og
import argparse
import time
Fixed Show fixed Hide fixed
from onnxruntime_genai.models.builder import create_model
import json
import os
import pandas as pd

def create_table(output):
df = pd.DataFrame(output, columns=['Model Name', 'Validation Completed', 'Exceptions / Failures'])
return df

def validate_model(config, model_directory):
Fixed Show fixed Hide fixed
if config["verbose"]: print("Loading model...")

model = og.Model(f'{model_directory}')

if config["verbose"]: print("Model loaded")
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
if config["verbose"]: print("Tokenizer created")
if config["verbose"]: print()

chat_template = get_chat_template(model_directory)

for input in config["inputs"]:

complete_text = ''

prompt = f'{chat_template.format(input=input)}'

input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
params.input_ids = input_tokens

generator = og.Generator(model, params)
if config["verbose"]: print("Generator created")

if config["verbose"]: print("Running generation loop ...")

print()
print("Output: ", end='', flush=True)

generation_successful = True

try:
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]

value_to_save = tokenizer_stream.decode(new_token)

complete_text += value_to_save

print(tokenizer_stream.decode(new_token), end='', flush=True)

except KeyboardInterrupt:
print(" --control+c pressed, aborting generation--")
generation_successful = False
except Exception as e:
print(f"An error occurred: {e}")
generation_successful = False

with open(f'{model_directory}/output.txt', 'a') as file:
file.write(complete_text)

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

return generation_successful

def get_chat_template(output_directory):
tokenizer_json = output_directory + '/tokenizer_config.json'
with open(tokenizer_json, 'r') as file:
config = json.load(file)
return config["chat_template"]


if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")

parser.add_argument('-j', '--json', type=str, required=True, help='Path to the JSON file containing the arguments')

args = parser.parse_args()

with open(args.json, 'r') as file:
config = json.load(file)

os.makedirs(config["output_directory"], exist_ok=True)
os.makedirs(config["cache_directory"], exist_ok=True)

output = []

validation_complete = False

for model in config["models"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model

model_dict
model_dict['name'], model_dict['chat_template']


print(f"We are validating {model}")
adjusted_model = model.replace("/", "_")
output_path = config["output_directory"] + f'/{adjusted_model}'
# From the output directory, there exist a file named tokenizer_config.json which contains the chat
cache_path = config["cache_directory"] + f'/{adjusted_model}'

try:
create_model(model, '', output_path, config["precision"], config["executive_provider"], cache_path)
except Exception as e:
print(f'Failure after create model {e}')
output.append([model, validation_complete, e])
continue
try:
validation_complete = validate_model(config, output_path)
except Exception as e:
print(f'Failure after validation model {e}')
output.append([model, validation_complete, e])


df = create_table(output)

df.to_csv("models.csv")

print(df)

# From the folder name, get the chat template
Loading