Skip to content

Commit

Permalink
implement ST_norm_from_LUT for the ResidualQuantizer
Browse files Browse the repository at this point in the history
Summary:
The norm computation ST_norm_from_LUT was not implemented in Faiss. See issue

facebookresearch#3882

This diff adds an implementation for it. It is probably not very quick. A few precomputed tables for AdditiveQuantizer were moved form ResidualQuantizer.

Differential Revision: D63975689
  • Loading branch information
mdouze authored and facebook-github-bot committed Oct 7, 2024
1 parent d2692b8 commit eec426d
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 58 deletions.
1 change: 1 addition & 0 deletions faiss/IndexAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ void IndexAdditiveQuantizer::search(
DISPATCH(ST_norm_qint8)
DISPATCH(ST_norm_qint4)
DISPATCH(ST_norm_cqint4)
DISPATCH(ST_norm_from_LUT)
case AdditiveQuantizer::ST_norm_cqint8:
case AdditiveQuantizer::ST_norm_lsq2x4:
case AdditiveQuantizer::ST_norm_rq2x4:
Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexIVFAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
return new AQInvertedListScannerLUT<false, AdditiveQuantizer::st>( \
*this, store_pairs);
A(ST_LUT_nonorm)
// A(ST_norm_from_LUT)
A(ST_norm_from_LUT)
A(ST_norm_float)
A(ST_norm_qint8)
A(ST_norm_qint4)
Expand Down
85 changes: 77 additions & 8 deletions faiss/impl/AdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,40 @@ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
}
}

void AdditiveQuantizer::compute_codebook_tables() {
centroid_norms.resize(total_codebook_size);
fvec_norms_L2sqr(
centroid_norms.data(), codebooks.data(), d, total_codebook_size);
size_t cross_table_size = 0;
for (int m = 0; m < M; m++) {
size_t K = (size_t)1 << nbits[m];
cross_table_size += K * codebook_offsets[m];
}
codebook_cross_products.resize(cross_table_size);
size_t ofs = 0;
for (int m = 1; m < M; m++) {
FINTEGER ki = (size_t)1 << nbits[m];
FINTEGER kk = codebook_offsets[m];
FINTEGER di = d;
float zero = 0, one = 1;
assert(ofs + ki * kk <= cross_table_size);
sgemm_("Transposed",
"Not transposed",
&ki,
&kk,
&di,
&one,
codebooks.data() + d * kk,
&di,
codebooks.data(),
&di,
&zero,
codebook_cross_products.data() + ofs,
&ki);
ofs += ki * kk;
}
}

namespace {

// TODO
Expand Down Expand Up @@ -471,7 +505,6 @@ namespace {
float accumulate_IPs(
const AdditiveQuantizer& aq,
BitstringReader& bs,
const uint8_t* codes,
const float* LUT) {
float accu = 0;
for (int m = 0; m < aq.M; m++) {
Expand All @@ -483,6 +516,29 @@ float accumulate_IPs(
return accu;
}

float compute_norm_from_LUT(const AdditiveQuantizer& aq, BitstringReader& bs) {
float accu = 0;
std::vector<int> idx(aq.M);
const float* c = aq.codebook_cross_products.data();
for (int m = 0; m < aq.M; m++) {
size_t nbit = aq.nbits[m];
int i = bs.read(nbit);
size_t K = 1 << nbit;
idx[m] = i;

accu += aq.centroid_norms[aq.codebook_offsets[m] + i];

for (int l = 0; l < m; l++) {
int j = idx[l];
accu += 2 * c[j * K + i];
c += (1 << aq.nbits[l]) * K;
}
}
// FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() +
// aq.codebook_cross_products.size());
return accu;
}

} // anonymous namespace

template <>
Expand All @@ -491,7 +547,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
return accumulate_IPs(*this, bs, codes, LUT);
return accumulate_IPs(*this, bs, LUT);
}

template <>
Expand All @@ -500,7 +556,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
return -accumulate_IPs(*this, bs, codes, LUT);
return -accumulate_IPs(*this, bs, LUT);
}

template <>
Expand All @@ -509,7 +565,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(32);
float norm2;
memcpy(&norm2, &norm_i, 4);
Expand All @@ -522,7 +578,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(8);
float norm2 = decode_qcint(norm_i);
return norm2 - 2 * accu;
Expand All @@ -534,7 +590,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(4);
float norm2 = decode_qcint(norm_i);
return norm2 - 2 * accu;
Expand All @@ -546,7 +602,7 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(8);
float norm2 = decode_qint8(norm_i, norm_min, norm_max);
return norm2 - 2 * accu;
Expand All @@ -558,10 +614,23 @@ float AdditiveQuantizer::
const uint8_t* codes,
const float* LUT) const {
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, codes, LUT);
float accu = accumulate_IPs(*this, bs, LUT);
uint32_t norm_i = bs.read(4);
float norm2 = decode_qint4(norm_i, norm_min, norm_max);
return norm2 - 2 * accu;
}

template <>
float AdditiveQuantizer::
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_from_LUT>(
const uint8_t* codes,
const float* LUT) const {
FAISS_THROW_IF_NOT(codebook_cross_products.size() > 0);
BitstringReader bs(codes, code_size);
float accu = accumulate_IPs(*this, bs, LUT);
BitstringReader bs2(codes, code_size);
float norm2 = compute_norm_from_LUT(*this, bs2);
return norm2 - 2 * accu;
}

} // namespace faiss
18 changes: 15 additions & 3 deletions faiss/impl/AdditiveQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct AdditiveQuantizer : Quantizer {
std::vector<float> codebooks; ///< codebooks

// derived values
/// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1]
/// in the codebooks table of size total_codebook_size by d
std::vector<uint64_t> codebook_offsets;
size_t tot_bits = 0; ///< total number of bits (indexes + norms)
size_t norm_bits = 0; ///< bits allocated for the norms
Expand All @@ -38,9 +40,19 @@ struct AdditiveQuantizer : Quantizer {
bool verbose = false; ///< verbose during training?
bool is_trained = false; ///< is trained or not

IndexFlat1D qnorm; ///< store and search norms
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
///< fastscan search
/// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4
/// store norms of codebook entries for 4-bit fastscan
std::vector<float> norm_tabs;
IndexFlat1D qnorm; ///< store and search norms

void compute_codebook_tables();

/// norms of all codebook entries (size total_codebook_size)
std::vector<float> centroid_norms;

/// dot products of all codebook entries with the previous codebooks
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
std::vector<float> codebook_cross_products;

/// norms and distance matrixes with beam search can get large, so use this
/// to control for the amount of memory that can be allocated
Expand Down
34 changes: 0 additions & 34 deletions faiss/impl/ResidualQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,40 +492,6 @@ void ResidualQuantizer::refine_beam(
* Functions using the dot products between codebook entries
*******************************************************************/

void ResidualQuantizer::compute_codebook_tables() {
cent_norms.resize(total_codebook_size);
fvec_norms_L2sqr(
cent_norms.data(), codebooks.data(), d, total_codebook_size);
size_t cross_table_size = 0;
for (int m = 0; m < M; m++) {
size_t K = (size_t)1 << nbits[m];
cross_table_size += K * codebook_offsets[m];
}
codebook_cross_products.resize(cross_table_size);
size_t ofs = 0;
for (int m = 1; m < M; m++) {
FINTEGER ki = (size_t)1 << nbits[m];
FINTEGER kk = codebook_offsets[m];
FINTEGER di = d;
float zero = 0, one = 1;
assert(ofs + ki * kk <= cross_table_size);
sgemm_("Transposed",
"Not transposed",
&ki,
&kk,
&di,
&one,
codebooks.data() + d * kk,
&di,
codebooks.data(),
&di,
&zero,
codebook_cross_products.data() + ofs,
&ki);
ofs += ki * kk;
}
}

void ResidualQuantizer::refine_beam_LUT(
size_t n,
const float* query_norms, // size n
Expand Down
10 changes: 0 additions & 10 deletions faiss/impl/ResidualQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
* @param beam_size if != -1, override the beam size
*/
size_t memory_per_point(int beam_size = -1) const;

/** Cross products used in codebook tables used for beam_LUT = 1
*/
void compute_codebook_tables();

/// dot products of all codebook entries with the previous codebooks
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
std::vector<float> codebook_cross_products;
/// norms of all codebook entries (size total_codebook_size)
std::vector<float> cent_norms;
};

} // namespace faiss
2 changes: 1 addition & 1 deletion faiss/impl/residual_quantizer_encode_steps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ void refine_beam_LUT_mp(
rq.codebook_offsets.data(),
query_cp + rq.codebook_offsets[m],
rq.total_codebook_size,
rq.cent_norms.data() + rq.codebook_offsets[m],
rq.centroid_norms.data() + rq.codebook_offsets[m],
m,
codes_ptr,
distances_ptr,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_residual_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,37 @@ def test_search_decompress(self):
# recalls are {1: 0.05, 10: 0.37, 100: 0.37}
self.assertGreater(recalls[10], 0.35)

def do_exact_search_equiv(self, norm_type):
""" searching with this normalization should yield
exactly the same results as decompression (because the
norms are exact) """
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

# decompresses by default
ir = faiss.IndexResidualQuantizer(ds.d, 3, 6)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.train(ds.get_train())
ir.add(ds.get_database())
Dref, Iref = ir.search(ds.get_queries(), 10)

ir2 = faiss.IndexResidualQuantizer(
ds.d, 3, 6, faiss.METRIC_L2, norm_type)

# assumes training is reproducible
ir2.rq.train_type = faiss.ResidualQuantizer.Train_default
ir2.train(ds.get_train())
ir2.add(ds.get_database())
D, I = ir2.search(ds.get_queries(), 10)

np.testing.assert_allclose(D, Dref, atol=1e-5)
np.testing.assert_array_equal(I, Iref)

def test_exact_equiv_norm_float(self):
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_float)

def test_exact_equiv_norm_from_LUT(self):
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_from_LUT)

def test_reestimate_codebook(self):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

Expand Down Expand Up @@ -858,6 +889,9 @@ def test_norm_cqint(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8)
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4)

def test_norm_from_LUT(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_from_LUT)

def test_factory(self):
index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat")
self.assertEqual(index.nlist, 1024)
Expand Down Expand Up @@ -1105,7 +1139,7 @@ def test_precomp(self):
ofs += kk * K
np.testing.assert_allclose(py_table, cpp_table, atol=1e-5)

cent_norms = faiss.vector_to_array(rq.cent_norms)
cent_norms = faiss.vector_to_array(rq.centroid_norms)
np.testing.assert_array_almost_equal(
np.hstack(cent_norms_ref), cent_norms, decimal=5)

Expand Down

0 comments on commit eec426d

Please sign in to comment.