From 942952d714ebf425a5adb6483d87e562cdab1a85 Mon Sep 17 00:00:00 2001 From: Moya Chen <72097364+moyapchen@users.noreply.github.com> Date: Thu, 24 Mar 2022 00:48:37 -0400 Subject: [PATCH] TOD project folder (#4437) * [TOD] Projects folder for tod_simulator + scripts + documentation [lots of commits from this being a stacked diff removed cause... no one needs to see all that.] --- parlai/core/tod/tod_agents.py | 13 + parlai/core/tod/tod_core.py | 12 +- parlai/scripts/tod_world_script.py | 10 +- projects/tod_simulator/README.md | 149 +++++ projects/tod_simulator/__init__.py | 5 + projects/tod_simulator/scripts/__init__.py | 5 + .../scripts/cleanup_conversation.py | 157 +++++ .../scripts/do_get_passing_only_on_dir.py | 56 ++ .../scripts/get_al_samples_for_gsgd.py | 204 +++++++ .../tod_simulator/scripts/get_api_data.py | 125 ++++ .../get_interdistinct_on_conversations.py | 70 +++ .../tod_simulator/scripts/get_passing_only.py | 166 ++++++ .../scripts/get_quick_eval_stats.py | 178 ++++++ .../scripts/tod_distributed_uber_script.py | 552 ++++++++++++++++++ projects/tod_simulator/sweeps/pretrain_all.py | 119 ++++ .../tod_world_configs/__init__.py | 5 + .../tod_world_configs/all_human.json | 7 + .../google_sgd_simulation_dump_data.json | 12 + .../tod_simulator/world_metrics/__init__.py | 5 + 19 files changed, 1845 insertions(+), 5 deletions(-) create mode 100644 projects/tod_simulator/README.md create mode 100644 projects/tod_simulator/__init__.py create mode 100644 projects/tod_simulator/scripts/__init__.py create mode 100644 projects/tod_simulator/scripts/cleanup_conversation.py create mode 100644 projects/tod_simulator/scripts/do_get_passing_only_on_dir.py create mode 100644 projects/tod_simulator/scripts/get_al_samples_for_gsgd.py create mode 100644 projects/tod_simulator/scripts/get_api_data.py create mode 100644 projects/tod_simulator/scripts/get_interdistinct_on_conversations.py create mode 100644 projects/tod_simulator/scripts/get_passing_only.py create mode 100644 projects/tod_simulator/scripts/get_quick_eval_stats.py create mode 100644 projects/tod_simulator/scripts/tod_distributed_uber_script.py create mode 100644 projects/tod_simulator/sweeps/pretrain_all.py create mode 100644 projects/tod_simulator/tod_world_configs/__init__.py create mode 100644 projects/tod_simulator/tod_world_configs/all_human.json create mode 100644 projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json create mode 100644 projects/tod_simulator/world_metrics/__init__.py diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index dcb6ee260aa..fb712e4b24d 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -736,6 +736,19 @@ def __init__(self, opt, shared=None): self._num_examples_cache = sum([len(x.rounds) for x in self.episodes]) self._num_episodes_cache = len(self.episodes) + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--api-schemas", + type="bool", + default=False, + help="Preempt first turn with intents + required/optional parameters as key/value for given domain. NOOP for this teacher, but including to make sweeps easier", + ) + return parser + def setup_data(self, fold): for episode in self.generate_episodes(): if len(episode.rounds) < 1: diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py index 587407b6d80..0f29420de8e 100644 --- a/parlai/core/tod/tod_core.py +++ b/parlai/core/tod/tod_core.py @@ -168,7 +168,7 @@ def delex(cls, text, slots): def inner_list_join(cls, values): if isinstance(values, str): return values - return ", ".join(sorted([v.strip() for v in values])) + return ", ".join(sorted([str(v).strip() for v in values])) @classmethod def inner_list_split(cls, s): @@ -185,12 +185,18 @@ def inner_list_split(cls, s): def maybe_inner_list_join(cls, values): if type(values) is dict: return str(values) - if isinstance(values, str) or isinstance(values, int): + if ( + isinstance(values, str) + or isinstance(values, int) + or isinstance(values, float) + ): return values elif isinstance(values, Iterable): return SerializationHelpers.inner_list_join(values) else: - raise RuntimeError("invalid type of argument for maybe_inner_list_join") + raise RuntimeError( + f"invalid type of argument for maybe_inner_list_join: {values}; type {type(values)}" + ) @classmethod def api_dict_to_str(cls, apidict): diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index edfd23bdafc..04dcb4d3d4f 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -58,8 +58,14 @@ def _is_batch_world(self, world): def _log_batch(self, world): batch_acts = world.get_batch_acts() for i, acts in enumerate(batch_acts): - # filter out for empty - acts = [act for act in acts if act["id"] != "" and act["text"] != ""] + acts = [ + act for act in acts if act is not None and "id" in act and "text" in act + ] + acts = [ + act + for act in acts + if act["id"] != "" and (act["text"] != "" or "Human" in act["id"]) + ] if len(acts) > 0: self._add_msgs(acts, idx=i) if world.episode_done(): diff --git a/projects/tod_simulator/README.md b/projects/tod_simulator/README.md new file mode 100644 index 00000000000..5bae1de3f65 --- /dev/null +++ b/projects/tod_simulator/README.md @@ -0,0 +1,149 @@ +# Task Oriented Dialogue (TOD): Agents, Worlds, Scripts, etc + +### _Teaching Models new APIs: Domain-Agnostic Simulators for Task Oriented Dialogue_ + +Moya Chen, Paul A. Crook, Stephen Roller + +## Abstract + +We demonstrate that large language models are able to simulate Task Oriented Dialogues in novel domains, provided only with an API implementation and a list of goals. We show these simulations can formulate online, automatic metrics that correlate well with human evaluations. Furthermore, by checking for whether the User's goals are met, we can use simulation to repeatedly generate training data and improve the quality of simulations themselves. With no human intervention or domain-specific training data, our simulations bootstrap end-to-end models which achieve a 37% error reduction in previously unseen domains. By including as few as 32 domain-specific conversations, bootstrapped models can match the performance of a fully-supervised model with 10× more data. To our knowledge, this is the first time simulations have been shown to be effective at bootstrapping models without explicitly requiring any domain-specific training data, rule-engineering, or humans-in-the-loop. + +## Paper + +[Link to arXiv](https://arxiv.org/abs/2110.06905) + +# Explanation of content in project + +This directory contains code for executing conversations for task-oriented dialogue (ex. setting an alarm, asking for the time) in a structured format. We introduce this structured format then go into the operational details for our setup: dataset generation + model training, simulation script usage, then give an overview of scripts in this folder. We then go into details of the specific datasets that we use as well as how to download and interact with our pre-trained models. + +As a terminology note, while the paper uses "Assistant" throughout, the same speaker is generally referred to as the "System" throughout code and documentation. + +## Conversation structure + +In task oriented dialogue, we have a user (with some goal) that requests some form of action out of an assistant system. This assistant system normally has some external knowledge-base with to which it can interact with via APIs. + +To model this, we begin each episode with a grounding stage where: +1. an api schema agent gives a description string of the API to an api call and api response agents +2. a goal agent gives a target goal string to a user utterance agent to start the conversation + +During the 'parlay' or normal conversational phase, we have four agents that speak in looped turns: +1. User utt agent +2. System API call agent +3. API response agent +4. System utt agent + +In analogy to more traditional TOD-setups, one can think of the api call agent as dialogue state tracking and the system utt agent as natural language generation. Since many TOD-systems these days combine both dialogue state tracking and natural language generation into one model, we assume that the api call and system agents are the same. + +To prevent leakage of information between agents during the parlay phase, each agent only observes only its own output and that of the agent which speaks immediately before. + +## Dataset setup + Model Training + +See `parlai/core/tod/tod_agents.py` for information on how to build agents and teachers for a specific dataset. + +Of the agents described in the conversation, only the User and System need to be trained with generative models. These can be trained as normal ParlAI models (ie.`parlai train_model -t -mf -m `) using System- and UserSimulator- Teachers created via the documentation in the `tod_agents.py` file mentioned above. + +## Simulation Script Usage +Use `python parlai/scripts/tod_world_script.py` or `parlai tod_world_script` (or the corresponding `distributed_` prefixed versions) to generate model-model chats. Arguments to the script are listed in file. Note that it is oftentimes preferable to use the `python ..` rather than `parlai ..` form of this command, especially if one has model or agent specific flags, due to argument order parsing. + +As a quick example, we provide + +`parlai tod_world_script -o projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json` + +as an example of printing the validation data from Google SGD Out of Domain through the simulation script. + +Additionally, use this to specify a conversation where all of the agents take human input from the command line: + +``` +parlai tod_world_script --system-model parlai.agents.local_human.local_human:LocalHumanAgent --user-model parlai.agents.local_human.local_human:LocalHumanAgent --api-resp-model parlai.agents.local_human.local_human:LocalHumanAgent --api-schema-grounding-model parlai.agents.local_human.local_human:LocalHumanAgent --goal-grounding-model parlai.agents.local_human.local_human:LocalHumanAgent +``` + +(which is the same as `parlai tod_world_script -o projects/tod_simulator/tod_world_configs/all_human.json`, included for convenience) + +Defaults are provided for the grounding agents but must be specified for the rest. Pretrained model locations can also be specified for the user and system with `--user-model-file` and `--system-model-file` arguments respectively. Since the system agent + api call agent are assumed to be the same, we only specify the 5 distinct agents, rather than 6. + +Further documentation of the simulation world and simulation world metrics are described in `parlai/core/tod/tod_world.py` and `parlai/core/tod/world_metrics.py`, respectively. + +## Scripts in `script` directory of this folder + +**cleanup\_conversation.py** +As a convenience, we also add a script for parsing the output conversation of the TOD Script into a format slightly more ameniable to ACUTE-Eval. While the raw output of the TOD Script could be used as well, the provided cleanup script does things like remove API utterances + Goals. + +**do\_get\_passing\_only\_on\_dir.py** +Uses `get_passing_only.py` internaly to run on a directory + +**get\_al\_samples\_for\_gsgd.py** +Gets active learning samples out of Google SGD's OutDomainSystemTeacher train set based on worst-performing API calls as extracted from `get_passing_only.py`. + +**get\_api\_data.py** +For models trained with `tod_distributed_uber_script.py` that have `--api-jga-record` set to `True`, this will automatically pull per-api Google SGD Out-of-Domain JGA and simulation success statistics. + +**get\_interdistinct\_on\_conversations.py** +Deprecated script to calculate interdistinct metrics for simulation conversations. (Included for completeness.) + +**get\_passing\_only.py** +Given a conversation generated from `tod_world_script`, outputs statistics about performance of different APIs. + +**get\_quick\_eval\_stats.py** +For models trained with `tod_distributed_uber_script.py`, this quickly grabs evaluation and model-model simulation data into a comma-separated format. + +**tod\_distributed\_uber\_multiwoz\_script.py** +Version of `tod_distributed_uber_script.py` but with MultiWoz v2.2 as the primary task rather than Google SGD Out-of-Domain. (Included for completeness.) + +**tod\_distributed\_uber\_script.py** +Multi-step train, evaluation, and data generation script used in Simulations paper. Uses Google SGD Out-of-Domain as primary dataset; note "STANDALONE\_API\_FILE\_PATH" that needs to be set in file. Makes use of `do_get_passing_only_on_dir.py` and `get_al_samples_for_gsgd.py`; use `get_passing_only.py` and `get_api_data.py` after the fact for analysis. + +Note that this script is intended to be run in a SLURM environment matching that of the Simulations paper authors. It is unknown how the script performs in other settings but is included as a reference. + +## Tasks used in the paper + +See the appendix of [the paper](https://arxiv.org/abs/2110.06905) (or the description of the task in ParlAI Task List) for explanations of these datasets. Below, we include the dataset name, the command to run the `SystemTeacher` relevant for each of the datasets, and any other notable details. Other agents and teachers for the dataset are specified in the relevant task `agent.py` files. + +### Pretraining Tasks + +* Google SGD In-Domain + * `parlai dd -t google_sgd_simulation_splits:InDomainSystemTeacher` +* MetalWoz + * `parlai dd -t metalwoz:SystemTeacher` +* MSR_E2E + * `parlai dd -t msr_e2e:SystemTeacher` + * Note that due to the lack of annotations in this dataset, this System Teacher *only* includes utterance turns +* Multidogo + * `parlai dd -t multidogo:SystemTeacher` +* MultiWoz + * We use a fb-internal pre-processing of MultiWoz, based on MultiWoz v2.1 and do not open source it at this time. +* Taskmaster + * `parlai dd -t taskmaster:SystemTeacher` +* Taskmaster2 + * `parlai dd -t taskmaster2:SystemTeacher` +* Taskmaster3 (TicketTalk) + * `parlai dd -t taskmaster3:SystemTeacher` + +### Experimentation Tasks + +* Google SGD Out-of-Domain + * `parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher` +* MultiWoz (not currently included in paper) + * `parlai dd -t multiwoz_v22:SystemTeacher` + * This is a preprocessing of the dataset based on MultiWoz v2.2. Though utterances are the same as used for pre-training, API Call and API Response structures aer different. + +See "scripts in project directory" for scripts associated with training, evaluation, and data generation. + +## Pretrained models + +We release Schema-Aware and Schema-Agnostic version of our intermediate task-pretraining. One can see the outputs of these models by running + +``` +parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher -mf zoo:tod/tod_base_yes_api/model --skip-generation false --api-schemas true +``` + +for the Schema-Aware version of the model and + +``` +parlai dd -t google_sgd_simulation_splits:OutDomainSystemTeacher -mf zoo:tod/tod_base_no_api/model --skip-generation false --api-schemas false +``` + +for the Schema-Agnostic version. + +Note the path names of the model files; they are `zoo:tod/tod_base_{yes,no}_api/mode` where "yes" corresponds to Schema-Aware and "no" corresponding to Schema-Agnostic. Care must be taken to specify `--api-schemas` correctly since task-setting flags are parsed from teacher-specific flags and not from model files. + +These models are both based on a BART-large (400 million paramater) base model. Hyperparameters for training can be found in the paper; tasks are listed in "Pretraining Tasks" above. diff --git a/projects/tod_simulator/__init__.py b/projects/tod_simulator/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/projects/tod_simulator/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/projects/tod_simulator/scripts/__init__.py b/projects/tod_simulator/scripts/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/projects/tod_simulator/scripts/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/projects/tod_simulator/scripts/cleanup_conversation.py b/projects/tod_simulator/scripts/cleanup_conversation.py new file mode 100644 index 00000000000..1bc33816116 --- /dev/null +++ b/projects/tod_simulator/scripts/cleanup_conversation.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Script for making light modifications to conversations from tod chats such that they are +ready for the ACUTE format. + +Notably, this does things that are slightly too much of a pain in the butt to do with +regexes like "add suffixes to ids when multiple ids might have the same string" (and +change metadata appropriately). + +For example, the following + +``` +python cleanup_conversation.py --source_file _conversations.jsonl --report-path .json --agent-suffixes user_utt_model _BASE_USER system_utt_model _BASE_SYSTEM --included-speakers goal_grounding_model user_utt_model system_utt_model +``` + +strips the API call related turns and adds "_BASE_USER" and "_BASE_SYSTEM" (which otherwise would be the model type name, ex BART) to the latter two, respecitvely. +""" + +from parlai.core.params import ParlaiParser +from parlai.utils.conversations import Conversations, Metadata +from parlai.utils.io import PathManager +from parlai.core.script import ParlaiScript, register_script + +from parlai.core.tod.tod_core import TodAgentType, TOD_AGENT_TYPE_TO_PREFIX + +import json + + +@register_script("conversation_cleanup") +class ConversationCleanup(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser( + False, + False, + "Script for simplying conversations output from TOD. Input expected to be in conversations format, as is output", + ) + # Following params are same as the `eval_model` script + parser.add_argument( + "--source-file", + type=str, + required=True, + help="Source file in conversations format, generated from `tod_world_script.py`", + ) + parser.add_argument( + "--out-file", type=str, default=None, help="Output location." + ) + parser.add_argument( + "--included-speakers", + nargs="*", + type=str, + choices=[e.value for e in TodAgentType], + default=[TodAgentType.USER_UTT_AGENT, TodAgentType.SYSTEM_UTT_AGENT], + help="Which of the speakers to not remove. Should match those in `tod_world`", + ) + parser.add_argument( + "--agent-suffixes", + nargs="*", + type=str, + default=[ + TodAgentType.USER_UTT_AGENT, + "_USER", + TodAgentType.SYSTEM_UTT_AGENT, + "_SYSTEM", + ], + help="List of pairs. Speaker type should match those in `TodAgentType`; outputs (if included) will have the suffix added to the ID. This is useful when using multiple of the same out model (ex. Bart model for both the user and the system)", + ) + parser.add_argument( + "--num-conversations", + default=400, + help="Number of conversations to include. -1 for all", + ) + parser.add_argument( + "--report-path", + required=True, + help="path of the report saved from the tod_metrics_script", + ) + return parser + + def _get_turn_type(self, turn): + for agent_type, prefix in TOD_AGENT_TYPE_TO_PREFIX.items(): + if prefix in turn["text"]: + return agent_type + + def run(self): + opt = self.opt + if int(len(self.opt["agent_suffixes"])) % 2 != 0: + raise RuntimeError("Agent suffix input should be even") + suffixes = {} + for i in range(int(len(self.opt["agent_suffixes"]) / 2)): + agent = self.opt["agent_suffixes"][2 * i] + suffix = self.opt["agent_suffixes"][2 * i + 1] + suffixes[agent] = suffix + + with PathManager.open(opt["report_path"]) as r: + report = json.load(r)["report"] + tod_metrics = report["tod_metrics"] + + if opt["num_conversations"] > -1: + tod_metrics = tod_metrics[: opt["num_conversations"]] + + source = self.opt["source_file"].replace(".jsonl", "") + if self.opt["out_file"]: + out = self.opt["out_file"] + else: + if ( + "conversations" in source + ): # just to make sure we don't overwrite anything... + out = source.replace("conversations", "cleaned_conversations") + else: + out = "cleaned_" + source + + speakers = [] + with PathManager.open(out + ".jsonl", "w") as f: + conversations = Conversations(source + ".jsonl") + for i, conversation in enumerate(conversations): + if opt["num_conversations"] >= 0 and i >= opt["num_conversations"]: + break + cleaned_dialog = [] + for parlay_round in conversation.episode["dialog"]: + cleaned_parlay_round = [] + for turn in parlay_round: + turn_type = self._get_turn_type(turn) + if turn_type in self.opt["included_speakers"]: + if turn_type in suffixes: + turn["id"] += suffixes[turn_type] + if turn["id"] not in speakers: + speakers.append(turn["id"]) + cleaned_parlay_round.append(turn) + if len(cleaned_parlay_round) > 0: + cleaned_dialog.append(cleaned_parlay_round) + convo = {} + convo["dialog"] = cleaned_dialog + convo["metadata_path"] = Metadata._get_path(out) + convo["context"] = [ + { + "synthetic_task_success": tod_metrics[i][ + "synthetic_task_success" + ], + "goal_text": tod_metrics[i]["goal"]["text"], + } + ] + json_convo = json.dumps(convo) + f.write(json_convo + "\n") + + old_meta = Metadata(source + ".jsonl") + Metadata.save_metadata( + out, old_meta.opt, old_meta.self_chat, speakers, **old_meta.extra_data + ) + + +if __name__ == "__main__": + ConversationCleanup.main() diff --git a/projects/tod_simulator/scripts/do_get_passing_only_on_dir.py b/projects/tod_simulator/scripts/do_get_passing_only_on_dir.py new file mode 100644 index 00000000000..29a75484fbe --- /dev/null +++ b/projects/tod_simulator/scripts/do_get_passing_only_on_dir.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script + +from projects.tod_simulator.scripts.get_passing_only import GetPassingOnlyScript + +import glob + + +@register_script("do_get_passing_only_on_dir") +class DoGetPassingOnlyOnDirScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser(False, False) + parser.add_argument("-p", "--path", required=True, type=str) + parser.add_argument( + "--filter-call-attempts", + default=True, + help="when True, only counts as 'passing' if System made exactly same # of api calls as goals", + ) + return parser + + def run(self): + opt = self.opt + path = opt["path"] + + # assumes standard naming from the `model_consts` set of scripts + base_paths = [ + x.replace("_conversations.jsonl", "") + for x in glob.glob(f"{path}/*_conversations.jsonl") + ] + + for to_run in base_paths: + convo_path = to_run + "_conversations.jsonl" + report_path = to_run + ".json" + + here_opt = { + "convo_path": convo_path, + "report_path": report_path, + "print_to_file": True, + "filter_call_attempts": opt["filter_call_attempts"], + } + GetPassingOnlyScript._run_kwargs(here_opt) + print("done with ", convo_path) + + +if __name__ == "__main__": + DoGetPassingOnlyOnDirScript.main() diff --git a/projects/tod_simulator/scripts/get_al_samples_for_gsgd.py b/projects/tod_simulator/scripts/get_al_samples_for_gsgd.py new file mode 100644 index 00000000000..cd83a90a07d --- /dev/null +++ b/projects/tod_simulator/scripts/get_al_samples_for_gsgd.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Quick script for dumping out relevant conversation ids from GoogleSGD. +""" + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script +from parlai.tasks.google_sgd_simulation_splits.agents import GoogleSgdOutDomainParser +from parlai.core.tod.tod_agents import TodStructuredDataParser + +import parlai +import os +import json +import random + +PARLAI_DATA_PATH = os.path.dirname(os.path.dirname(parlai.__file__)) + "/data" + + +class GrabEpisodes(GoogleSgdOutDomainParser, TodStructuredDataParser): + def get_agent_type_suffix(self): + return "GrabEpisodes" + + +def setup_args(parser=None): + if not parser: + parser = ParlaiParser(False, False) + group = parser.add_argument_group("Get active learning samples script") + group.add_argument( + "--find-random-al-samples", + type=bool, + default=False, + help="Get active learning samples randomly or few shot", + ) + group.add_argument( + "--input-processed-stats", + type=str, + help="Processed stats file from the `get_passing_only` script", + ) + group.add_argument( + "--processed-stats-section", + type=str, + default="MISSES", + help="Which section from the `get_passing_only` script will we use to rank. (MISSES and FRACTIONAL current options)", + ) + group.add_argument( + "--num-apis-to-get", + default=8, + type=int, + help="Number of api descriptions we want to find", + ) + group.add_argument( + "--existing-al-files", + nargs="*", + type=str, + help="Existing active learning files (ie, for running multiple iterations of learning)", + ) + group.add_argument( + "--cumulative-al", + type=bool, + default=True, + help="Uses active learning files from '--existing-al-files' cumulatively (as in, will append all prior dialog ids for next, rather than excluding)", + ) + group.add_argument( + "--al-output-file", + default=None, + help="Output file. Will put into 'results' in active run directory otherwise.", + ) + return parser + + +@register_script("get_al_samples_for_gsgd_script") +class GetAlSamplesForGsgdScript(ParlaiScript): + @classmethod + def setup_args(cls): + return setup_args() + + def run(self): + existing_al_ids = self.get_existing_al_ids() + # NOTE: The inidivdual get_al_samples funcitons are responsible for dealing with cumulative + if self.opt["find_random_al_samples"]: + save_me = self.get_al_samples_random(existing_al_ids) + else: + save_me = self.get_al_samples_from_processed(existing_al_ids) + out_file = self.opt.get("al_output_file") + if not out_file: + out_file = "result" + with open(out_file, "w+") as f: + json.dump(save_me, f, indent=4) + print("Saved AL samples to ", out_file) + + def get_al_samples_random(self, existing_al_ids): + save_me = {} + for datatype in ["train", "valid", "test"]: + unfiltered_episodes = self.get_gsgd_episodes_for_datatype(datatype) + filtered_episodes = [ + x + for x in unfiltered_episodes + if x.extras["dialogue_id"] not in existing_al_ids + ] + samples = random.Random(42).sample( + filtered_episodes, self.opt["num_apis_to_get"] + ) + if self.opt["cumulative_al"]: + old = [ + x + for x in unfiltered_episodes + if x.extras["dialogue_id"] in existing_al_ids + ] + samples.extend(old) + save_me[datatype] = { + episode.extras["dialogue_id"]: episode.goal_calls_utt + for episode in samples + } + return save_me + + def get_al_samples_from_processed(self, existing_al_ids): + wanted_apis = self.get_wanted_apis() + save_me = {} + for datatype in ["train", "valid", "test"]: + found = [False] * len(wanted_apis) + save_me[datatype] = {} + for episode in self.get_gsgd_episodes_for_datatype(datatype): + if episode.extras["dialogue_id"] in existing_al_ids.get(datatype, []): + if self.opt["cumulative_al"]: + save_me[datatype][ + episode.extras["dialogue_id"] + ] = episode.goal_calls_utt + continue + here = False + for idx, val in enumerate(found): + if not val and not here: + here = True + for api_name_slots in wanted_apis[idx]: + api_name = api_name_slots[0] + slots = api_name_slots[1] + if api_name in episode.goal_calls_utt: + for slot in slots: + if slot not in episode.goal_calls_utt: + here = False + else: + here = False + if here: + found[idx] = True + save_me[datatype][ + episode.extras["dialogue_id"] + ] = episode.goal_calls_utt + return save_me + + def get_gsgd_episodes_for_datatype(self, datatype): + datapath = PARLAI_DATA_PATH + if "datapath" in self.opt: + datapath = self.opt["datapath"] + elif "parlai_home" in self.opt: + datapath = self.opt["parlai_home"] + "/data" + opt = { + "datatype": datatype, + "datapath": datapath, + "gsgd_domains": "all", + "n_shot": -1, + "episodes_randomization_seed": -1, + } + return GrabEpisodes(opt).episodes + + def get_existing_al_ids(self): + existing_al_files = self.opt.get("existing_al_files") + if existing_al_files is None: + return {} + existing_al_ids = {} + for existing in existing_al_files: + with open(existing) as f: + data = json.load(f) + for datatype, blob in data.items(): + if datatype not in existing_al_ids: + existing_al_ids[datatype] = [] + for dialog_id in blob: + existing_al_ids[datatype].append(dialog_id) + return existing_al_ids + + def get_wanted_apis(self): + unprocessed_lines = [] + with open(self.opt["input_processed_stats"]) as f: + lines = f.readlines() + idx = 0 + while self.opt["processed_stats_section"] not in lines[idx]: + idx += 1 + unprocessed_lines = lines[idx + 1 : idx + 1 + self.opt["num_apis_to_get"]] + + processed = [] + for row_raw in unprocessed_lines: + print(row_raw) + row = json.loads(json.loads(row_raw.strip())[0]) + slots = [x for x in row if "api_name" not in x] + for x in row: + if "api_name" in x: + processed.append((x.replace("api_name:", ""), slots)) + return processed + + +if __name__ == "__main__": + GetAlSamplesForGsgdScript.main() diff --git a/projects/tod_simulator/scripts/get_api_data.py b/projects/tod_simulator/scripts/get_api_data.py new file mode 100644 index 00000000000..a7bae4184b6 --- /dev/null +++ b/projects/tod_simulator/scripts/get_api_data.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Prints validation statistics about API calls from a directory generated by +`tod_distributed_uber_script`. + +This file assumes that there exists a "eval_stats.json" that includes validation +statistics from a valid run on OutDomainSystemTeacher of Google SGD's simulation +scripts; also assumes a file of the format `mm_eval*pro*st*` generated from validation +goals of the same source conversations, split by single goals. +""" + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript +import glob +import json + +API_CALL_NAMES = [ + "ShareLocation", + "RequestPayment", + "MakePayment", + "FindApartment", + "ScheduleVisit", + "FindHomeByArea", + "GetCarsAvailable", + "ReserveCar", +] + + +class GetQuickEvalStatsDirScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser(False, False) + parser.add_argument("-p", "--path", required=False, type=str) + return parser + + def get_apis_from_eval_stats(self): + report = self.eval_data["report"] + result = {} + for key in report: + for call in API_CALL_NAMES: + if call in key: + if call not in result: + result[call] = {} + result[call][key] = report[key] + return result + + def get_apis_from_processed_stats(self): + cand_files = glob.glob(f"{self.path}/mm_eval*pro*st*") # eval stats + if len(cand_files) == 0: + cand_files = glob.glob( + f"{self.path}/*pro*st*" + ) # legacy before nucleus added + if len(cand_files) == 0: + print("NO PROCESSED STATS FILE FOUND!") + return {} + + with open(cand_files[0]) as f: + lines = [line.strip() for line in f.readlines()] + for i in range(len(lines)): + if "DELTAS" in lines[i]: + want = lines[i + 1 :] + break + + result = {} + for line in want: + key, value = line.strip().split('", [') + for call in API_CALL_NAMES: + if call in key: + if call not in result: + result[call] = {} + result[call][key] = value + return result + + def get_align_apis(self): + api_eval = self.get_apis_from_eval_stats() + api_processed_stats = self.get_apis_from_processed_stats() + if len(api_eval) != len(api_processed_stats): + print( + "LENGTH OF FILES NOT THE SAME: api_eval", + len(api_eval), + "api_processed_stats", + len(api_processed_stats), + ), + merged = [] + for key in sorted(API_CALL_NAMES): + if ( + key == "FindHomeByArea" and len(api_eval) > 0 + ): # this one has an extra api call type in train with 6 samples + if len(api_eval[key]) != len(api_processed_stats[key]): + api_eval[key][ + "api-FindHomeByArea--area-has_garage-in_unit_laundry-intent-number_of_baths-number_of_beds" + ] = "NA" + api_eval_calls = sorted(api_eval.get(key, {}).keys()) + api_processed_stats_here = sorted(api_processed_stats.get(key, {}).items()) + for i in range(max(len(api_eval_calls), len(api_processed_stats_here))): + save_me = [] + if len(api_eval_calls) > i: + api_call = api_eval_calls[i] + save_me.append(api_call) + save_me.append(str(api_eval[key][api_call])) + if len(api_processed_stats_here) > i: + save_me.append(api_processed_stats_here[i][1].replace("]", "")) + merged.append(", ".join(save_me)) + for entry in merged: + print(entry) + + def run(self): + opt = self.opt + path = opt["path"] + self.path = path + + eval_stats = f"{path}/eval_stats.json" + with open(eval_stats) as f: + eval_data = json.load(f) + self.eval_data = eval_data + + self.get_align_apis() + + +if __name__ == "__main__": + GetQuickEvalStatsDirScript.main() diff --git a/projects/tod_simulator/scripts/get_interdistinct_on_conversations.py b/projects/tod_simulator/scripts/get_interdistinct_on_conversations.py new file mode 100644 index 00000000000..f220e724036 --- /dev/null +++ b/projects/tod_simulator/scripts/get_interdistinct_on_conversations.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" +from parlai.core.metrics import InterDistinctMetric + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script + +import glob +import json + + +@register_script("get_interdistinct_on_conversations") +class GetInterdistinctOnConversationsScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser(False, False) + parser.add_argument("-p", "--path", required=True, type=str) + return parser + + def run(self): + opt = self.opt + paths = glob.glob(f"{opt['path']}/*_conversations.jsonl") + + for path in paths: + name = path.split("/")[-1] + sys_utt_one = InterDistinctMetric.compute("", 1) + sys_utt_two = InterDistinctMetric.compute("", 2) + user_utt_one = InterDistinctMetric.compute("", 1) + user_utt_two = InterDistinctMetric.compute("", 2) + + with open(path) as f: + for line_raw in f: + line = json.loads(line_raw)["dialog"] + for turn in line: + if len(turn) < 4: + continue + user_utt = turn[0]["text"] + if user_utt.startswith("APIS: "): + continue + user_utt_one += InterDistinctMetric.compute(user_utt, 1) + user_utt_two += InterDistinctMetric.compute(user_utt, 2) + + sys_utt = turn[3]["text"] + sys_utt_one += InterDistinctMetric.compute(sys_utt, 1) + sys_utt_two += InterDistinctMetric.compute(sys_utt, 2) + + print( + ",".join( + [ + str(x) + for x in [ + name, + user_utt_one.value(), + user_utt_two.value(), + sys_utt_one.value(), + sys_utt_two.value(), + ] + ] + ) + ) + + +if __name__ == "__main__": + GetInterdistinctOnConversationsScript.main() diff --git a/projects/tod_simulator/scripts/get_passing_only.py b/projects/tod_simulator/scripts/get_passing_only.py new file mode 100644 index 00000000000..57dc917e094 --- /dev/null +++ b/projects/tod_simulator/scripts/get_passing_only.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" + +from collections import defaultdict +from copy import deepcopy +import json +import sys +from shutil import copyfile + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script +from parlai.core.tod.tod_core import STANDARD_DONE +from parlai.utils.io import PathManager + +from parlai.core.tod.tod_core import SerializationHelpers + + +@register_script("get_passing_only") +class GetPassingOnlyScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser(False, False) + parser.add_argument("--cut-first-400", default=False) + parser.add_argument("--convo-path", required=True) + parser.add_argument( + "--report-path", + required=True, + help="path of the report saved from the tod_metrics_script", + ) + parser.add_argument( + "--print-to-file", + default=False, + help="save the results to a file (by hackishly redirecting stdout)", + ) + parser.add_argument( + "--filter-call-attempts", + default=True, + help="when True, only counts as 'passing' if System made exactly the same # of api calls as # of goals", + ) + return parser + + def run(self): + opt = self.opt + + with PathManager.open(opt["report_path"]) as r: + print(opt["report_path"]) + report = json.load(r)["report"] + tod_metrics = report["tod_metrics"] + + if opt["cut_first_400"]: + tod_metrics = tod_metrics[400:] + + infile_base = opt["convo_path"].replace(".jsonl", "") + outfile_base = infile_base + "_processed" + + copyfile(f"{infile_base}.metadata", f"{outfile_base}.metadata") + + if opt["print_to_file"]: + print_file_path = infile_base.replace("_conversations", "_processed_stats") + print_file = open(print_file_path, "w+") + orig_stdout = sys.stdout + sys.stdout = print_file + + found = defaultdict(lambda: 0) + not_found = defaultdict(lambda: 0) + + with PathManager.open(f"{outfile_base}.jsonl", "w") as out: + with PathManager.open(opt["convo_path"]) as c: + lines = c.readlines() + if opt["cut_first_400"]: + lines = lines[400:] + print("len lines: ", len(lines)) + print("len tod_metrics: ", len(tod_metrics)) + for i, l in enumerate(lines): + if i >= len(tod_metrics): + break + goals = SerializationHelpers.str_to_goals( + tod_metrics[i]["goal"]["text"][len(STANDARD_DONE) :] + ) + if len(goals) == 0: + print(goals) + continue + if tod_metrics[i]["synthetic_task_success"] == 1.0 and ( + not self.opt["filter_call_attempts"] + or tod_metrics[i].get("api_call_attempts", 1) # legacy + == len(goals) + ): + print(goals, tod_metrics[i]["api_call_attempts"], l) + self.api_string_add(found, goals, 1) + self.api_string_add(not_found, goals, 0) + out.write(l) + else: + self.api_string_add(not_found, goals, 1) + self.api_string_add(found, goals, 0) + print("count found", sum(found.values())) + print("count notfound", sum(not_found.values())) + print( + "============================ FOUND ======================\n", + [ + (k, v) + for k, v in sorted(found.items(), key=lambda x: x[1], reverse=True) + ], + ) + print( + "============================= NOT FOUND ====================\n", + [ + (k, v) + for k, v in sorted( + not_found.items(), key=lambda x: x[1], reverse=True + ) + ], + ) + + biggest_misses = deepcopy(not_found) + for f in found: + biggest_misses[f] -= found[f] + print( + "======================= BIGGEST MISSES (# not found - # found) ======================\n", + "\n".join( + [ + json.dumps((k, v)) + for k, v in sorted( + biggest_misses.items(), key=lambda x: x[1], reverse=True + ) + ] + ), + ) + fraction = {} + for k in not_found: + total = not_found[k] + found[k] + fraction[k] = (float(not_found[k]) / total, total) + print( + "=========================== BIGGEST FRACTIONAL DELTAS (not_found / total # ) ====================\n", + "\n".join( + [ + json.dumps((k, v)) + for k, v in sorted( + fraction.items(), key=lambda x: x[1], reverse=True + ) + ] + ), + ) + + if opt["print_to_file"]: + print_file.close() + sys.stdout = orig_stdout + + def api_string_add(self, found_list, api_strings, val): + for api_string in api_strings: + if "api_name" not in api_string: + continue + api_name = api_string["api_name"] + api_string["api_name:" + api_name] = "" + api_name = api_string.pop("api_name") + found_list[json.dumps(sorted(list(api_string)))] += val + api_string["api_name"] = api_name + + +if __name__ == "__main__": + GetPassingOnlyScript.main() diff --git a/projects/tod_simulator/scripts/get_quick_eval_stats.py b/projects/tod_simulator/scripts/get_quick_eval_stats.py new file mode 100644 index 00000000000..fe75f20628a --- /dev/null +++ b/projects/tod_simulator/scripts/get_quick_eval_stats.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" + +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript +import glob +import json +import os + + +class GetQuickEvalStatsDirScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = ParlaiParser(False, False) + parser.add_argument("-p", "--path", required=True, type=str) + return parser + + def run(self): + opt = self.opt + path = opt["path"] + + stuff = {} + + for eval_stats in glob.glob(f"{path}/*/eval_stats.json"): + with open(eval_stats) as f: + eval_data = json.load(f) + + base_path = os.path.abspath(eval_stats).replace("eval_stats.json", "/") + + base = "" + multitask = "" + if not os.path.isfile(base_path + "run.sh"): + continue + with open(base_path + "run.sh") as run_file: + for line in run_file: + if "zoo:bart" in line: + base = "BartOnly" + + with open(base_path + "model.opt") as opt_file: + print(base_path) + opt = json.load(opt_file) + lr = opt["learningrate"] + multitask = opt.get("multitask_weights", "") + yes_api = opt.get("api_schemas", False) + if multitask != "": + multitask = "_" + "".join([str(x) for x in multitask]) + + orig_path = "/".join(os.path.abspath(base_path).split("/")[-3:]) + + ppl = eval_data["report"].get("ppl", "") + token_em = eval_data["report"].get("token_em", "") + jga = eval_data["report"].get("jga", "") + jga_n = eval_data["report"].get("jga_noempty", "") + jga_e = eval_data["report"].get("jga_empty", "") + + if yes_api: # hack cause I'm annoyed at things being in wrong alpha order + yes_api = "True" + else: + yes_api = "false" + + maybe_nshot = "" + root = base_path + eval_file = glob.glob(f"{root}/*mm_eval*.json") + if len(eval_file) > 0 and ( + "percent" in eval_file[0] or "nshot" in eval_file[0] + ): + maybe_nshot = eval_file[0][len(root + "mm_eval_") : -len(".json")] + + base += f"{len(opt.get('task').split(','))}-{yes_api}-APIS__{maybe_nshot}{base}_{lr}{multitask}" + metrics = [orig_path, base, ppl, token_em, jga, jga_n, jga_e] + + if os.path.isfile(eval_stats.replace("eval_stats", "user_eval_stats")): + with open(eval_stats.replace("eval_stats", "user_eval_stats")) as f: + metrics.append(json.load(f)["report"]["ppl"]) + else: + metrics.append("DNE") + + metrics.append(" ") + + root = base_path + maybe_mm_stats = glob.glob(f"{root}/mm_eval") + if len(maybe_mm_stats) == 0: + maybe_mm_stats = glob.glob(f"{root}/*mm_eval*.json") + if len(maybe_mm_stats) == 0: + maybe_mm_stats = glob.glob(f"{root}/*Mult*.json") + if len(maybe_mm_stats) > 0: + with open(maybe_mm_stats[0]) as f: + report = json.load(f)["report"] + if "synthetic_task_success" in report: + tsr = report["synthetic_task_success"] + elif "all_goals_hit" in report: + tsr = report["all_goals_hit"] + else: + tsr = "DNE" + metrics.append(tsr) + else: + metrics.append("DNE") + + metrics.append(" ") + + if os.path.isfile(eval_stats.replace("eval_stats", "in_eval_stats")): + with open(eval_stats.replace("eval_stats", "in_eval_stats")) as f: + blob = json.load(f)["report"] + metrics.append(blob.get("ppl", "")) + metrics.append(blob.get("jga", "")) + metrics.append(blob.get("jga_noempty", "")) + metrics.append(blob.get("jga_empty", "")) + else: + metrics.append("DNE") + metrics.append("DNE") + metrics.append("DNE") + metrics.append("DNE") + + if os.path.isfile(eval_stats.replace("eval_stats", "in_user_eval_stats")): + with open(eval_stats.replace("eval_stats", "in_user_eval_stats")) as f: + metrics.append(json.load(f)["report"]["ppl"]) + else: + metrics.append("DNE") + + if os.path.isfile(eval_stats.replace("eval_stats", "test_eval_stats")): + with open(eval_stats.replace("eval_stats", "test_eval_stats")) as f: + blob = json.load(f)["report"] + metrics.append(blob.get("ppl", "")) + metrics.append(blob.get("jga", "")) + metrics.append(blob.get("jga_noempty", "")) + metrics.append(blob.get("jga_empty", "")) + else: + metrics.append("DNE") + metrics.append("DNE") + metrics.append("DNE") + metrics.append("DNE") + + if os.path.isfile(eval_stats.replace("eval_stats", "test_user_eval_stats")): + with open( + eval_stats.replace("eval_stats", "test_user_eval_stats") + ) as f: + metrics.append(json.load(f)["report"]["ppl"]) + else: + metrics.append("DNE") + + stuff[base] = ",".join([str(x) for x in metrics]) + + ordering = [ + "orig_path", + "base", + "sys_ppl", + "token_em", + "jga", + "jga_noempty", + "jga_empty", + "user_ppl", + "space", + "sTsr", + "space", + "in_sys_ppl", + "in_jga", + "in_jga_noempty", + "in_jga_empty", + "in_user_ppl", + "test_sys_ppl", + "test_jga", + "test_jga_noempty", + "test_jga_empty", + "test_user_ppl", + ] + print(",".join(ordering)) + for key in sorted(stuff.keys()): + print(stuff[key]) + + +if __name__ == "__main__": + GetQuickEvalStatsDirScript.main() diff --git a/projects/tod_simulator/scripts/tod_distributed_uber_script.py b/projects/tod_simulator/scripts/tod_distributed_uber_script.py new file mode 100644 index 00000000000..ee40810022d --- /dev/null +++ b/projects/tod_simulator/scripts/tod_distributed_uber_script.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Script for training, evaluating, and running model model chats for Task Oriented Dialog. + +Note that the code below does assume running in a SLURM environment + +Use +``` +parlai train -t parlai.tasks.google_sgd.agents:StandaloneApiTeacher --standalone-api-file standalone_api_file -m parlai.core.tod.tod_agents:StandaloneApiAgent -mf -eps 5 +``` +to define STANDALONE_API_FILE_PATH +""" +from parlai.core.script import ParlaiScript, register_script +from parlai.scripts.train_model import TrainModel +from parlai.scripts.distributed_train import ( + DistributedTrain, + setup_args as train_setup_args, +) +from parlai.scripts.eval_model import EvalModel +from parlai.scripts.distributed_eval import DistributedEval +from parlai.scripts.tod_world_script import TodWorldScript +from parlai.scripts.distributed_tod_world_script import DistributedTodWorldScript +from projects.tod_simulator.scripts.get_passing_only import GetPassingOnlyScript +from projects.tod_simulator.scripts.get_al_samples_for_gsgd import ( + GetAlSamplesForGsgdScript, + setup_args as setup_al_sample_script_args, +) + +import copy +import os +from shutil import copyfile +import json +import random + +STANDALONE_API_FILE_PATH = ( + "/checkpoint//projects/user_simulator/standalone_api_data/google_sgd" +) + + +@register_script("tod_distributed_uber_script") +class TodDistributedUberScript(ParlaiScript): + def get_mm_opt(self, datatype, port, mm_prefix="", mm_suffix=""): + model_model_opt = {} + model_model_opt["exact_api_call"] = True + model_model_opt["api_schemas"] = self.opt["api_schemas"] + model_model_opt["display_examples"] = False # we'll see these later + model_model_opt["episodes_randomization_seed"] = 42 + if "nucleus" in mm_prefix or "nucleus" in mm_suffix: + model_model_opt["episodes_randomization_seed"] = random.randint( + 0, 1000000000000 + ) + model_model_opt["skip_generation"] = False + model_model_opt["batchsize"] = 32 + model_model_opt["num_episodes"] = -1 + model_model_opt["datatype"] = datatype + model_model_opt["log_keep_fields"] = "all" + # Do this so that the agents get the ride settings + model_model_opt["override"] = copy.deepcopy(model_model_opt) + + if self.opt["api_schemas"]: + grounding = "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainSingleApiSchemaAgent" + else: + grounding = "parlai.core.tod.tod_agents_and_teachers:TodEmptyApiSchemaAgent" + + model_model_opt["api_schema_grounding_model"] = grounding + model_model_opt[ + "goal_grounding_model" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainSingleGoalAgent" + model_model_opt[ + "api_resp_model" + ] = "parlai.core.tod.tod_agents_and_teachers:TodStandaloneApiAgent" + model_model_opt["standalone_api_file"] = STANDALONE_API_FILE_PATH + + model_model_opt["system_model_file"] = self.opt["model_file"] + model_model_opt["user_model_file"] = self.opt["model_file"] + model_model_opt["distributed_world_size"] = self.opt["distributed_world_size"] + model_model_opt["ddp_backend"] = self.opt["ddp_backend"] + model_model_opt["save_format"] = "conversations" + model_model_opt["log_every_n_seconds"] = 30 + + if self.opt["custom_model_model_name"] is None: + model_model_name = "" + if "zoo:bart" in self.opt["init_model"]: + model_model_name = "BartOnlyNoApi" + else: + model_model_name = self.opt["custom_model_model_name"] + + model_model_name += "_ApiSchemas" + str(self.opt["api_schemas"]) + lr = self.opt["learningrate"] + multitask = "Multitask-" + str(self.opt["multitask_weights"][0]) + model_model_name = f"{model_model_name}_{lr}_{multitask}" + + base_path_name = os.path.dirname(self.opt["model_file"]) + + model_model_opt["report_filename"] = os.path.join( + base_path_name, mm_prefix + model_model_name + mm_suffix + ) + model_model_opt["world_logs"] = os.path.join( + base_path_name, mm_prefix + model_model_name + mm_suffix + ) + model_model_opt["port"] = port + + pretty_print = "" + + for key in model_model_opt: + if key == "override": + continue + pretty_print += "--" + key.replace("_", "-") + pretty_print += " " + pretty_print += str(model_model_opt[key]) + pretty_print += " " + self.dist_print(pretty_print) + + return model_model_opt + + @classmethod + def setup_args(cls): + parser = train_setup_args() + # Convenience + group = parser.add_argument_group("Tod distributed uber script") + group.add_argument( + "--custom-model-model-name", + default=None, + type=str, + help="model model name. Set to empty string to derive", + ) + group.add_argument( + "--existing-train-files", + nargs="*", + required=True, + help="Path to previous model-model generated (and processed) train files", + ) + group.add_argument( + "--rl-level", + type=int, + required=True, + help="Which level of RL are we running? Base pretrain is 0. Using JSON once is 1. 'existing-train-files' and 'existing-al-files' will be truncated to this length.", + ) + group.add_argument( + "--skip-al-generation", + type=bool, + default=False, + help="Skip AL geenration, ie for IN-JSON rains", + ) + group.add_argument( + "--skip-train-convo-generation", + type=bool, + default=False, + help="Skip train-convo geenration, ie for ones that don't use json", + ) + group.add_argument( + "--nucleus-mm-topp", + type=float, + nargs="*", + default=[], + help="List of coefficients of nucleus used for train data generation", + ) + group.add_argument( + "--api-schemas", + type=bool, + help="Is this a yes API Schemas or a no API Schemas model?", + ) + setup_al_sample_script_args(parser) + + return parser + + def run(self): + ###### + # This is a gigantic function that sets necessary a priori state (notably, if we are in a distributed setting), then generates a bunch of opts, then runs those opts. + ##### + + # Do this manually since we are not yet in a distributed context yet at this piece of code and cannot use distributed.py + if "SLURM_PROCID" in os.environ: + self.rank = int(os.environ["SLURM_PROCID"]) + else: + self.rank = -1 + + self.dist_print(f"Using distributed: {self.is_slurm_distributed()}") + + # Setup all of the args first, then see if any issues. + model_file = self.opt["model_file"] + base_path_name = os.path.dirname(model_file) + api_schemas = self.opt["api_schemas"] + + #################### EVAL OPTS + # Grab necessary default args and set them + eval_argparser = DistributedEval.setup_args() + eval_opt = eval_argparser.parse_args( + [ + "--task", + "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainSystemTeacher", + "--model", + self.opt["model"], + ] + ) + # ...but also make sure to use the right settings passed in via run_grid (ie distributed opts) + for key in eval_opt: + if key in self.opt: + eval_opt[key] = self.opt[key] + # Now reset opts that we need here + eval_opt[ + "task" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainSystemTeacher" + eval_opt["model_file"] = model_file + eval_opt["api_schemas"] = api_schemas + eval_opt["batchsize"] = 32 + eval_opt["skip_generation"] = False + eval_opt["report_filename"] = os.path.join(base_path_name, "eval_stats.json") + eval_opt["datatype"] = "valid" + eval_opt["port"] = self.opt["port"] + 1 + eval_opt["distributed_world_size"] = self.opt["distributed_world_size"] + eval_opt["override"] = copy.deepcopy(eval_opt) + + ### OUT User valid eval + user_eval_opt = copy.deepcopy(eval_opt) + user_eval_opt[ + "task" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserSimulatorTeacher" + user_eval_opt["override"] = copy.deepcopy(user_eval_opt) + user_eval_opt["report_filename"] = os.path.join( + base_path_name, "user_eval_stats.json" + ) + user_eval_opt["port"] = self.opt["port"] + 2 + + ### OUT Test eval (system + user) + test_eval_opt = copy.deepcopy(eval_opt) + test_eval_opt["datatype"] = "test" + test_eval_opt["report_filename"] = os.path.join( + base_path_name, "test_eval_stats.json" + ) + test_eval_opt["port"] = self.opt["port"] + 5 + + user_test_eval_opt = copy.deepcopy(eval_opt) + user_test_eval_opt["datatype"] = "test" + user_test_eval_opt[ + "task" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserSimulatorTeacher" + user_test_eval_opt["report_filename"] = os.path.join( + base_path_name, "test_user_eval_stats.json" + ) + user_test_eval_opt["port"] = self.opt["port"] + 6 + + ### IN valid eval (system + user) + in_eval_opt = copy.deepcopy(eval_opt) + in_eval_opt["datatype"] = "valid" + in_eval_opt[ + "task" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:InDomainSystemTeacher" + in_eval_opt["report_filename"] = os.path.join( + base_path_name, "in_eval_stats.json" + ) + in_eval_opt["port"] = self.opt["port"] + 7 + + user_in_eval_opt = copy.deepcopy(eval_opt) + user_in_eval_opt["datatype"] = "valid" + user_in_eval_opt[ + "task" + ] = "parlai.tasks.google_sgd_simulation_splits.agents:InDomainUserSimulatorTeacher" + user_in_eval_opt["report_filename"] = os.path.join( + base_path_name, "in_user_eval_stats.json" + ) + user_in_eval_opt["port"] = self.opt["port"] + 8 + + eval_model_model_opt = self.get_mm_opt( + "valid", self.opt["port"] + 3, "mm_eval_" + ) + + train_model_model_opt = self.get_mm_opt( + "train", self.opt["port"] + 4, mm_prefix="", mm_suffix="_greedy" + ) + + # At this point, everything above is distributed and everything below is non distributed... + # ...except for nucleus stuff, so we have some of its own code there. + NEXT_FREE_PORT = 9 + + #### For processing everything after train data + if ( + self.opt["skip_train_convo_generation"] + and len(self.opt["nucleus_mm_topp"]) > 0 + ): + raise RuntimeError( + "Makes no sense to do nucleus generation if we're not making convos" + ) + + mm_report_filename = ( + train_model_model_opt["report_filename"].replace(".json", "") + ".json" + ) + mm_convo_filename = f"{train_model_model_opt['world_logs']}_{train_model_model_opt['save_format']}.jsonl" + + ### Passing only script args + passing_only_parser = GetPassingOnlyScript.setup_args() + passing_only_opt = passing_only_parser.parse_args( + [ + "--convo-path", + mm_convo_filename, + "--report-path", + mm_report_filename, + "--print-to-file", + str(True), + ] + ) + + # NOTE: Following line needs to be kept in sync with the get_passing_only script + infile_base = mm_convo_filename.replace(".jsonl", "") + passing_only_stats_file = infile_base.replace( + "_conversations", "_processed_stats" + ) + passing_only_convo_file = infile_base + "_processed.jsonl" + passing_only_convo_metadata = infile_base + "_processed.metadata" + + # Args for generating nucleus... bit of a mess. + nucleus_mm_opts = [] + nucleus_passing_only_opts = [] + nucleus_processed_convo_filenames = [] + if len(self.opt["nucleus_mm_topp"]) > 0: + if self.opt["skip_train_convo_generation"]: + raise RuntimeError( + "Makes no sense to do nucleus generation if we're not making convos" + ) + for i, topp in enumerate(self.opt["nucleus_mm_topp"]): + nucleus_opt = self.get_mm_opt( + "train", 0, mm_prefix="", mm_suffix=f"-nucleus-{topp}-{i}" + ) + nucleus_opt["inference"] = "nucleus" + nucleus_opt["topp"] = topp + nucleus_opt["override"]["inference"] = "nucleus" + nucleus_opt["override"]["topp"] = topp + nucleus_opt["port"] = self.opt["port"] + NEXT_FREE_PORT + i + nucleus_mm_opts.append(nucleus_opt) + nucleus_report_filename = ( + nucleus_opt["report_filename"].replace(".json", "") + ".json" + ) + nucleus_convo_filename = ( + f"{nucleus_opt['world_logs']}_{nucleus_opt['save_format']}.jsonl" + ) + nucleus_passing_only_opt = passing_only_parser.parse_args( + [ + "--convo-path", + nucleus_convo_filename, + "--report-path", + nucleus_report_filename, + "--print-to-file", + str(True), + ] + ) + nucleus_passing_only_opts.append(nucleus_passing_only_opt) + nucleus_processed_convo = ( + nucleus_convo_filename.replace(".jsonl", "") + "_processed.jsonl" + ) + nucleus_processed_convo_filenames.append(nucleus_processed_convo) + + # Cumulative converseation args + cumulative_convo = os.path.join(base_path_name, "processed_cumulative.jsonl") + cumulative_metadata = os.path.join( + base_path_name, "processed_cumulative.metadata" + ) + + noncumulative_convo = os.path.join( + base_path_name, "processed_noncumulative.jsonl" + ) + noncumulative_metadata = os.path.join( + base_path_name, "processed_noncumulative.metadata" + ) + + ### Active learning script args + al_sample_opt = {} + al_sample_opt["find_random_al_samples"] = self.opt["find_random_al_samples"] + al_sample_opt["processed_stats_section"] = self.opt["processed_stats_section"] + al_sample_opt["num_apis_to_get"] = self.opt["num_apis_to_get"] + al_sample_opt["existing_al_files"] = self.opt["existing_al_files"] + if "datapath" in self.opt: + al_sample_opt["datapath"] = self.opt["datapath"] + elif "parlai_home" in self.opt: + al_sample_opt["datapath"] = self.opt["parlai_home"] + "/data" + + # Manually override what we need to manually override + al_sample_opt["input_processed_stats"] = passing_only_stats_file + # To save on deciding if we're going to do cumulative runs together or separately, just do both + al_sample_noncumulative_opt = copy.deepcopy(al_sample_opt) + al_sample_noncumulative_opt["cumulative_al"] = False + al_sample_noncumulative_opt["al_output_file"] = os.path.join( + base_path_name, "al_noncumulative.json" + ) + + al_sample_cumulative_opt = copy.deepcopy(al_sample_opt) + al_sample_cumulative_opt["cumulative_al"] = True + al_sample_cumulative_opt["al_output_file"] = os.path.join( + base_path_name, "al_cumulative.json" + ) + + ####### Run everything, skipping if we've already finished it + if not os.path.isfile(model_file + ".test"): + self.dist_print("RUNNING TRAIN", self.opt) + train_result = self.train_class()(self.opt).run() + self.dist_print("train result: ", train_result) + + # Required necessary things + if not os.path.isfile(eval_model_model_opt["report_filename"] + ".json"): + self.dist_print("RUNNING VALID MODEL-MODEL", eval_model_model_opt) + model_model_result = self.tod_world_class()(eval_model_model_opt).run() + self.dist_print("eval_model_model_result: ", model_model_result) + + if not os.path.isfile(eval_opt["report_filename"]): + self.dist_print("RUNNING SYSTEM EVAL", eval_opt) + eval_result = self.eval_class()(eval_opt).run() + self.dist_print("eval_result: ", eval_result) + + if not os.path.isfile(user_eval_opt["report_filename"]): + self.dist_print("RUNNING USER SIM EVAL", user_eval_opt) + eval_result = self.eval_class()(user_eval_opt).run() + self.dist_print("user_eval_result: ", eval_result) + + # All the other evals + if not os.path.isfile(in_eval_opt["report_filename"]): + self.dist_print("RUNNING SYSTEM _IN_ EVAL", in_eval_opt) + in_eval_result = self.eval_class()(in_eval_opt).run() + self.dist_print("in_eval_result: ", in_eval_result) + + if not os.path.isfile(user_in_eval_opt["report_filename"]): + self.dist_print("RUNNING USER SIM _IN_ EVAL", user_in_eval_opt) + in_eval_result = self.eval_class()(user_in_eval_opt).run() + self.dist_print("user_in_eval_result: ", in_eval_result) + + if not os.path.isfile(test_eval_opt["report_filename"]): + self.dist_print("RUNNING SYSTEM TEST EVAL", test_eval_opt) + test_eval_result = self.eval_class()(test_eval_opt).run() + self.dist_print("test_eval_result: ", test_eval_result) + + if not os.path.isfile(user_test_eval_opt["report_filename"]): + self.dist_print("RUNNING USER SIM TEST EVAL", user_test_eval_opt) + test_eval_result = self.eval_class()(user_test_eval_opt).run() + self.dist_print("user_test_eval_result: ", test_eval_result) + + # convo generation + if api_schemas and not self.opt["skip_train_convo_generation"]: + if not os.path.isfile(train_model_model_opt["report_filename"] + ".json"): + self.dist_print("RUNNING TRAIN MODEL-MODEL", train_model_model_opt) + model_model_result = self.tod_world_class()(train_model_model_opt).run() + self.dist_print("train_model_model_result: ", model_model_result) + + for nucleus_mm_opt in nucleus_mm_opts: + if os.path.isfile( + nucleus_mm_opt["report_filename"].replace(".json", "") + ".json" + ): + continue + self.dist_print( + "RUNNING NUCLEUS MODEL MODEL FOR TOPP: ", nucleus_mm_opt["topp"] + ) + model_model_result = self.tod_world_class()(nucleus_mm_opt).run() + self.dist_print("nucleus_mm_result: ", model_model_result) + + if self.is_main_worker(): + self.dist_print("DOING GETTING PASSING ONLY") + GetPassingOnlyScript(passing_only_opt).run() + for nucleus_passing_only_opt in nucleus_passing_only_opts: + GetPassingOnlyScript(nucleus_passing_only_opt).run() + + if len(nucleus_processed_convo_filenames) > 0: + concatinate_me = [ + passing_only_convo_file + ] + nucleus_processed_convo_filenames + concat_conversations( + concatinate_me, + destination=noncumulative_convo, + source_metadata=passing_only_convo_metadata, + destination_metadata=noncumulative_metadata, + ) + + concatinate_me = ( + self.opt["existing_train_files"] + nucleus_processed_convo_filenames + ) + if len(concatinate_me) > 0: + concatinate_me = [passing_only_convo_file] + concatinate_me + concat_conversations( + concatinate_me, + destination=cumulative_convo, + source_metadata=passing_only_convo_metadata, + destination_metadata=cumulative_metadata, + ) + + # Al generation + if api_schemas and not self.opt["skip_al_generation"] and self.is_main_worker(): + self.dist_print( + "Getting AL samples (printing only cumulative opts)", + al_sample_cumulative_opt, + ) + GetAlSamplesForGsgdScript(al_sample_noncumulative_opt).run() + GetAlSamplesForGsgdScript(al_sample_cumulative_opt).run() + + def is_slurm_distributed(self): + return self.rank < 0 + + def is_main_worker(self): + return self.rank <= 0 + + def dist_print(self, *args): + # distributed aware print + if self.is_main_worker(): + print(*args) + + def train_class(self): + if self.is_slurm_distributed(): + return TrainModel + return DistributedTrain + + def eval_class(self): + if self.is_slurm_distributed(): + return EvalModel + return DistributedEval + + def tod_world_class(self): + if self.is_slurm_distributed(): + return TodWorldScript + return DistributedTodWorldScript + + +def concat_conversations( + concatinate_me, destination, source_metadata=None, destination_metadata=None +): + if source_metadata: + copyfile(source_metadata, destination_metadata) + seen_lines = set() + file_stats = {} + with open(destination, "w+") as destfile: + for source in concatinate_me: + print(source) + new_in_file = 0 + dup_in_file = 0 + with open(source) as infile: + for line in infile: + convo_raw = json.loads(line)["dialog"] + convo_formatted = [x["text"] for round in convo_raw for x in round] + convo = json.dumps(convo_formatted) + if convo not in seen_lines: + destfile.write(line) + seen_lines.add(convo) + new_in_file += 1 + else: + dup_in_file += 1 + file_stats[source] = {"new": new_in_file, "dup": dup_in_file} + print("Concat + filter dupe stats") + print(json.dumps(file_stats, indent=4)) + with open(destination.replace(".jsonl", "_concat_stats.json"), "w+") as f: + json.dump(file_stats, f, indent=4) + + +if __name__ == "__main__": + TodDistributedUberScript.main() diff --git a/projects/tod_simulator/sweeps/pretrain_all.py b/projects/tod_simulator/sweeps/pretrain_all.py new file mode 100644 index 00000000000..655a6704af0 --- /dev/null +++ b/projects/tod_simulator/sweeps/pretrain_all.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Note: This sweep file is presented as an example of the pretraining used. Note that it relies on internal versions of these datasets and uses internal sweep scripts, so it will not work in practice. +""" + + +from parlai_internal.projects.param_sweep_utils.param_sweep import run_grid +import time +import os + + +SCRIPT_NAME = os.path.basename(__file__).replace(".py", "") +TODAY = format(time.asctime().replace(":", "-").replace(" ", "_")[:-14]) +SWEEP_NAME = f"{SCRIPT_NAME}{TODAY}" + +here_path = os.path.realpath(__file__).replace(".py", "") +projects = here_path[here_path.find("/projects") :] +SAVEROOT = "/checkpoint/" + projects + TODAY + +HOURS = 23 +GPUS = 8 + +TEACHERS_NO_GSGD_GOOD = [ + "fb:taskmaster1:SystemTeacher", + "parlai_fb.tasks.taskmaster2.formatted_agents:SystemTeacher", + "fb:taskmaster3:SystemTeacher", + "fb:msr_e2e:SystemTeacher", + "parlai_fb.tasks.taskmaster2.formatted_agents:UserSimulatorTeacher", + "fb:taskmaster3:UserSimulatorTeacher", + "fb:msr_e2e:UserSimulatorTeacher", + "fb:multiwoz_tod:UserSimulatorTeacher", + "fb:multidogo:UserSimulatorTeacher", +] + +TEACHERS_NO_GSGD_FUNKY = [ + "fb:metalwoz_internal:SystemTeacher", # also without the STANDARD_ whatevers, so could be interesting. + "fb:multiwoz_tod:SystemTeacher", # API responses makes no sense + "fb:multidogo:SystemTeacher", # API responses make no sense + "fb:metalwoz_internal:UserSimulatorTeacher", # also without the STANDARD_ whatevers, so could be interesting. + "fb:taskmaster1:UserSimulatorTeacher", # no goals +] + +TEACHER_GSGD = [ + "parlai_fb.tasks.google_sgd_rl_splits.agents:InDomainUserSimulatorTeacher", + "parlai_fb.tasks.google_sgd_rl_splits.agents:InDomainSystemTeacher", +] + +ALL_TEACHERS = TEACHER_GSGD + TEACHERS_NO_GSGD_GOOD + TEACHERS_NO_GSGD_FUNKY +ALL_TEACHERS_NO_GSGD = TEACHERS_NO_GSGD_GOOD + TEACHERS_NO_GSGD_FUNKY +ALL_GOOD_TEACHERS = TEACHER_GSGD + TEACHERS_NO_GSGD_GOOD + +TEACHER_OPTIONS = [ + ",".join(ALL_TEACHERS), + # ",".join(ALL_TEACHERS_NO_GSGD), + ",".join(ALL_GOOD_TEACHERS), + # ",".join(TEACHERS_NO_GSGD_GOOD), + ",".join(TEACHER_GSGD), +] + +print(TEACHER_OPTIONS[0]) + +# Define param grid +grid = { + # dataset params + "-t": TEACHER_OPTIONS, + "--api-descriptions": [True, False], + # other params + "--model": ["parlai_fb.agents.bart.r3f:R3fFirstTurnHistoryRepeatAgent"], + "--fp16": [True], + "--label-truncate": [512], + "--log-every-n-secs": [30], + "--lr-scheduler": ["invsqrt"], + "--max-lr-steps": [-1], + "--max-train-steps": [-1], + "--optimizer": ["adam"], + "--save-after-valid": [True], + "--text-truncate": [512], + "--warmup-updates": [1000], + "--fp16-impl": ["mem_efficient"], + "--gradient-clip": [0.1], + "--skip-generation": [True], + "-vp": [8], + "--max-train-time": [HOURS * 60 * 60 - 30 * 60], + "--load-from-checkpoint": ["true"], + "-vmt": ["token_em -vmm max"], + "--multitask-weights": ["stochastic"], + # Sweeping params + "--batchsize": [4], + "--update-freq": [8], + "-lr": [1e-4], + "-vstep": [1000], +} + + +if __name__ == "__main__": + run_grid( + grid=grid, + name_keys={}, + sweep_name=SWEEP_NAME, + saveroot=SAVEROOT, + prefix="python -u -m parlai.scripts.distributed_train", + partition="learnlab", + jobtime=f"{HOURS}:00:00", + gpus=8, + nodes=1, + create_model_file=True, + requeue=True, + include_job_id=False, + volta32=True, + hashname=True, + mem_gb=400, + email_updates=True, + wandb=True, + ) diff --git a/projects/tod_simulator/tod_world_configs/__init__.py b/projects/tod_simulator/tod_world_configs/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/projects/tod_simulator/tod_world_configs/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/projects/tod_simulator/tod_world_configs/all_human.json b/projects/tod_simulator/tod_world_configs/all_human.json new file mode 100644 index 00000000000..6c66c52b704 --- /dev/null +++ b/projects/tod_simulator/tod_world_configs/all_human.json @@ -0,0 +1,7 @@ +{ +"api_schema_grounding_model": "local_human", +"goal_grounding_model": "local_human", +"user_model": "local_human", +"system_model": "local_human", +"api_resp_model": "local_human" +} diff --git a/projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json b/projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json new file mode 100644 index 00000000000..db8f762d024 --- /dev/null +++ b/projects/tod_simulator/tod_world_configs/google_sgd_simulation_dump_data.json @@ -0,0 +1,12 @@ +{ +"api_schema_grounding_model": "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent", +"goal_grounding_model": "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent", +"user_model": "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent", +"user_model_file": "", +"system_model":"parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent", +"system_model_file": "", +"api_resp_model": "parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent", +"display_examples": "True", +"datatype": "valid", +"num_episodes": -1 +} diff --git a/projects/tod_simulator/world_metrics/__init__.py b/projects/tod_simulator/world_metrics/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/projects/tod_simulator/world_metrics/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree.