Skip to content

Commit

Permalink
Disable graph capture for embedding model (#930)
Browse files Browse the repository at this point in the history
The new embedding model for the vision pipeline contains `if` nodes, so
graph capture is not possible.
  • Loading branch information
PatriceVignola authored Sep 27, 2024
1 parent 7013224 commit 53e3ac9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
27 changes: 13 additions & 14 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ State::State(const GeneratorParams& params, const Model& model)
params_{params.shared_from_this()} {}

void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) {
auto captured_graph_info = GetCapturedGraphInfo();

if (first_run_) {
if (params_->use_cuda_graph) {
if (captured_graph_info) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
} else if (params_->use_cuda_graph && new_batch_size != current_batch_size_) {
assert(GetCapturedGraphInfo() != nullptr);
} else if (captured_graph_info && new_batch_size != current_batch_size_) {
current_batch_size_ = new_batch_size;
auto annotation_id = std::to_string(GetCapturedGraphInfo()->GenerateUniqueAnnotationID(new_batch_size));
auto annotation_id = std::to_string(captured_graph_info->GenerateUniqueAnnotationID(new_batch_size));
model_.run_options_->AddConfigEntry("gpu_graph_id", annotation_id.c_str());
}

Expand Down Expand Up @@ -288,7 +289,8 @@ void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) {

void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_session_options,
OrtSessionOptions& session_options,
bool is_primary_session_options) {
bool is_primary_session_options,
bool disable_graph_capture) {
// Default to a limit of 16 threads to optimize performance
constexpr int min_thread_nums = 1;
constexpr int max_thread_nums = 16;
Expand Down Expand Up @@ -433,16 +435,13 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_

dml_pooled_upload_heap_ = std::make_unique<DmlPooledUploadHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
dml_readback_heap_ = std::make_unique<DmlReadbackHeap>(dml_objects_.d3d12_device.Get(), dml_execution_context_.get());
}

// The vision model doesn't support graph capture because of dynamic shapes, so don't enable graph capture for it
if (!vision_session_options_ && !config_->model.vision.filename.empty()) {
vision_session_options_ = session_options.Clone();
p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(vision_session_options_.get(), dml_device_.Get(), dml_objects_.command_queue.Get());
}
if (!disable_graph_capture) {
session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1");
session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1");
}

session_options.AddConfigEntry("ep.dml.enable_graph_capture", "1");
session_options.AddConfigEntry("ep.dml.disable_memory_arena", "1");
p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(&session_options, dml_device_.Get(), dml_objects_.command_queue.Get());

if (is_primary_session_options)
Expand All @@ -463,12 +462,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_

void Model::CreateSessionOptions() {
session_options_ = OrtSessionOptions::Create();
CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *session_options_, true);
CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *session_options_, true, false);

for (auto& pipeline_model : config_->model.decoder.pipeline) {
if (pipeline_model.session_options.has_value()) {
auto emplaced = pipeline_session_options_.emplace(pipeline_model.model_id, OrtSessionOptions::Create());
CreateSessionOptionsFromConfig(*pipeline_model.session_options, *emplaced.first->second, false);
CreateSessionOptionsFromConfig(*pipeline_model.session_options, *emplaced.first->second, false, false);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {

std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;
std::unique_ptr<OrtSessionOptions> vision_session_options_;
std::unique_ptr<OrtRunOptions> run_options_;

cuda_stream_holder cuda_stream_;
Expand All @@ -156,10 +155,10 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
void InitDeviceAllocator(OrtSession& session);
void CreateSessionOptions();

private:
void CreateSessionOptionsFromConfig(const Config::SessionOptions& config_session_options,
OrtSessionOptions& session_options,
bool is_primary_session_options);
bool is_primary_session_options,
bool disable_graph_capture);

#if USE_DML
mutable DmlObjects dml_objects_;
Expand Down
15 changes: 9 additions & 6 deletions src/models/multi_modal_vision_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ int64_t GetNumImageTokens(const std::vector<GeneratorParams::Input>& extra_input

MultiModalVisionModel::MultiModalVisionModel(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)} {
embedding_session_ = OrtSession::Create(
ort_env, (config_->config_path / fs::path(config_->model.embedding.filename)).c_str(), session_options_.get());
// The embedding and vision models don't support graph capture because of control flow nodes, so disable graph capture for them
auto vision_session_options = OrtSessionOptions::Create();
CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *vision_session_options, true, true);

// User a custom vision session if available; otherwise, fallback to the generic options
auto* vision_session_options = vision_session_options_ ? vision_session_options_.get() : session_options_.get();
auto embedding_session_options = OrtSessionOptions::Create();
CreateSessionOptionsFromConfig(config_->model.decoder.session_options, *embedding_session_options, true, true);

embedding_session_ = OrtSession::Create(
ort_env, (config_->config_path / fs::path(config_->model.embedding.filename)).c_str(), embedding_session_options.get());
vision_session_ = OrtSession::Create(
ort_env, (config_->config_path / fs::path(config_->model.vision.filename)).c_str(), vision_session_options);
ort_env, (config_->config_path / fs::path(config_->model.vision.filename)).c_str(), vision_session_options.get());
decoder_session_ = OrtSession::Create(
ort_env, (config_->config_path / fs::path(config_->model.decoder.filename)).c_str(), session_options_.get());

Expand Down Expand Up @@ -144,7 +147,7 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& mo
model_{model},
num_image_tokens_{GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.pixel_values, model_.config_->model.vision.inputs.image_sizes)},
captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)} {
embedding_state_ = std::make_unique<EmbeddingState>(model, params, captured_graph_info_.get(), num_image_tokens_);
embedding_state_ = std::make_unique<EmbeddingState>(model, params, nullptr, num_image_tokens_);
vision_state_ = std::make_unique<VisionState>(model_, params, num_image_tokens_);
decoder_state_ = std::make_unique<DecoderState>(model_, sequence_lengths_unk, params, captured_graph_info_.get());
}
Expand Down

0 comments on commit 53e3ac9

Please sign in to comment.