Skip to content

Commit

Permalink
Merge branch 'master' into improve_scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Aug 4, 2023
2 parents 8ee3a6f + fa31954 commit f2a6c24
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 16 additions & 3 deletions hbw/config/defaults_and_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@
from columnflow.tasks.framework.base import RESOLVE_DEFAULT


def default_selector(cls, container, task_params):
if container.has_tag("is_sl"):
selector = "sl"
elif container.has_tag("is_dl"):
selector = "dl"

return selector


def default_ml_model(cls, container, task_params):
""" Function that chooses the default_ml_model based on the inference_model if given """
# for most tasks, do not use any default ml model
default_ml_model = None

# the ml_model parameter is only used by `MLTraining` and `MLEvaluation`, therefore use some default
# NOTE: default_ml_model does not work for the MLTraining task
if cls.task_family in ("cf.MLTraining", "cf.MLEvaulation"):
if cls.task_family in ("cf.MLTraining", "cf.MLEvaulation", "cf.MergeMLEvents", "cf.PrepareMLEvents"):
# TODO: we might want to distinguish between two default ML models (sl vs dl)
default_ml_model = "dense_default"

Expand All @@ -34,10 +43,14 @@ def default_ml_model(cls, container, task_params):

def default_producers(cls, container, task_params):
""" Default producers chosen based on the Inference model and the ML Model """
dataset_inst = task_params.get("dataset_inst", None)

# per default, use the ml_inputs and event_weights
# TODO: we might need two ml_inputs producers in the future (sl vs dl)
default_producers = ["ml_inputs", "event_weights"]
default_producers = ["ml_inputs"]
if dataset_inst and dataset_inst.is_mc:
# run event weights producer only if it's a MC dataset
default_producers.append("event_weights")

# check if a ml_model has been set
ml_model = task_params.get("mlmodel", None) or task_params.get("mlmodels", None)
Expand Down Expand Up @@ -87,7 +100,7 @@ def set_config_defaults_and_groups(config_inst):
# TODO: the default dataset is currently still being set up by the law.cfg
config_inst.x.default_dataset = default_signal_dataset = f"{default_signal_process}_{signal_generator}"
config_inst.x.default_calibrator = "skip_jecunc"
config_inst.x.default_selector = f"{signal_tag}"
config_inst.x.default_selector = default_selector
config_inst.x.default_producer = default_producers
config_inst.x.default_ml_model = default_ml_model
config_inst.x.default_inference_model = "default"
Expand Down
4 changes: 2 additions & 2 deletions hbw/ml/dense_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def training_configs(self, requested_configs: Sequence[str]) -> list[str]:
return list(requested_configs)
else:
# use config_2017 per default
return ["config_2017"]
return ["c17"]

def training_calibrators(self, config_inst: od.Config, requested_calibrators: Sequence[str]) -> list[str]:
# fix MLTraining Phase Space
return ["skip_jecunc"]

def training_selector(self, config_inst: od.Config, requested_selector: str) -> str:
# fix MLTraining Phase Space
return "default"
return "sl" if self.config_inst.has_tag("is_sl") else "dl"

def training_producers(self, config_inst: od.Config, requested_producers: Sequence[str]) -> list[str]:
# fix MLTraining Phase Space
Expand Down

0 comments on commit f2a6c24

Please sign in to comment.