Skip to content

Commit

Permalink
Support search_preassigned in torch (#3916)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3916

Adding missing wrapper to the torch wrappers in Faiss + test it.

Also factorized a  bit of code between search functions.

Reviewed By: algoriddle

Differential Revision: D63974821

fbshipit-source-id: a0415a57a763e2d1896956c503e503615c167860
  • Loading branch information
mdouze authored and facebook-github-bot committed Oct 8, 2024
1 parent be4fc8e commit 2e6551f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 21 deletions.
72 changes: 51 additions & 21 deletions contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,8 @@ def torch_replacement_train(self, x):
# CPU torch
self.train_c(n, x_ptr)

def torch_replacement_search(self, x, k, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_numpy(x, k, D=D, I=I)

assert type(x) is torch.Tensor
def search_methods_common(x, k, D, I):
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)

if D is None:
Expand All @@ -241,6 +235,19 @@ def torch_replacement_search(self, x, k, D=None, I=None):
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)

return x_ptr, D_ptr, I_ptr, D, I

def torch_replacement_search(self, x, k, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_numpy(x, k, D=D, I=I)

assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d

x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)

if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

Expand All @@ -261,21 +268,8 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)

if D is None:
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
else:
assert type(D) is torch.Tensor
assert D.shape == (n, k)
D_ptr = swig_ptr_from_FloatTensor(D)

if I is None:
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(I) is torch.Tensor
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)

if R is None:
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
Expand All @@ -296,6 +290,40 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)

return D, I, R

def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)

assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d

x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)

assert Iq.shape == (n, self.nprobe)
Iq = Iq.contiguous()
Iq_ptr = swig_ptr_from_IndicesTensor(Iq)

if Dq is not None:
Dq = Dq.contiguous()
assert Dq.shape == Iq.shape
Dq_ptr = swig_ptr_from_FloatTensor(Dq)
else:
Dq_ptr = None

if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
else:
# CPU torch
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)

return D, I

def torch_replacement_remove_ids(self, x):
# Not yet implemented
assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
Expand Down Expand Up @@ -495,6 +523,8 @@ def torch_replacement_sa_decode(self, codes, x=None):
ignore_missing=True)
torch_replace_method(the_class, 'search_and_reconstruct',
torch_replacement_search_and_reconstruct, ignore_missing=True)
torch_replace_method(the_class, 'search_preassigned',
torch_replacement_search_preassigned, ignore_missing=True)
torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)

Expand Down
26 changes: 26 additions & 0 deletions tests/torch_test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,32 @@ def test_search_and_reconstruct(self):
self.assertTrue(torch.equal(I, I_input))
self.assertTrue(torch.equal(R, R_input))

def test_search_preassigned(self):
ds = datasets.SyntheticDataset(32, 1000, 100, 10)
index = faiss.index_factory(32, "IVF20,PQ4np")
index.train(ds.get_train())
index.add(ds.get_database())
index.nprobe = 4
Dref, Iref = index.search(ds.get_queries(), 10)
quantizer = faiss.clone_index(index.quantizer)

# mutilate the index' quantizer
index.quantizer.reset()
index.quantizer.add(np.zeros((20, 32), dtype='float32'))

# test numpy codepath
Dq, Iq = quantizer.search(ds.get_queries(), 4)
Dref2, Iref2 = index.search_preassigned(ds.get_queries(), 10, Iq, Dq)
np.testing.assert_array_equal(Iref, Iref2)
np.testing.assert_array_equal(Dref, Dref2)

# test torch codepath
xq = torch.from_numpy(ds.get_queries())
Dq, Iq = quantizer.search(xq, 4)
Dref2, Iref2 = index.search_preassigned(xq, 10, Iq, Dq)
np.testing.assert_array_equal(Iref, Iref2.numpy())
np.testing.assert_array_equal(Dref, Dref2.numpy())

# tests sa_encode, sa_decode
def test_sa_encode_decode(self):
d = 16
Expand Down

0 comments on commit 2e6551f

Please sign in to comment.