Skip to content

Commit

Permalink
fixes for assigniung vlen arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed May 6, 2024
1 parent e7b5ef7 commit 69640c1
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 48 deletions.
4 changes: 3 additions & 1 deletion h5pyd/_apps/utillib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@ def write_dataset(src, tgt, ctx):
logging.debug(f"src dtype: {src.dtype}")
logging.debug(f"des dtype: {tgt.dtype}")

empty_arr = None
for src_s in it:
logging.debug(f"src selection: {src_s}")
if rank == 1 and isinstance(src_s, slice):
Expand Down Expand Up @@ -1377,7 +1378,8 @@ def write_dataset(src, tgt, ctx):

arr = src[src_s]
# don't write arr if it's all zeros (or the fillvalue if defined)
empty_arr = np.zeros(arr.shape, dtype=arr.dtype)
if empty_arr is None or empty_arr.shape != arr.shape:
empty_arr = np.zeros(arr.shape, dtype=arr.dtype)
if fillvalue:
empty_arr.fill(fillvalue)
if np.array_equal(arr, empty_arr):
Expand Down
46 changes: 22 additions & 24 deletions h5pyd/_hl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def getElementSize(e, dt):
"""
Get number of byte needed for given element as a bytestream
"""
# print(f"getElementSize - e: {e} dt: {dt} itemsize: {dt.itemsize}")

if len(dt) > 1:
count = 0
for name in dt.names:
Expand All @@ -339,12 +339,8 @@ def getElementSize(e, dt):
else:
# variable length element
vlen = dt.metadata["vlen"]
if isinstance(e, int):
if e == 0:
count = 4 # non-initialized element
else:
raise ValueError(f"Unexpected value: {e}")
elif isinstance(e, bytes):

if isinstance(e, bytes):
count = len(e) + 4
elif isinstance(e, str):
count = len(e.encode('utf-8')) + 4
Expand All @@ -353,7 +349,7 @@ def getElementSize(e, dt):
if e.dtype.kind != 'O':
count = e.dtype.itemsize * nElements
else:
count = nElements * vlen.itemsize # tbd - special case for strings?
count = nElements * vlen.itemsize
count += 4 # byte count
elif isinstance(e, list) or isinstance(e, tuple):
if not e:
Expand All @@ -362,8 +358,12 @@ def getElementSize(e, dt):
else:
count = len(e) * vlen.itemsize + 4 # +4 for byte count
else:
# uninitialized element
if e and not np.isnan(e):
raise ValueError(f"Unexpected value: {e}")
else:
count = 4 # non-initialized element

raise TypeError(f"unexpected type: {type(e)}")
return count


Expand All @@ -372,7 +372,7 @@ def getByteArraySize(arr):
Get number of bytes needed to store given numpy array as a bytestream
"""
if not isVlen(arr.dtype) and arr.dtype.kind != 'O':
print("not isVlen")
# not vlen just return itemsize * number of elements
return arr.itemsize * np.prod(arr.shape)
nElements = int(np.prod(arr.shape))
# reshape to 1d for easier iteration
Expand All @@ -381,6 +381,7 @@ def getByteArraySize(arr):
count = 0
for e in arr1d:
count += getElementSize(e, dt)

return count


Expand Down Expand Up @@ -418,13 +419,7 @@ def copyElement(e, dt, buffer, offset, vlen=None):
offset = copyBuffer(e_buf, buffer, offset)
else:
# variable length element
if isinstance(e, int):
if e == 0:
# write 4-byte integer 0 to buffer
offset = copyBuffer(b'\x00\x00\x00\x00', buffer, offset)
else:
raise ValueError("Unexpected value: {}".format(e))
elif isinstance(e, bytes):
if isinstance(e, bytes):
count = np.int32(len(e))
offset = copyBuffer(count.tobytes(), buffer, offset)
offset = copyBuffer(e, buffer, offset)
Expand All @@ -451,9 +446,6 @@ def copyElement(e, dt, buffer, offset, vlen=None):
arr = np.asarray(arr1d, dtype=vlen)
offset = copyBuffer(arr.tobytes(), buffer, offset)

# for item in arr1d:
# offset = copyElement(item, dt, buffer, offset)

elif isinstance(e, list) or isinstance(e, tuple):
count = np.int32(len(e) * vlen.itemsize)
offset = copyBuffer(count.tobytes(), buffer, offset)
Expand All @@ -464,7 +456,12 @@ def copyElement(e, dt, buffer, offset, vlen=None):
offset = copyBuffer(arr.tobytes(), buffer, offset)

else:
raise TypeError("unexpected type: {}".format(type(e)))
# uninitialized variable length element
if e and not np.isnan(e):
raise ValueError(f"Unexpected value: {e}")
else:
# write 4-byte integer 0 to buffer
offset = copyBuffer(b'\x00\x00\x00\x00', buffer, offset)
# print("buffer: {}".format(buffer))
return offset

Expand Down Expand Up @@ -551,11 +548,12 @@ def arrayToBytes(arr, vlen=None):
# can just return normal numpy bytestream
return arr.tobytes()

nSize = getByteArraySize(arr)
buffer = bytearray(nSize)
offset = 0
nElements = int(np.prod(arr.shape))
arr1d = arr.reshape((nElements,))
nSize = getByteArraySize(arr1d)
buffer = bytearray(nSize)
offset = 0

for e in arr1d:
offset = copyElement(e, arr1d.dtype, buffer, offset, vlen=vlen)
return buffer
Expand Down
28 changes: 14 additions & 14 deletions h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,15 +1332,15 @@ def __setitem__(self, args, val):
match.
"""
self.log.info(f"Dataset __setitem__, args: {args}")

use_base64 = True # may need to set this to false below for some types

args = args if isinstance(args, tuple) else (args,)

# get the val dtype if we're passed a numpy array
try:
self.log.debug(
f"val dtype: {val.dtype}, shape: {val.shape} metadata: {val.dtype.metadata}"
)
msg = f"val dtype: {val.dtype}, shape: {val.shape} metadata: {val.dtype.metadata}"
self.log.debug(msg)
if numpy.prod(val.shape) == 0:
self.log.info("no elements in numpy array, skipping write")
except AttributeError:
Expand Down Expand Up @@ -1370,20 +1370,22 @@ def __setitem__(self, args, val):
# side. However, for compound literals this is unavoidable.
# For h5pyd, do extra check and convert type on client side for efficiency
vlen = check_dtype(vlen=self.dtype)
if vlen is not None and vlen not in (bytes, str):

if not isinstance(val, numpy.ndarray) and vlen is not None and vlen not in (bytes, str):
try:
val = numpy.asarray(val, dtype=vlen)

except ValueError:
except ValueError as ve:
self.log.debug(f"asarray ValueError: {ve}")
try:
val = numpy.array(
[numpy.array(x, dtype=self.dtype) for x in val],
dtype=self.dtype,
)
except ValueError as e:
self.log.debug(
f"valueError converting value element by element: {e} "
)
msg = f"ValueError converting value element by element: {e}"
self.log.debug(msg)

if vlen == val.dtype:
if val.ndim > 1:
tmp = numpy.empty(shape=val.shape[:-1], dtype=self.dtype)
Expand All @@ -1405,9 +1407,9 @@ def __setitem__(self, args, val):
isinstance(val, complex) or getattr(getattr(val, "dtype", None), "kind", None) == "c"
):
if self.dtype.kind != "V" or self.dtype.names != ("r", "i"):
raise TypeError(
f"Wrong dataset dtype for complex number values: {self.dtype.fields}"
)
msg = f"Wrong dataset dtype for complex number values: {self.dtype.fields}"
raise TypeError(msg)

if isinstance(val, complex):
val = numpy.asarray(val, dtype=type(val))
tmp = numpy.empty(shape=val.shape, dtype=self.dtype)
Expand Down Expand Up @@ -1447,6 +1449,7 @@ def __setitem__(self, args, val):
tmp[...] = val[...]
val = tmp
else:
self.log.debug(f"asarray for {self.dtype}")
val = numpy.asarray(val, order="C", dtype=self.dtype)

# Check for array dtype compatibility and convert
Expand Down Expand Up @@ -1474,9 +1477,6 @@ def __setitem__(self, args, val):
raise ValueError(f"Illegal slicing argument (fields {mismatch} not in dataset type)")

# Use mtype derived from array (let DatasetID.write figure it out)
else:
mshape = val.shape
# mtype = None

mshape = val.shape
self.log.debug(f"mshape: {mshape}")
Expand Down
92 changes: 83 additions & 9 deletions test/hl/test_vlentype.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_create_vlen_dset(self):
g1_3 = g1.create_group('g1_3')
g1_3.attrs["name"] = 'g1_3'

# create a dataset that is a VLEN int32
# create a dataset that is a VLEN int16
dtvlen = h5py.special_dtype(vlen=np.dtype('uint16'))

dset1 = f.create_dataset("dset1", shape=(2,), dtype=dtvlen)
Expand Down Expand Up @@ -260,9 +260,7 @@ def test_create_vlen_2d_dset(self):
def test_variable_len_str_attr(self):
filename = self.getFileName("variable_len_str_dset")
print("filename:", filename)
if config.get("use_h5py") and False:
# TBD - skipping as this core dumps in travis for some reason
return

f = h5py.File(filename, "w")

words = (b"one", b"two", b"three", b"four", b"five", b"six", b"seven", b"eight", b"nine", b"ten")
Expand All @@ -288,9 +286,7 @@ def test_variable_len_str_attr(self):
def test_variable_len_str_dset(self):
filename = self.getFileName("variable_len_str_dset")
print("filename:", filename)
if config.get("use_h5py") and False:
# TBD - skipping as this core dumps in travis for some reason
return

f = h5py.File(filename, "w")

dims = (10,)
Expand Down Expand Up @@ -323,6 +319,84 @@ def test_variable_len_str_dset(self):

f.close()

def test_variable_len_float_dset(self):
filename = self.getFileName("variable_len_float_dset")
print("filename:", filename)

f = h5py.File(filename, "w")

dims = (2,)
dtvlen = h5py.special_dtype(vlen=float)
dset = f.create_dataset('variable_len_float_dset', dims, dtype=dtvlen)

self.assertEqual(dset.name, "/variable_len_float_dset")
self.assertTrue(isinstance(dset.shape, tuple))
self.assertEqual(len(dset.shape), 1)
self.assertEqual(dset.shape[0], 2)
self.assertEqual(str(dset.dtype), 'object')
self.assertTrue(isinstance(dset.maxshape, tuple))
self.assertEqual(len(dset.maxshape), 1)
self.assertEqual(dset.maxshape[0], 2)
self.assertFalse(dset.fillvalue) # will be 0 for HSDS, None for h5py
ret_val = dset[...]
self.assertTrue(isinstance(ret_val, np.ndarray))
self.assertEqual(len(ret_val), 2)
e0 = ret_val[0]
self.assertTrue(isinstance(e0, np.ndarray))
self.assertEqual(e0.shape, (0,))

e0 = np.array([1.1, 2.2, 3.3], dtype=np.float64)
e1 = np.array([1.9, 2.8, 3.7], dtype=np.float64)

data = np.array([e0, e1], dtype=dtvlen)
try:
# This will fail on HSDS because data is a ndarray of shape (2,3) of floats
dset[...] = data
if isinstance(dset.id.id, str):
# id is str for HSDS, int for h5py
self.assertTrue(False)
except ValueError:
pass # expected

data = np.zeros((2,), dtype=dtvlen)
data[0] = e0
data[1] = e1

# write data
# In this case, data is a ndarray of ndarrays
if isinstance(dset.id.id, str):
# and this is failing on h5py because h5py is try to
# broadcast (2,3) to (2,)
dset[...] = data
else:
dset[0] = e0
dset[1] = e1

# read back data
ret_val = dset[...]
self.assertTrue(isinstance(ret_val, np.ndarray))
self.assertEqual(len(ret_val), 2)
self.assertTrue(isinstance(ret_val[0], np.ndarray))
self.assertEqual(list(ret_val[0]), [1.1, 2.2, 3.3])
self.assertEqual(ret_val[0].dtype, np.dtype('float'))
self.assertTrue(isinstance(ret_val[1], np.ndarray))
self.assertEqual(ret_val[1].dtype, np.dtype('float'))

self.assertEqual(list(ret_val[1]), [1.9, 2.8, 3.7])

# Read back just one element
e0 = dset[0]
self.assertEqual(len(e0), 3)
self.assertEqual(list(e0), [1.1, 2.2, 3.3])

# try writing float lists into dataset
data = [42.24,]
dset[0] = data
ret_val = dset[...]
self.assertEqual(list(ret_val[0]), [42.24,])

f.close()

def test_variable_len_unicode_dset(self):
filename = self.getFileName("variable_len_unicode_dset")
print("filename:", filename)
Expand Down Expand Up @@ -376,7 +450,7 @@ def test_variable_len_unicode_attr(self):
f.attrs.create('a1', words, shape=dims, dtype=dt)

vals = f.attrs["a1"] # read back
# print("type:", type(vals))

self.assertTrue("vlen" in vals.dtype.metadata)

for i in range(10):
Expand All @@ -387,6 +461,6 @@ def test_variable_len_unicode_attr(self):


if __name__ == '__main__':
loglevel = logging.ERROR
loglevel = logging.DEBUG
logging.basicConfig(format='%(asctime)s %(message)s', level=loglevel)
ut.main()

0 comments on commit 69640c1

Please sign in to comment.