Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add "mutable arrays" to capabilities #845

Open
jakevdp opened this issue Sep 23, 2024 · 14 comments
Open

RFC: add "mutable arrays" to capabilities #845

jakevdp opened this issue Sep 23, 2024 · 14 comments
Labels
API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Inspection Array API inspection.
Milestone

Comments

@jakevdp
Copy link

jakevdp commented Sep 23, 2024

Several parts of the Array API standard assume that array objects are mutable. Some array API implementations (notably JAX) do not support mutating array objects. This has led to array API implementations currently being developed in scipy and sklearn to be entirely unusable in JAX.

Given this, downstream implementations have a few choices:

  1. Use mutability semantics, excluding libraries like JAX.
  2. Avoid mutability semantics to support libraries like JAX.
  3. Explicitly special-case arrays of type jax.numpy.Array, changing the implementation logic for that case.

(1) is a bad choice, because it means JAX will not be supported. (2) is a bad choice, because for libraries like NumPy, it leads to excessive copying of buffers, worsening performance. (3) is a bad choice because it hard-codes the presence of specific implementations in a context that is supposed to be implementation-agnostic.

One way the Array API standard could address this is by adding "mutable arrays" or something similar to the existing capabilities dict. Then downstream implementations could use strategy (3) without special-casing particular implementations.

@jakevdp
Copy link
Author

jakevdp commented Sep 23, 2024

(to anticipate one response: no, it's not possible to make JAX arrays support mutation: central to JAX are transformations like jit, vmap, grad, etc. that rely on immutability assumptions in their program tracing)

@lucascolley
Copy link
Contributor

lucascolley commented Sep 23, 2024

For (3), could you prototype what it would look like in the case of gh-609? For capabilities["mutable arrays"] == True, we use the NumPy syntax x[i] += y. For capabilities["mutable arrays"] == False, we use ...? This would require standardising a way to do this for immutable arrays, right? Or can we just use xp.where?

@leofang
Copy link
Contributor

leofang commented Sep 23, 2024

Several parts of the Array API standard assume that array objects are mutable.

This is very surprising. It would be nice if we can have a list of such occurrences here, because this was not supposed to happen as per our design guideline https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html

@jakevdp
Copy link
Author

jakevdp commented Sep 23, 2024

This is very surprising. It would be nice if we can have a list of such occurrences here,

The main example is __setitem__, which as far as I can tell is supported in the standard.

For (3), could you prototype what it would look like in the case of gh-609?

For example, it could look something like this in the specific case of updating an array with a mask and a scalar:

info =  xp.__array_namespace_info__().capabilities()

if info['mutable arrays']:
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

That would certainly not cover all cases, but it would be enough to fix a large number of the incompatibilities being currently introduced into scipy and scikit-learn.

But in general, yes, it would also be beneficial if the array API standard could add some syntax for out-of-place array updates similar to what's being discussed in #609.

@kgryte kgryte added this to the v2024 milestone Sep 24, 2024
@kgryte kgryte changed the title Proposal: add "mutable arrays" to capabilities RFC: add "mutable arrays" to capabilities Sep 24, 2024
@kgryte kgryte added RFC Request for comments. Feature requests and proposed changes. API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. topic: Inspection Array API inspection. labels Sep 24, 2024
@pearu
Copy link

pearu commented Oct 1, 2024

Adding "mutable arrays" to library level capabilities is sub-optimal for libraries that support both mutable and immutable arrays. For example, numpy arrays have flags.writable attribute bit that signals if an array should be considered as mutable or immutable.
What about adding an array object level flags to Array API, something similar to numpy.ndarray.flags?

@rgommers
Copy link
Member

rgommers commented Oct 2, 2024

Agreed with @pearu's comment. There are multiple other issues here though, for example:

(1) What does it mean to be a "mutable array"? To stay with the numpy.ndarray.flags example:

>>> import numpy as np
>>> x = np.arange(5)
>>> y = x[:3]
>>> y.flags.writeable = False
>>> y += 1
...
ValueError: output array is read-only

>>> y[0]
np.int64(0)
>>> x += 1
>>> y[0]
np.int64(1)

So is y mutable? I guess you'd have said no - but its values can still easily change. So there's no right answer here for numpy.ndarray right now.

(2) JAX you'd argue is immutable I'm sure, however as we saw in the example above numpy readonly arrays reject in-place operators like += while JAX doesn't:

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> x[0]
Array(0, dtype=int32)
>>> x += 1
>>> x[0]
Array(1, dtype=int32)

So I'd say "is a mutable array" is quite ambiguous.

This is very surprising. It would be nice if we can have a list of such occurrences here,

The main example is __setitem__, which as far as I can tell is supported in the standard.

__setitem__ is the one and only painful example for JAX here I believe, however it is not the case that it cannot be implemented in JAX. Nor that it's incompatible with immutability. It's a complex topic, but https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html covers it. The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array. The reason JAX never implemented slice/item assignment (with discussions going all the way back to gh-24) is that doing so would be confusing given the mismatch in semantics with NumPy. But it's NumPy semantics that are undefined behavior as soon as views play a role, the JAX design is perfectly fine.


I think that we should not add .flags unless there's a real value-add. At the moment, the motivating code example is best written for one or more specific libraries like:

if is_jax(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

@pearu
Copy link

pearu commented Oct 2, 2024

(1) ... So is y mutable? I guess you'd have said no

Yes, y is not mutable because you cannot mutate it via operations on y.

  • but its values can still easily change.

In general, there exists always ways to mutate data of immutable objects. One can even mutate JAX arrays easily via dlpack or array interface protocols.

In this specific case, the example demonstrates a common practice of viewing data as read-only while the data could still be modified at some other level or time. For instance, one can open a file in read-only mode and in this context the file descriptor would represent an immutable object while some other process may open the same file in a writable mode which would enable mutations.

(2) JAX you'd argue is immutable I'm sure, however as we saw in the example above numpy readonly arrays reject in-place operators like += while JAX doesn't:

This means that numpy and JAX implement different semantics for in-place operations: for numpy, in-place operation is a mutable operation while for JAX, the in-place operation is a syntactic sugar for transformations: x = op(x, y).

So I'd say "is a mutable array" is quite ambiguous.

I'd disagree. By definition, an "array" is a certain view of (contiguous) data that elements can be accessed via indexing operation.
So, a mutable/immutable array is an array that allows/disallows data mutations via indexing operations. Even if there exists other ways for mutating underlying data (say, via direct memory access, via cosmic rays, etc), these mutation will happen out of context of mutable/immutable arrays usages.

...

You probably meant to write:

if not is_jax(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

I find using "jax" in the name of utility predicate function suboptimal because JAX arrays are not the only array objects that are immutable. So, instead of introducing is_jax, I suggest:

if is_mutable(x):
  x[xp.isnan(x)] = 0
else:
  x = xp.where(xp.isnan(x), 0, x)

so that the same code in scipy/... will not be needed to be modified when one invents another Array API compliant array object that is immutable: it will be sufficient to update only the definition of is_mutable.

@lucascolley
Copy link
Contributor

lucascolley commented Oct 2, 2024

if is_mutable(x):

The problem is that we now need to fetch a capability of the array, rather than the array namespace, since NumPy has both behaviours. So what way could we address this if not for an info or flags method on arrays (or simply refusing to handle immutable NumPy arrays correctly)?

@jakevdp
Copy link
Author

jakevdp commented Oct 2, 2024

The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array.

Thinking about this a bit, I think the language about "views" in https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html is not quite strong enough.

Let's assume "in-place update" is equivalent to x[0] = 1 in NumPy, and "out-of-place update" is equivalent to x = x.copy(); x[0] = 1 in NumPy, or to x = x.at[0].set(1) in JAX.

The kind of example I have in mind is this:

x = xp.zeros(2)
L = [x]
x[0] = 1
print(L)

What is the result here? If x[0] = 1 has in-place update semantics, then this prints [array([1., 0.])]. If x[0] = 1 has out-of-place update semantics, then this prints [array([0., 0.])].

So the equivalence of in-place and out-of-place semantics doesn't just require the absence of array views in the sense of what's tracked by x.flags.owndata, it also requires that the array's Python refcount be exactly equal to 1.

@pearu
Copy link

pearu commented Oct 2, 2024

Recall, Python __setitem__ does not support returning non-None values, that is, the assumption in #845 (comment) that one would be able to implement __setitem__ such that

x[0] = 1

uses JAX's semantics (x = x.at[0].set(1)), is simply impossible in Python (a Python object cannot reset its reference by itself, even not so by modifying the parent frame locals, IIRC). Of course, under jitting everything is possible but I assume that we want to keep the semantics of Python and jitted functions the same.

When modifying the @jakevdp example as follows:

x = xp.zeros(2)
L = [x]
L[0][0] = 1
print(L)

the expected output would be [array([1., 0.])] and since I cannot see how the out-of-place semantics would be possible to support technically, I think there is no alternative output.

@jakevdp
Copy link
Author

jakevdp commented Oct 2, 2024

and since I cannot see how the out-of-place semantics would be possible to support technically, I think there is no alternative output.

Sorry, I think you misunderstood my point. I wasn't arguing that x[0] = 1 should have out-of-place semantics. I was using this as an example of where in-place and out-of-place semantics differ in a way not already identified by the doc I linked to. If it helps, you can replace x[0] = 1 with *** in that code example, and substitute the appropriate code to get either in-place or out-of-place semantics, and to see the difference in their outputs.

@pearu
Copy link

pearu commented Oct 2, 2024

If it helps, you can replace x[0] = 1 with *** in that code example

Ok, fair enough. A better example would be that uses, say, some in-place operation (__iadd__, etc.) that methods support non-None return values. Consider my comment above as a reminder that __setitem__ and any of the inplace operation methods are not equivalent in terms of discussing in-place and out-of-place semantics.

@jakevdp
Copy link
Author

jakevdp commented Oct 2, 2024

Rather than getting lost in implementation details, let's bring it back to the statement I was responding to:

The key point is that there's no semantic difference between updating values in-place or out-of-place as long as the update modifies only a single array.

I think this is untrue, unless you also consider Python-level references as well as views when reasoning about whether an operation affects a "single array". And limiting operations to objects with a refcount of 1 is far more intrusive than limiting operations to arrays whose buffer is not shared with any other array objects.

@rgommers
Copy link
Member

rgommers commented Oct 3, 2024

I think this is untrue, unless you also consider Python-level references as well as views when reasoning about whether an operation affects a "single array".

Yes, that is a good point, and I agree that that page should be more explicit and grow a section on Python refcount >1. The behavior difference applies not only to __setitem__, but to all in-place operators as well.

And limiting operations to objects with a refcount of 1 is far more intrusive

Agreed. I think (but am not sure, have to give it some more thought) is that that should remain undefined behavior - it's kinda baked into the Python language, and it's already a difference today between JAX and NumPy/PyTorch today for += & co.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Inspection Array API inspection.
Projects
None yet
Development

No branches or pull requests

6 participants