Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow data loading of large arrays from sharded dataset #372

Open
rainx0r opened this issue Mar 30, 2024 · 1 comment
Open

Slow data loading of large arrays from sharded dataset #372

rainx0r opened this issue Mar 30, 2024 · 1 comment

Comments

@rainx0r
Copy link

rainx0r commented Mar 30, 2024

Hi. First off I'd like to say that I'm unsure if I should post this issue here or in the array_record repo or in the tensorflow_datasets repo. But my goal here is to ultimately use grain in my project because I really like the idea of deterministic data loading and easily checkpointing the state, shuffle etc, and I'm obviously using JAX.

The problem is that I can't seem to load ArrayRecords fast with grain for my data. Using TFRecords with TFDS seems to be a lot faster, which isn't really what I'd expect. I suspect this might be an issue with my dataset consisting of large arrays.

Data

My dataset has around 50000 samples, where each sample is a numPy array of shape (100,500,99) and float32 dtype. Currently my dataset is in 50000 .npy files. I'm testing with a subset of 5000 from them.

Conversion to ArrayRecord

...

# arbitrarily chose 50 arrays per ArrayRecord cause I read online 1GB is ok for shard size
num_arrays_shard = 50
filenames = np.array(list(DATA_DIR.iterdir()))  # .npy filenames 
num_shards = len(filenames) // num_arrays_shard  # 100 shards for my subset of the dataset
group_size = 1

features = tfds.features.FeaturesDict({
    "arr": tfds.features.Tensor(shape=(100,500,99), dtype=np.float32)
})

def _write_arrayrecord_shard(shard: int):
  writer = array_record.ArrayRecordWriter(
    f"{GRAIN_DATA_DIR}/data.array_record-{shard:05d}-of-{num_shards - 1:05d}",
    f"group_size:{group_size}"
  )
  for fname in filenames[shard * num_arrays_shard : shard * num_arrays_shard + num_arrays_shard]:
    _arr = np.load(fname).astype(np.float32)
    tf_example = features.serialize_example({"arr": _arr})
    writer.write(tf_example)
  writer.close()

_ = process_map(_write_arrayrecord_shard, range(num_shards), max_workers=multiprocessing.cpu_count())

Loading with grain

import grain.python as grain

ds = grain.ArrayRecordDataSource([str(f) for f in (GRAIN_DATA_DIR).iterdir()])

@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
  def map(self, _features):
    return features.deserialize_example_np(_features)

sampler = grain.SequentialSampler(num_records=len(filenames), shard_options=grain.NoSharding())
loader = grain.DataLoader(
  data_source=ds,
  operations=[ParseFeatures(), grain.Batch(5)],
  sampler=sampler,
  worker_buffer_size=1000
)

The problem

I benchmark the resulting loader with tfds.benchmark(loader, batch_size=5) and I'm getting 3 examples per second, which seems really slow. Manually looping through the DataLoader and timing it is not any better, so I don't think this is a bug with the benchmark.

Reading each individual numPy file from the filesystem with numpy.load yields about 140 examples per second.

In an identical setup where I use tf.io.TFRecordWriter in my data conversion step, load it all as a TF Dataset and then benchmark it as follows:

ds = ds.batch(5, num_parallel_calls=5)
ds = ds.as_numpy_iterator()
tfds.benchmark(ds, num_iter=990, batch_size=5)

then I get roughly 130 samples per second, which isn't great but it's at least close to the naive solution of reading directly from the disk.

Without conversion to numPy / deserialisation, it's faster but not as fast as I'd expect. I'm getting around 53 examples per second without the ParseFeatures() operation. Also, I tried setting worker_count= in the DataLoader but I get an error "Processing Failed. Shutting down.". Though that is probably worth its own issue.

TLDR

I'm trying to load a few thousand big arrays (each float32, shape=(100,500,99)) from ArrayRecord files with Grain but it's slow. Slower than TFRecords and TFDataset and slower than just loading from disk directly.

Reproduction notebook here

Am I missing the point of Grain / is it just not a good fit for my use case? Or are some of my settings wrong (shard size / buffer size / serialisation strategy)?

I'm using grain_nightly==0.0.6 and array_record==0.5.0. I'm on a 1 TB NVMe SSD and have a Ryzen 9 7950X CPU with 64GB of DDR5 RAM on Linux.

@rainx0r
Copy link
Author

rainx0r commented Mar 31, 2024

Doing a bit more testing, it seems like my bottleneck might be using the tfds.features module for serialisation. Doing a simple

writer.write(_arr.tobytes())

when writing ArrayRecords and deserialising with

@dataclasses.dataclass
class ParseFeatures(grain.MapTransform):
  def map(self, record):
    return np.reshape(np.frombuffer(record, dtype=np.float32), (100,500,99))

gives me around 130-140 examples / second on a single worker.

Is there a recommended way to serialise / deserialise data for use with grain?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant