Skip to content

Commit

Permalink
Merge pull request #200 from mattjala/vlen_fix
Browse files Browse the repository at this point in the history
Fix variable length writes from non-standard shapes
  • Loading branch information
mattjala authored May 16, 2024
2 parents d670f8a + dfa7711 commit 5d98e31
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 75 deletions.
25 changes: 15 additions & 10 deletions h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,25 +1370,29 @@ def __setitem__(self, args, val):
# Generally we try to avoid converting the arrays on the Python
# side. However, for compound literals this is unavoidable.
# For h5pyd, do extra check and convert type on client side for efficiency
vlen = check_dtype(vlen=self.dtype)

if not isinstance(val, numpy.ndarray) and vlen is not None and vlen not in (bytes, str):
vlen_base_class = check_dtype(vlen=self.dtype)
if vlen_base_class is not None and vlen_base_class not in (bytes, str):
try:
val = numpy.asarray(val, dtype=vlen)
# Attempt to directly convert the input array of vlen data to its base class
val = numpy.asarray(val, dtype=vlen_base_class)

except ValueError as ve:
# Failed to convert input array to vlen base class directly, instead create a new array where
# each element is an array of the Dataset's dtype
self.log.debug(f"asarray ValueError: {ve}")
try:
val = numpy.array(
[numpy.array(x, dtype=self.dtype) for x in val],
dtype=self.dtype,
)
# Force output shape
tmp = numpy.empty(shape=val.shape, dtype=self.dtype)
tmp[:] = [numpy.array(x, dtype=self.dtype) for x in val]
val = tmp
except ValueError as e:
msg = f"ValueError converting value element by element: {e}"
self.log.debug(msg)

if vlen == val.dtype:
if vlen_base_class == val.dtype:
if val.ndim > 1:
# Reshape array to 2D, where first dim = product of all dims except last, and second dim = last dim
# Then flatten it to 1D
tmp = numpy.empty(shape=val.shape[:-1], dtype=self.dtype)
tmp.ravel()[:] = [
i
Expand Down Expand Up @@ -1435,6 +1439,7 @@ def __setitem__(self, args, val):
else:
dtype = self.dtype
cast_compound = False

val = numpy.asarray(val, dtype=dtype, order="C")
if cast_compound:
val = val.astype(numpy.dtype([(names[0], dtype)]))
Expand Down Expand Up @@ -1521,7 +1526,7 @@ def __setitem__(self, args, val):
if self.id.uuid.startswith("d-"):
# server is HSDS, use binary data, use param values for selection
format = "binary"
body = arrayToBytes(val, vlen=vlen)
body = arrayToBytes(val, vlen=vlen_base_class)
self.log.debug(f"writing binary data, {len(body)}")
else:
# h5serv, base64 encode, body json for selection
Expand Down
Loading

0 comments on commit 5d98e31

Please sign in to comment.