Skip to content

Commit

Permalink
address pr comments - tmpdir -> tmp_path, cleaner code iteration, rem…
Browse files Browse the repository at this point in the history
…ove debugging prints.
  • Loading branch information
bgenchel committed Jul 31, 2024
1 parent bea0ba7 commit 9f5067a
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions tests/data/test_maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ def create_mock_midi(output_fpath: str) -> None:
logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")


def test_maestro_to_tf_example(tmpdir: str) -> None:
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi")))

output_dir = pathlib.Path(tmpdir) / "outputs"
output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)

input_data: List[str] = [TRAIN_TRACK_ID]
Expand All @@ -106,25 +106,24 @@ def test_maestro_to_tf_example(tmpdir: str) -> None:
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir)))
)

assert len(os.listdir(str(output_dir))) == 1
print("PASSED THIS POINT")
assert os.path.splitext(os.listdir(str(output_dir))[0])[-1] == ".tfrecord"
print("PASSED THIS OTHER POINT")
listdir = os.listdir(str(output_dir))
assert len(listdir) == 1
assert os.path.splitext(listdir[0])[-1] == ".tfrecord"
with open(os.path.join(str(output_dir), os.listdir(str(output_dir))[0]), "rb") as fp:
data = fp.read()
assert len(data) != 0


def test_maestro_invalid_tracks(tmpdir: str) -> None:
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_wav(str(mock_maestro_ext / (VALID_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_wav(str(mock_maestro_ext / (TEST_TRACK_ID.split("/")[1] + ".wav")), 3)

input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")]

for track_id, _ in input_data:
create_mock_wav(str(mock_maestro_ext / (track_id.split("/")[1] + ".wav")), 3)

split_labels = set([e[1] for e in input_data])
with TestPipeline() as p:
splits = (
Expand All @@ -137,23 +136,24 @@ def test_maestro_invalid_tracks(tmpdir: str) -> None:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
>> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="")
)

for track_id, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
with open(str(tmp_path / f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == track_id


def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
"""
The track id used here is a real track id in maestro, and it is part of the train split, but we mock the data so as
not to store a large file in git, hence the variable name.
"""

mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav")
create_mock_wav(str(mock_fpath), 16)

Expand All @@ -170,11 +170,11 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
>> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="")
)

for _, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
with open(str(tmp_path / f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == ""


Expand Down

0 comments on commit 9f5067a

Please sign in to comment.