Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a wrapper for search_preassigned in torch_utils #3909

Open
mdouze opened this issue Oct 4, 2024 · 3 comments
Open

add a wrapper for search_preassigned in torch_utils #3909

mdouze opened this issue Oct 4, 2024 · 3 comments

Comments

@mdouze
Copy link
Contributor

mdouze commented Oct 4, 2024

There are wrappers for search and search_and_reconstruct to accept torch arrays but not for search_preassigned (that was added later).

It would be useful to implement it, here:

https://github.com/facebookresearch/faiss/blob/main/contrib/torch_utils.py#L220

@mdouze
Copy link
Contributor Author

mdouze commented Oct 4, 2024

@mdouze
Copy link
Contributor Author

mdouze commented Oct 4, 2024

The following code works:

    def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if D is None:
            D = torch.empty(n, k, device=x.device, dtype=torch.float32)
        else:
            assert type(D) is torch.Tensor
            assert D.shape == (n, k)
        D_ptr = swig_ptr_from_FloatTensor(D)

        if I is None:
            I = torch.empty(n, k, device=x.device, dtype=torch.int64)
        else:
            assert type(I) is torch.Tensor
            assert I.shape == (n, k)
        I_ptr = swig_ptr_from_IndicesTensor(I)

        assert Iq.shape == (n, self.nprobe)
        Iq = Iq.contiguous()
        Iq_ptr = swig_ptr_from_IndicesTensor(Iq)

        if Dq is not None:
            Dq = Dq.contiguous()
            assert Dq.shape == Iq.shape        
            Dq_ptr = swig_ptr_from_FloatTensor(Dq)
        else: 
            Dq_ptr = None

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
        else:
            # CPU torch
            self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)

        return D, I

    torch_replace_method(the_class, 'search_preassigned', torch_replacement_search_preassigned)

@mdouze
Copy link
Contributor Author

mdouze commented Oct 7, 2024

implemented in #3916

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants