From dce7c09413d2f467007197411264b7ac70b1dfbe Mon Sep 17 00:00:00 2001 From: vorj <40021161+vorj@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:59:41 -0700 Subject: [PATCH] Add some SVE implementations (#3933) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: related: https://github.com/facebookresearch/faiss/issues/2884 I added some SVE implementations of: - `code_distance` - `distance_single_code` - `distance_four_codes` - `exhaustive_L2sqr_blas_cmax_sve` - `fvec_inner_products_ny` - `fvec_madd` ## Evaluation result I evaluated the search for SIFT1M dataset on AWS EC2 c7g.large and r8g.large instances. `main` is the current (2e6551ffa3f6fbdb1ba814c2c531fb399b00d4e3) implementation. ### c7g.large (Graviton 3) ![g3_sift1m](https://github.com/user-attachments/assets/9c03cffa-72d1-4c77-9ae8-0ec0a5f5a6a5) ![g3_ivfpq](https://github.com/user-attachments/assets/4a8dfcc8-823c-4c31-ae79-3f4af9be28c8) On Graviton 3, `IndexIVFPQ` has been improved particularly. In the best case (IndexIVFPQ + IndexFlatL2, M: 32), this PR is approx. 2.38-~~2.50~~**2.44**x faster than `main` . - nprobe: 1, 0.069ms/query → 0.029ms/query - nprobe: 4, 0.181ms/query → ~~0.074~~**0.075**ms/query - nprobe: 16, 0.613ms/query → ~~0.245~~**0.251**ms/query ### r8g.large (Graviton 4) ![g4_sift1m](https://github.com/user-attachments/assets/e8510163-49d2-4143-babe-d406e2e40398) ![g4_ivfpq](https://github.com/user-attachments/assets/dc9a3ae0-a6b5-4a07-9898-c6aff372025c) On Graviton 4, especially `IndexIVFPQ` for tiny `nprobe` has been improved. In the best case (IndexIVFPQ + IndexFlatL2, M: 8, nprobe: 1), this PR is approx. 1.33x faster than `main` (0.016ms/query → 0.012ms/query). Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3933 Reviewed By: mengdilin Differential Revision: D64249808 Pulled By: asadoughi fbshipit-source-id: 8a625f0ab37732d330192599c851f864350885c4 --- faiss/impl/code_distance/code_distance-sve.h | 440 +++++++++++++++++ faiss/impl/code_distance/code_distance.h | 53 ++ faiss/utils/distances.cpp | 189 +++++++ faiss/utils/distances_simd.cpp | 493 +++++++++++++++++++ 4 files changed, 1175 insertions(+) create mode 100644 faiss/impl/code_distance/code_distance-sve.h diff --git a/faiss/impl/code_distance/code_distance-sve.h b/faiss/impl/code_distance/code_distance-sve.h new file mode 100644 index 0000000000..c15a755d1c --- /dev/null +++ b/faiss/impl/code_distance/code_distance-sve.h @@ -0,0 +1,440 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __ARM_FEATURE_SVE + +#include + +#include +#include + +#include + +namespace faiss { + +template +std::enable_if_t, float> inline distance_single_code_sve( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + // default implementation + return distance_single_code_generic(M, nbits, sim_table, code); +} + +static inline void distance_codes_kernel( + svbool_t pg, + svuint32_t idx1, + svuint32_t offsets_0, + const float* tab, + svfloat32_t& partialSum) { + // add offset + const auto indices_to_read_from = svadd_u32_x(pg, idx1, offsets_0); + + // gather values, similar to some operations of tab[index] + const auto collected = + svld1_gather_u32index_f32(pg, tab, indices_to_read_from); + + // collect partial sum + partialSum = svadd_f32_m(pg, partialSum, collected); +} + +static float distance_single_code_sve_for_small_m( + // the product quantizer + const size_t M, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code) { + constexpr size_t nbits = 8u; + + const size_t ksub = 1 << nbits; + + const auto offsets_0 = svindex_u32(0, static_cast(ksub)); + + // loop + const auto pg = svwhilelt_b32_u64(0, M); + + auto mm1 = svld1ub_u32(pg, code); + mm1 = svadd_u32_x(pg, mm1, offsets_0); + const auto collected0 = svld1_gather_u32index_f32(pg, sim_table, mm1); + return svaddv_f32(pg, collected0); +} + +template +std::enable_if_t, float> inline distance_single_code_sve( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + if (M <= svcntw()) + return distance_single_code_sve_for_small_m(M, sim_table, code); + + const float* tab = sim_table; + + const size_t ksub = 1 << nbits; + + const auto offsets_0 = svindex_u32(0, static_cast(ksub)); + + // accumulators of partial sums + auto partialSum = svdup_n_f32(0.f); + + const auto lanes = svcntb(); + const auto quad_lanes = lanes / 4; + + // loop + for (std::size_t m = 0; m < M;) { + const auto pg = svwhilelt_b8_u64(m, M); + + const auto mm1 = svld1_u8(pg, code + m); + { + const auto mm1lo = svunpklo_u16(mm1); + const auto pglo = svunpklo_b(pg); + + { + // convert uint8 values to uint32 values + const auto idx1 = svunpklo_u32(mm1lo); + const auto pglolo = svunpklo_b(pglo); + + distance_codes_kernel(pglolo, idx1, offsets_0, tab, partialSum); + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + + { + // convert uint8 values to uint32 values + const auto idx1 = svunpkhi_u32(mm1lo); + const auto pglohi = svunpkhi_b(pglo); + + distance_codes_kernel(pglohi, idx1, offsets_0, tab, partialSum); + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + } + + { + const auto mm1hi = svunpkhi_u16(mm1); + const auto pghi = svunpkhi_b(pg); + + { + // convert uint8 values to uint32 values + const auto idx1 = svunpklo_u32(mm1hi); + const auto pghilo = svunpklo_b(pghi); + + distance_codes_kernel(pghilo, idx1, offsets_0, tab, partialSum); + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + + { + // convert uint8 values to uint32 values + const auto idx1 = svunpkhi_u32(mm1hi); + const auto pghihi = svunpkhi_b(pghi); + + distance_codes_kernel(pghihi, idx1, offsets_0, tab, partialSum); + tab += ksub * quad_lanes; + } + + m += quad_lanes; + } + } + + return svaddv_f32(svptrue_b32(), partialSum); +} + +template +std::enable_if_t, void> +distance_four_codes_sve( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + distance_four_codes_generic( + M, + nbits, + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); +} + +static void distance_four_codes_sve_for_small_m( + // the product quantizer + const size_t M, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + constexpr size_t nbits = 8u; + + const size_t ksub = 1 << nbits; + + const auto offsets_0 = svindex_u32(0, static_cast(ksub)); + + const auto quad_lanes = svcntw(); + + // loop + const auto pg = svwhilelt_b32_u64(0, M); + + auto mm10 = svld1ub_u32(pg, code0); + auto mm11 = svld1ub_u32(pg, code1); + auto mm12 = svld1ub_u32(pg, code2); + auto mm13 = svld1ub_u32(pg, code3); + mm10 = svadd_u32_x(pg, mm10, offsets_0); + mm11 = svadd_u32_x(pg, mm11, offsets_0); + mm12 = svadd_u32_x(pg, mm12, offsets_0); + mm13 = svadd_u32_x(pg, mm13, offsets_0); + const auto collected0 = svld1_gather_u32index_f32(pg, sim_table, mm10); + const auto collected1 = svld1_gather_u32index_f32(pg, sim_table, mm11); + const auto collected2 = svld1_gather_u32index_f32(pg, sim_table, mm12); + const auto collected3 = svld1_gather_u32index_f32(pg, sim_table, mm13); + result0 = svaddv_f32(pg, collected0); + result1 = svaddv_f32(pg, collected1); + result2 = svaddv_f32(pg, collected2); + result3 = svaddv_f32(pg, collected3); +} + +// Combines 4 operations of distance_single_code() +template +std::enable_if_t, void> +distance_four_codes_sve( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + if (M <= svcntw()) { + distance_four_codes_sve_for_small_m( + M, + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); + return; + } + + const float* tab = sim_table; + + const size_t ksub = 1 << nbits; + + const auto offsets_0 = svindex_u32(0, static_cast(ksub)); + + // accumulators of partial sums + auto partialSum0 = svdup_n_f32(0.f); + auto partialSum1 = svdup_n_f32(0.f); + auto partialSum2 = svdup_n_f32(0.f); + auto partialSum3 = svdup_n_f32(0.f); + + const auto lanes = svcntb(); + const auto quad_lanes = lanes / 4; + + // loop + for (std::size_t m = 0; m < M;) { + const auto pg = svwhilelt_b8_u64(m, M); + + const auto mm10 = svld1_u8(pg, code0 + m); + const auto mm11 = svld1_u8(pg, code1 + m); + const auto mm12 = svld1_u8(pg, code2 + m); + const auto mm13 = svld1_u8(pg, code3 + m); + { + const auto mm10lo = svunpklo_u16(mm10); + const auto mm11lo = svunpklo_u16(mm11); + const auto mm12lo = svunpklo_u16(mm12); + const auto mm13lo = svunpklo_u16(mm13); + const auto pglo = svunpklo_b(pg); + + { + const auto pglolo = svunpklo_b(pglo); + { + const auto idx1 = svunpklo_u32(mm10lo); + distance_codes_kernel( + pglolo, idx1, offsets_0, tab, partialSum0); + } + { + const auto idx1 = svunpklo_u32(mm11lo); + distance_codes_kernel( + pglolo, idx1, offsets_0, tab, partialSum1); + } + { + const auto idx1 = svunpklo_u32(mm12lo); + distance_codes_kernel( + pglolo, idx1, offsets_0, tab, partialSum2); + } + { + const auto idx1 = svunpklo_u32(mm13lo); + distance_codes_kernel( + pglolo, idx1, offsets_0, tab, partialSum3); + } + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + + { + const auto pglohi = svunpkhi_b(pglo); + { + const auto idx1 = svunpkhi_u32(mm10lo); + distance_codes_kernel( + pglohi, idx1, offsets_0, tab, partialSum0); + } + { + const auto idx1 = svunpkhi_u32(mm11lo); + distance_codes_kernel( + pglohi, idx1, offsets_0, tab, partialSum1); + } + { + const auto idx1 = svunpkhi_u32(mm12lo); + distance_codes_kernel( + pglohi, idx1, offsets_0, tab, partialSum2); + } + { + const auto idx1 = svunpkhi_u32(mm13lo); + distance_codes_kernel( + pglohi, idx1, offsets_0, tab, partialSum3); + } + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + } + + { + const auto mm10hi = svunpkhi_u16(mm10); + const auto mm11hi = svunpkhi_u16(mm11); + const auto mm12hi = svunpkhi_u16(mm12); + const auto mm13hi = svunpkhi_u16(mm13); + const auto pghi = svunpkhi_b(pg); + + { + const auto pghilo = svunpklo_b(pghi); + { + const auto idx1 = svunpklo_u32(mm10hi); + distance_codes_kernel( + pghilo, idx1, offsets_0, tab, partialSum0); + } + { + const auto idx1 = svunpklo_u32(mm11hi); + distance_codes_kernel( + pghilo, idx1, offsets_0, tab, partialSum1); + } + { + const auto idx1 = svunpklo_u32(mm12hi); + distance_codes_kernel( + pghilo, idx1, offsets_0, tab, partialSum2); + } + { + const auto idx1 = svunpklo_u32(mm13hi); + distance_codes_kernel( + pghilo, idx1, offsets_0, tab, partialSum3); + } + tab += ksub * quad_lanes; + } + + m += quad_lanes; + if (m >= M) + break; + + { + const auto pghihi = svunpkhi_b(pghi); + { + const auto idx1 = svunpkhi_u32(mm10hi); + distance_codes_kernel( + pghihi, idx1, offsets_0, tab, partialSum0); + } + { + const auto idx1 = svunpkhi_u32(mm11hi); + distance_codes_kernel( + pghihi, idx1, offsets_0, tab, partialSum1); + } + { + const auto idx1 = svunpkhi_u32(mm12hi); + distance_codes_kernel( + pghihi, idx1, offsets_0, tab, partialSum2); + } + { + const auto idx1 = svunpkhi_u32(mm13hi); + distance_codes_kernel( + pghihi, idx1, offsets_0, tab, partialSum3); + } + tab += ksub * quad_lanes; + } + + m += quad_lanes; + } + } + + result0 = svaddv_f32(svptrue_b32(), partialSum0); + result1 = svaddv_f32(svptrue_b32(), partialSum1); + result2 = svaddv_f32(svptrue_b32(), partialSum2); + result3 = svaddv_f32(svptrue_b32(), partialSum3); +} + +} // namespace faiss + +#endif diff --git a/faiss/impl/code_distance/code_distance.h b/faiss/impl/code_distance/code_distance.h index 7cdf932f50..155e19a6d8 100644 --- a/faiss/impl/code_distance/code_distance.h +++ b/faiss/impl/code_distance/code_distance.h @@ -77,6 +77,59 @@ inline void distance_four_codes( } // namespace faiss +#elif defined(__ARM_FEATURE_SVE) + +#include + +namespace faiss { + +template +inline float distance_single_code( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // the code + const uint8_t* code) { + return distance_single_code_sve(M, nbits, sim_table, code); +} + +template +inline void distance_four_codes( + // the product quantizer + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + distance_four_codes_sve( + M, + nbits, + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); +} + +} // namespace faiss + #else #include diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 1506bee5cf..e698037aa1 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -18,6 +18,8 @@ #ifdef __AVX2__ #include +#elif defined(__ARM_FEATURE_SVE) +#include #endif #include @@ -557,6 +559,183 @@ void exhaustive_L2sqr_blas_cmax_avx2( InterruptCallback::check(); } } +#elif defined(__ARM_FEATURE_SVE) +void exhaustive_L2sqr_blas_cmax_sve( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + Top1BlockResultHandler>& res, + const float* y_norms) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) + return; + + /* block sizes */ + const size_t bs_x = distance_compute_blas_query_bs; + const size_t bs_y = distance_compute_blas_database_bs; + // const size_t bs_x = 16, bs_y = 16; + std::unique_ptr ip_block(new float[bs_x * bs_y]); + std::unique_ptr x_norms(new float[nx]); + std::unique_ptr del2; + + fvec_norms_L2sqr(x_norms.get(), x, d, nx); + + const size_t lanes = svcntw(); + + if (!y_norms) { + float* y_norms2 = new float[ny]; + del2.reset(y_norms2); + fvec_norms_L2sqr(y_norms2, y, d, ny); + y_norms = y_norms2; + } + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if (i1 > nx) + i1 = nx; + + res.begin_multiple(i0, i1); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) + j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_("Transpose", + "Not transpose", + &nyi, + &nxi, + &di, + &one, + y + j0 * d, + &di, + x + i0 * d, + &di, + &zero, + ip_block.get(), + &nyi); + } +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { + const size_t count = j1 - j0; + float* ip_line = ip_block.get() + (i - i0) * count; + + svprfw(svwhilelt_b32_u64(0, count), ip_line, SV_PLDL1KEEP); + svprfw(svwhilelt_b32_u64(lanes, count), + ip_line + lanes, + SV_PLDL1KEEP); + + // Track lanes min distances + lanes min indices. + // All the distances tracked do not take x_norms[i] + // into account in order to get rid of extra + // vaddq_f32(x_norms[i], ...) instructions + // is distance computations. + auto min_distances = svdup_n_f32(res.dis_tab[i] - x_norms[i]); + + // these indices are local and are relative to j0. + // so, value 0 means j0. + auto min_indices = svdup_n_u32(0u); + + auto current_indices = svindex_u32(0u, 1u); + + // process lanes * 2 elements per loop + for (size_t idx_j = 0; idx_j < count; + idx_j += lanes * 2, ip_line += lanes * 2) { + svprfw(svwhilelt_b32_u64(idx_j + lanes * 2, count), + ip_line + lanes * 2, + SV_PLDL1KEEP); + svprfw(svwhilelt_b32_u64(idx_j + lanes * 3, count), + ip_line + lanes * 3, + SV_PLDL1KEEP); + + // mask + const auto mask_0 = svwhilelt_b32_u64(idx_j, count); + const auto mask_1 = svwhilelt_b32_u64(idx_j + lanes, count); + + // load values for norms + const auto y_norm_0 = + svld1_f32(mask_0, y_norms + idx_j + j0 + 0); + const auto y_norm_1 = + svld1_f32(mask_1, y_norms + idx_j + j0 + lanes); + + // load values for dot products + const auto ip_0 = svld1_f32(mask_0, ip_line + 0); + const auto ip_1 = svld1_f32(mask_1, ip_line + lanes); + + // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]). + // x_norm[i] was dropped off because it is a constant for a + // given i. We'll deal with it later. + const auto distances_0 = + svmla_n_f32_z(mask_0, y_norm_0, ip_0, -2.f); + const auto distances_1 = + svmla_n_f32_z(mask_1, y_norm_1, ip_1, -2.f); + + // compare the new distances to the min distances + // for each of the first group of 4 ARM SIMD components. + auto comparison = + svcmpgt_f32(mask_0, min_distances, distances_0); + + // update min distances and indices with closest vectors if + // needed. + min_distances = + svsel_f32(comparison, distances_0, min_distances); + min_indices = + svsel_u32(comparison, current_indices, min_indices); + current_indices = svadd_n_u32_x( + mask_0, + current_indices, + static_cast(lanes)); + + // compare the new distances to the min distances + // for each of the second group of 4 ARM SIMD components. + comparison = + svcmpgt_f32(mask_1, min_distances, distances_1); + + // update min distances and indices with closest vectors if + // needed. + min_distances = + svsel_f32(comparison, distances_1, min_distances); + min_indices = + svsel_u32(comparison, current_indices, min_indices); + current_indices = svadd_n_u32_x( + mask_1, + current_indices, + static_cast(lanes)); + } + + // add missing x_norms[i] + // negative values can occur for identical vectors + // due to roundoff errors. + auto mask = svwhilelt_b32_u64(0, count); + min_distances = svadd_n_f32_z( + svcmpge_n_f32(mask, min_distances, -x_norms[i]), + min_distances, + x_norms[i]); + min_indices = svadd_n_u32_x( + mask, min_indices, static_cast(j0)); + mask = svcmple_n_f32(mask, min_distances, res.dis_tab[i]); + if (svcntp_b32(svptrue_b32(), mask) == 0) + res.add_result(i, res.dis_tab[i], res.ids_tab[i]); + else { + const auto min_distance = svminv_f32(mask, min_distances); + const auto min_index = svminv_u32( + svcmpeq_n_f32(mask, min_distances, min_distance), + min_indices); + res.add_result(i, min_distance, min_index); + } + } + } + // Does nothing for SingleBestResultHandler, but + // keeping the call for the consistency. + res.end_multiple(); + InterruptCallback::check(); + } +} #endif // an override if only a single closest point is needed @@ -579,6 +758,16 @@ void exhaustive_L2sqr_blas>>( // run the specialized AVX2 implementation exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms); +#elif defined(__ARM_FEATURE_SVE) + // use a faster fused kernel if available + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { + // the kernel is available and it is complete, we're done. + return; + } + + // run the specialized SVE implementation + exhaustive_L2sqr_blas_cmax_sve(x, y, d, nx, ny, res, y_norms); + #elif defined(__aarch64__) // use a faster fused kernel if available if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 7cebd2ae33..7cabfc0a25 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -29,6 +29,10 @@ #include #endif +#ifdef __ARM_FEATURE_SVE +#include +#endif + #ifdef __aarch64__ #include #endif @@ -2673,6 +2677,441 @@ float fvec_Linf(const float* x, const float* y, size_t d) { return fvec_Linf_ref(x, y, d); } +#elif defined(__ARM_FEATURE_SVE) + +struct ElementOpIP { + static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { + return svmul_f32_x(pg, x, y); + } + static svfloat32_t merge( + svbool_t pg, + svfloat32_t z, + svfloat32_t x, + svfloat32_t y) { + return svmla_f32_x(pg, z, x, y); + } +}; + +template +void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + size_t i = 0; + for (; i + lanes4 < ny; i += lanes4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + svst1_f32(pg, dis, y0); + svst1_f32(pg, dis + lanes, y1); + svst1_f32(pg, dis + lanes2, y2); + svst1_f32(pg, dis + lanes3, y3); + y += lanes4; + dis += lanes4; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); + const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); + svfloat32_t y0 = svld1_f32(pg0, y); + svfloat32_t y1 = svld1_f32(pg1, y + lanes); + svfloat32_t y2 = svld1_f32(pg2, y + lanes2); + svfloat32_t y3 = svld1_f32(pg3, y + lanes3); + y0 = ElementOp::op(pg0, x0, y0); + y1 = ElementOp::op(pg1, x0, y1); + y2 = ElementOp::op(pg2, x0, y2); + y3 = ElementOp::op(pg3, x0, y3); + svst1_f32(pg0, dis, y0); + svst1_f32(pg1, dis + lanes, y1); + svst1_f32(pg2, dis + lanes2, y2); + svst1_f32(pg3, dis + lanes3, y3); +} + +template +void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + size_t i = 0; + for (; i + lanes2 < ny; i += lanes2) { + const svfloat32x2_t y0 = svld2_f32(pg, y); + const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + svst1_f32(pg, dis, y00); + svst1_f32(pg, dis + lanes, y10); + y += lanes4; + dis += lanes2; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svfloat32x2_t y0 = svld2_f32(pg0, y); + const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg0, x0, y00); + y10 = ElementOp::op(pg1, x0, y10); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y10 = ElementOp::merge(pg1, y10, x1, y11); + svst1_f32(pg0, dis, y00); + svst1_f32(pg1, dis + lanes, y10); +} + +template +void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t y0 = svld4_f32(pg, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg, x0, y00); + y02 = ElementOp::op(pg, x2, y02); + y00 = ElementOp::merge(pg, y00, x1, y01); + y02 = ElementOp::merge(pg, y02, x3, y03); + y00 = svadd_f32_x(pg, y00, y02); + svst1_f32(pg, dis, y00); + y += lanes4; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svfloat32x4_t y0 = svld4_f32(pg0, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg0, x0, y00); + y02 = ElementOp::op(pg0, x2, y02); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y02 = ElementOp::merge(pg0, y02, x3, y03); + y00 = svadd_f32_x(pg0, y00, y02); + svst1_f32(pg0, dis, y00); +} + +template +void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const size_t lanes8 = lanes * 8; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + const svfloat32_t x4 = svdup_n_f32(x[4]); + const svfloat32_t x5 = svdup_n_f32(x[5]); + const svfloat32_t x6 = svdup_n_f32(x[6]); + const svfloat32_t x7 = svdup_n_f32(x[7]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t ya = svld4_f32(pg, y); + const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y4 = ElementOp::op(pg, x4, y4); + y6 = ElementOp::op(pg, x6, y6); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y4 = ElementOp::merge(pg, y4, x5, y5); + y6 = ElementOp::merge(pg, y6, x7, y7); + y0 = svadd_f32_x(pg, y0, y2); + y4 = svadd_f32_x(pg, y4, y6); + y0 = svadd_f32_x(pg, y0, y4); + svst1_f32(pg, dis, y0); + y += lanes8; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); + const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); + const svfloat32x4_t ya = svld4_f32(pga, y); + const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg0, x0, y0); + y2 = ElementOp::op(pg0, x2, y2); + y4 = ElementOp::op(pg0, x4, y4); + y6 = ElementOp::op(pg0, x6, y6); + y0 = ElementOp::merge(pg0, y0, x1, y1); + y2 = ElementOp::merge(pg0, y2, x3, y3); + y4 = ElementOp::merge(pg0, y4, x5, y5); + y6 = ElementOp::merge(pg0, y6, x7, y7); + y0 = svadd_f32_x(pg0, y0, y2); + y4 = svadd_f32_x(pg0, y4, y6); + y0 = svadd_f32_x(pg0, y0, y4); + svst1_f32(pg0, dis, y0); + y += lanes8; + dis += lanes; +} + +template +void fvec_op_ny_sve_lanes1( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + size_t i = 0; + for (; i + 3 < ny; i += 4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + dis[i] = svaddv_f32(pg, y0); + dis[i + 1] = svaddv_f32(pg, y1); + dis[i + 2] = svaddv_f32(pg, y2); + dis[i + 3] = svaddv_f32(pg, y3); + } + for (; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + y += lanes; + y0 = ElementOp::op(pg, x0, y0); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + size_t i = 0; + for (; i + 1 < ny; i += 2) { + svfloat32_t y00 = svld1_f32(pg, y); + const svfloat32_t y01 = svld1_f32(pg, y + lanes); + svfloat32_t y10 = svld1_f32(pg, y + lanes2); + const svfloat32_t y11 = svld1_f32(pg, y + lanes3); + y += lanes4; + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + dis[i] = svaddv_f32(pg, y00); + dis[i + 1] = svaddv_f32(pg, y10); + } + if (i < ny) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes3( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + y += lanes3; + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + y0 = ElementOp::merge(pg, y0, x2, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + const svfloat32_t x3 = svld1_f32(pg, x + lanes3); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + const svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y0 = svadd_f32_x(pg, y0, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); +} + +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed_ref( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +float fvec_L1(const float* x, const float* y, size_t d) { + return fvec_L1_ref(x, y, d); +} + +float fvec_Linf(const float* x, const float* y, size_t d) { + return fvec_Linf_ref(x, y, d); +} + +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + const size_t lanes = svcntw(); + switch (d) { + case 1: + fvec_op_ny_sve_d1(dis, x, y, ny); + break; + case 2: + fvec_op_ny_sve_d2(dis, x, y, ny); + break; + case 4: + fvec_op_ny_sve_d4(dis, x, y, ny); + break; + case 8: + fvec_op_ny_sve_d8(dis, x, y, ny); + break; + default: + if (d == lanes) + fvec_op_ny_sve_lanes1(dis, x, y, ny); + else if (d == lanes * 2) + fvec_op_ny_sve_lanes2(dis, x, y, ny); + else if (d == lanes * 3) + fvec_op_ny_sve_lanes3(dis, x, y, ny); + else if (d == lanes * 4) + fvec_op_ny_sve_lanes4(dis, x, y, ny); + else + fvec_inner_products_ny_ref(dis, x, y, d, ny); + break; + } +} + #elif defined(__aarch64__) // not optimized for ARM @@ -2934,6 +3373,60 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { #endif } +#elif defined(__ARM_FEATURE_SVE) + +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t lanes = static_cast(svcntw()); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + size_t i = 0; + for (; i + lanes4 < n; i += lanes4) { + const auto mask = svptrue_b32(); + const auto ai0 = svld1_f32(mask, a + i); + const auto ai1 = svld1_f32(mask, a + i + lanes); + const auto ai2 = svld1_f32(mask, a + i + lanes2); + const auto ai3 = svld1_f32(mask, a + i + lanes3); + const auto bi0 = svld1_f32(mask, b + i); + const auto bi1 = svld1_f32(mask, b + i + lanes); + const auto bi2 = svld1_f32(mask, b + i + lanes2); + const auto bi3 = svld1_f32(mask, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); + svst1_f32(mask, c + i, ci0); + svst1_f32(mask, c + i + lanes, ci1); + svst1_f32(mask, c + i + lanes2, ci2); + svst1_f32(mask, c + i + lanes3, ci3); + } + const auto mask0 = svwhilelt_b32_u64(i, n); + const auto mask1 = svwhilelt_b32_u64(i + lanes, n); + const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); + const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); + const auto ai0 = svld1_f32(mask0, a + i); + const auto ai1 = svld1_f32(mask1, a + i + lanes); + const auto ai2 = svld1_f32(mask2, a + i + lanes2); + const auto ai3 = svld1_f32(mask3, a + i + lanes3); + const auto bi0 = svld1_f32(mask0, b + i); + const auto bi1 = svld1_f32(mask1, b + i + lanes); + const auto bi2 = svld1_f32(mask2, b + i + lanes2); + const auto bi3 = svld1_f32(mask3, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); + svst1_f32(mask0, c + i, ci0); + svst1_f32(mask1, c + i + lanes, ci1); + svst1_f32(mask2, c + i + lanes2, ci2); + svst1_f32(mask3, c + i + lanes3, ci3); +} + #elif defined(__aarch64__) void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {