@@ -1000,11 +1000,12 @@ bvt bv_utilst::comba_column_wise(const std::vector<bvt> &pps)
10001000// #define RADIX_MULTIPLIER 8
10011001// #define USE_KARATSUBA
10021002// #define USE_TOOM_COOK
1003- // #define USE_SCHOENHAGE_STRASSEN
1003+ #define USE_SCHOENHAGE_STRASSEN
10041004#ifdef RADIX_MULTIPLIER
1005+ // # define COMBA
10051006# define DADDA_TREE
10061007#endif
1007- #define COMBA
1008+ // #define COMBA
10081009
10091010#ifdef RADIX_MULTIPLIER
10101011static bvt unsigned_multiply_by_3 (propt &prop, const bvt &op)
@@ -2183,12 +2184,29 @@ bvt bv_utilst::unsigned_toom_cook_multiplier(const bvt &_op0, const bvt &_op1)
21832184}
21842185
21852186bvt bv_utilst::unsigned_schoenhage_strassen_multiplier (
2186- const bvt &a ,
2187- const bvt &b )
2187+ const bvt &_a ,
2188+ const bvt &_b )
21882189{
2190+ // http://malte-leip.net/beschreibung_ssa.pdf,
2191+ // https://de.wikipedia.org/wiki/Sch%C3%B6nhage-Strassen-Algorithmus
2192+ // Isabelle proof: https://mediatum.ub.tum.de/doc/1717658/1717658.pdf
2193+ bvt a = _a;
2194+ #if 1
2195+ // bvt b = a;
2196+ bvt b = _b;
2197+ #else
2198+ bvt b = _b;
2199+ a.resize(14);
2200+ b.resize(14);
2201+ #endif
2202+
21892203 PRECONDITION (a.size () == b.size ());
2204+ const std::size_t op_size = a.size ();
2205+ // delta.size() <= result_size doesn't hold for op_size <= 3
2206+ if (op_size <= 3 )
2207+ return unsigned_multiplier (a, b);
21902208
2191- // Running examples : we want to multiple 213 by 15 as 8- or 9-bit integers.
2209+ // Running example : we want to multiply 213 by 15 as 8- or 9-bit integers.
21922210 // That is, we seek to multiply 11010101 (011010101) by 00001111 (000001111).
21932211 // ^bit 7 ^bit 0
21942212 // The expected result is 123 as both an 8-bit and 9-bit result (001111011).
@@ -2202,7 +2220,7 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22022220 // m >= log_2(op_size) + 1.
22032221 // For our examples m will be 4 and 5, respectively, with Fermat numbers
22042222 // 2^16 + 1 and 2^32 + 1.
2205- const std::size_t m = address_bits (a. size () ) + 1 ;
2223+ const std::size_t m = address_bits (op_size ) + 1 ;
22062224 std::cerr << " m: " << m << std::endl;
22072225
22082226 // Extend bit width to 2^(m + 1) = op_size (rounded to next power of 2) * 4
@@ -2243,8 +2261,11 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22432261 {
22442262 a_rho.emplace_back (
22452263 a_ext.begin () + i * chunk_size, a_ext.begin () + (i + 1 ) * chunk_size);
2264+ std::cerr << " a_rho[" << i << " ]: " << beautify (a_rho.back ()) << std::endl;
22462265 b_sigma.emplace_back (
22472266 b_ext.begin () + i * chunk_size, b_ext.begin () + (i + 1 ) * chunk_size);
2267+ std::cerr << " b_sigma[" << i << " ]: " << beautify (b_sigma.back ())
2268+ << std::endl;
22482269 }
22492270 // For our example we now have
22502271 // a_rho = [ 0101, 1101, 0000, ..., 0000 ]
@@ -2266,11 +2287,16 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22662287 ++rho)
22672288 {
22682289 const std::size_t sigma = tau - rho;
2269- gamma_tau[tau] = add (
2270- gamma_tau[tau],
2290+ std::cerr << " Inner multiplication a_ " << rho << " * b_ " << sigma;
2291+ auto inner_product = zero_extension (
22712292 unsigned_multiplier (
2272- zero_extension (a_rho[rho], 3 * n + 5 ),
2273- zero_extension (b_sigma[sigma], 3 * n + 5 )));
2293+ zero_extension (a_rho[rho], chunk_size * 2 ),
2294+ zero_extension (b_sigma[sigma], chunk_size * 2 )),
2295+ 3 * n + 5 );
2296+ std::cerr << " = " << beautify (inner_product) << std::endl;
2297+ gamma_tau[tau] = add (gamma_tau[tau], inner_product);
2298+ std::cerr << " gamma_tau[" << tau << " ] = " << beautify (gamma_tau[tau])
2299+ << std::endl;
22742300 }
22752301 }
22762302 // For our example we obtain
@@ -2282,6 +2308,7 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22822308 c_tau.reserve (num_chunks);
22832309 for (std::size_t tau = 0 ; tau < num_chunks; ++tau)
22842310 {
2311+ // std::cerr << "gamma_tau[" << tau << "]: " << beautify(gamma_tau[tau]) << std::endl;
22852312 c_tau.push_back (add (gamma_tau[tau], gamma_tau[tau + num_chunks]));
22862313 CHECK_RETURN (c_tau.back ().size () >= address_bits (num_chunks) + 1 );
22872314 c_tau.back ().resize (address_bits (num_chunks) + 1 );
@@ -2295,7 +2322,11 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22952322 std::vector<bvt> z_j;
22962323 z_j.reserve (num_chunks / 2 );
22972324 for (std::size_t j = 0 ; j < num_chunks / 2 ; ++j)
2325+ {
22982326 z_j.push_back (sub (c_tau[j], c_tau[j + num_chunks / 2 ]));
2327+ // z_j.back().resize(address_bits(num_chunks) + 1);
2328+ std::cerr << " z_" << j << " = " << beautify (z_j.back ()) << std::endl;
2329+ }
22992330 // For our example we have z_j = c_tau as all elements beyond the second one
23002331 // are zeros.
23012332
@@ -2325,14 +2356,14 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23252356 // inverse NTT.
23262357
23272358 // Addition mod F_n with overflow
2328- auto cyclic_add = [this ](const bvt &x, const bvt &y)
2329- {
2359+ auto cyclic_add = [this ](const bvt &x, const bvt &y) {
23302360 PRECONDITION (x.size () == y.size ());
23312361
23322362 auto result_with_overflow = adder (x, y, const_literal (false ));
23332363 if (result_with_overflow.second .is_false ())
23342364 return result_with_overflow.first ;
23352365
2366+ std::cerr << " OVERFLOW" << std::endl;
23362367 return add (
23372368 result_with_overflow.first ,
23382369 zero_extension (bvt{1 , result_with_overflow.second }, x.size ()));
@@ -2365,14 +2396,16 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23652396 j <<= 1 ; // the initial shift has no effect
23662397 j |= (k & (1 << nu)) >> nu;
23672398 }
2399+ std::cerr << " k=" << k << " yields j=" << j << std::endl;
23682400 Aa.push_back (a_j[j]);
23692401 Ab.push_back (b_j[j]);
2402+ std::cerr << " Aa[0](" << k << " ): " << beautify (Aa.back ()) << std::endl;
23702403 }
2371- for (std::size_t nu = 1 ; nu <= address_bits (num_chunks); ++nu)
2404+ for (std::size_t nu = 0 ; nu < address_bits (num_chunks); ++nu)
23722405 {
2373- const std::size_t bit_nu = (std::size_t )1 << (nu - 1 ) ;
2406+ const std::size_t bit_nu = (std::size_t )1 << nu ;
23742407 std::size_t bits_up_to_nu = 0 ;
2375- for (std::size_t i = 0 ; i < nu - 1 ; ++i)
2408+ for (std::size_t i = 0 ; i < nu; ++i)
23762409 bits_up_to_nu |= 1 << i;
23772410
23782411 // we only need odd ones
@@ -2384,21 +2417,26 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23842417 bvt Aa_nu_bit_is_zero = Aa[k & ~bit_nu];
23852418 bvt Ab_nu_bit_is_zero = Ab[k & ~bit_nu];
23862419
2420+ std::cerr << " Round " << (nu + 1 ) << " , k=" << k
2421+ << " , k & ~bit_nu=" << (k & ~bit_nu) << std::endl;
23872422 const std::size_t chi = (k & bits_up_to_nu)
2388- << (address_bits (num_chunks) - 1 - (nu - 1 ));
2423+ << (address_bits (num_chunks) - 1 - nu);
2424+ std::cerr << " k & bits_up_to_nu=" << (k & bits_up_to_nu)
2425+ << " , chi=" << chi << std::endl;
23892426 const std::size_t omega = m % 2 == 1 ? 2 : 4 ;
23902427 const std::size_t shift_dist = chi * omega / 2 ;
23912428
2392- if (nu > 1 ) // no need to update even indices
2429+ if (nu > 0 ) // no need to update even indices
23932430 {
23942431 Aa[k & ~bit_nu] = cyclic_add (
23952432 Aa_nu_bit_is_zero, shift (Aa[k], shiftt::ROTATE_LEFT, shift_dist));
23962433 Ab[k & ~bit_nu] = cyclic_add (
23972434 Ab_nu_bit_is_zero, shift (Ab[k], shiftt::ROTATE_LEFT, shift_dist));
2398- std::cerr << " Aa[" << nu << " ](" << (k & ~bit_nu)
2435+ std::cerr << " shift_dist: " << shift_dist << std::endl;
2436+ std::cerr << " Aa[" << (nu + 1 ) << " ](" << (k & ~bit_nu)
23992437 << " ): " << beautify (Aa[k & ~bit_nu]) << std::endl;
24002438#if 0
2401- std::cerr << "Ab[" << nu << "](" << (k & ~bit_nu)
2439+ std::cerr << "Ab[" << (nu + 1) << "](" << (k & ~bit_nu)
24022440 << "): " << beautify(Ab[k & ~bit_nu]) << std::endl;
24032441#endif
24042442 }
@@ -2412,19 +2450,20 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
24122450 Ab[k] = cyclic_add (
24132451 Ab_nu_bit_is_zero,
24142452 shift (Ab[k], shiftt::ROTATE_LEFT, shift_dist_for_sub));
2415- std::cerr << " Aa[" << nu << " ](" << k << " ): " << beautify (Aa[k])
2453+ std::cerr << " shift_dist_for_sub: " << shift_dist_for_sub << std::endl;
2454+ std::cerr << " Aa[" << (nu + 1 ) << " ](" << k << " ): " << beautify (Aa[k])
24162455 << std::endl;
24172456#if 0
2418- std::cerr << "Ab[" << nu << "](" << k << "): " << beautify(Ab[k])
2457+ std::cerr << "Ab[" << (nu + 1) << "](" << k << "): " << beautify(Ab[k])
24192458 << std::endl;
24202459#endif
24212460 }
24222461 }
24232462
24242463 // Either compute u - v (if u > v), else u - v + 2^2^n + 1
2425- auto reduce_to_mod_F_n = [this ](const bvt &x)
2426- {
2464+ auto reduce_to_mod_F_n = [this ](const bvt &x, std::size_t n) {
24272465 const std::size_t two_to_power_of_n = x.size () / 2 ;
2466+ PRECONDITION (two_to_power_of_n == (std::size_t )1 << n);
24282467 // std::cerr << "two_to_power_of_n: " << two_to_power_of_n << std::endl;
24292468 const bvt u =
24302469 zero_extension (bvt{x.begin (), x.begin () + two_to_power_of_n}, x.size ());
@@ -2445,54 +2484,64 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
24452484
24462485 std::vector<bvt> a_hat_k{num_chunks, bvt{}}, b_hat_k{num_chunks, bvt{}};
24472486 // Reduce by F_n
2448- for (std::size_t j = 1 ; j < num_chunks; j += 2 )
2487+ for (std::size_t k = 1 ; k < num_chunks; k += 2 )
24492488 {
2450- a_hat_k[j ] = reduce_to_mod_F_n (Aa[j] );
2451- std::cerr << " a_hat_k[" << j << " ]: " << beautify (a_hat_k[j ]) << std::endl;
2452- b_hat_k[j ] = reduce_to_mod_F_n (Ab[j] );
2453- std::cerr << " b_hat_k[" << j << " ]: " << beautify (b_hat_k[j ]) << std::endl;
2489+ a_hat_k[k ] = reduce_to_mod_F_n (Aa[k], n );
2490+ std::cerr << " a_hat_k[" << k << " ]: " << beautify (a_hat_k[k ]) << std::endl;
2491+ b_hat_k[k ] = reduce_to_mod_F_n (Ab[k], n );
2492+ std::cerr << " b_hat_k[" << k << " ]: " << beautify (b_hat_k[k ]) << std::endl;
24542493 }
24552494
24562495 // Compute point-wise multiplication
24572496 std::vector<bvt> c_hat_k{num_chunks, bvt{}};
2458- for (std::size_t j = 1 ; j < num_chunks; j += 2 )
2459- {
2460- c_hat_k[j] = unsigned_multiplier (a_hat_k[j], b_hat_k[j]);
2461- std::cerr << " c_hat_k[" << j << " ]: " << beautify (c_hat_k[j]) << std::endl;
2497+ for (std::size_t k = 1 ; k < num_chunks; k += 2 )
2498+ {
2499+ // If at least one of a_hat_k[k] or b_hat_k[k] is 2^2^n (i.e., F_n - 1) then
2500+ // multiplication would overflow, so handle those cases separately:
2501+ // x * 2^2^n = x * -1 (mod F_n) which can be computed by rotating x by 2^n
2502+ const std::size_t power_2_n = (std::size_t )1 << n;
2503+ c_hat_k[k] = select (
2504+ a_hat_k[k][power_2_n],
2505+ shift (b_hat_k[k], shiftt::ROTATE_LEFT, power_2_n),
2506+ select (
2507+ b_hat_k[k][power_2_n],
2508+ shift (a_hat_k[k], shiftt::ROTATE_LEFT, power_2_n),
2509+ unsigned_multiplier (a_hat_k[k], b_hat_k[k])));
2510+ std::cerr << " c_hat_k[" << k << " ]: " << beautify (c_hat_k[k]) << std::endl;
24622511 }
24632512
24642513 // Apply inverse NTT
24652514 for (std::size_t nu = address_bits (num_chunks) - 1 ; nu > 0 ; --nu)
24662515 {
2467- const std::size_t bit_nu_plus_1 = (std::size_t )1 << nu;
2468- std::size_t bits_up_to_nu_plus_1 = 0 ;
2516+ const std::size_t bit_nu = (std::size_t )1 << nu;
2517+ std::size_t bits_up_to_nu = 0 ;
24692518 for (std::size_t i = 0 ; i < nu; ++i)
2470- bits_up_to_nu_plus_1 |= 1 << i;
2519+ bits_up_to_nu |= 1 << i;
24712520
24722521 // we only need odd ones
24732522 for (std::size_t k = 1 ; k < num_chunks; k += 2 )
24742523 {
2475- if ((k & bit_nu_plus_1 ) == 0 )
2524+ if ((k & bit_nu ) == 0 )
24762525 continue ;
24772526
2478- bvt c_hat_k_nu_plus_1_bit_is_zero = c_hat_k[k & ~bit_nu_plus_1 ];
2527+ bvt c_hat_k_nu_bit_is_zero = c_hat_k[k & ~bit_nu ];
24792528
2480- c_hat_k[k & ~bit_nu_plus_1 ] = shift (
2481- cyclic_add (c_hat_k_nu_plus_1_bit_is_zero , c_hat_k[k]),
2529+ c_hat_k[k & ~bit_nu ] = shift (
2530+ cyclic_add (c_hat_k_nu_bit_is_zero , c_hat_k[k]),
24822531 shiftt::ROTATE_RIGHT,
24832532 1 );
2484- std::cerr << " c_hat_k[" << nu << " ](" << (k & ~bit_nu_plus_1 )
2485- << " ): " << beautify (c_hat_k[k & ~bit_nu_plus_1 ]) << std::endl;
2533+ std::cerr << " c_hat_k[" << nu << " ](" << (k & ~bit_nu )
2534+ << " ): " << beautify (c_hat_k[k & ~bit_nu ]) << std::endl;
24862535
2487- const std::size_t chi = (k & bits_up_to_nu_plus_1 )
2536+ const std::size_t chi = (k & bits_up_to_nu )
24882537 << (address_bits (num_chunks) - 1 - nu);
24892538 const std::size_t omega = m % 2 == 1 ? 2 : 4 ;
24902539 const std::size_t shift_dist = chi * omega / 2 + 1 ;
24912540 std::cerr << " SHIFT: " << shift_dist << std::endl;
24922541
24932542 c_hat_k[k] = shift (
24942543 cyclic_add (
2495- c_hat_k_nu_plus_1_bit_is_zero ,
2544+ c_hat_k_nu_bit_is_zero ,
24962545 shift (c_hat_k[k], shiftt::ROTATE_LEFT, (std::size_t )1 << n)),
24972546 shiftt::ROTATE_RIGHT,
24982547 shift_dist);
@@ -2514,45 +2563,63 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
25142563 }
25152564 k |= 1 ;
25162565 std::cerr << " j " << j << " maps to " << k << std::endl;
2517- z_j_mod_F_n.push_back (reduce_to_mod_F_n (c_hat_k[k]));
2566+ // TODO: we could add a capability to reduce_to_mod_F_n to restrict the
2567+ // result to op_size bits so that this call (but not the ones above) could
2568+ // use it.
2569+ z_j_mod_F_n.push_back (reduce_to_mod_F_n (c_hat_k[k], n));
25182570 std::cerr << " z_j_mod_F_n[" << j << " ]: " << beautify (z_j_mod_F_n[j])
25192571 << std::endl;
25202572 }
25212573
2522- // Compute final coefficients as eta + delta * F_n where delta = eta - xi for
2523- // eta z_j and xi c_hat_k.
2574+ // Compute final coefficients as xi + delta * F_n where delta = eta - xi (mod
2575+ // 2^(n + 2) for odd m and 2^(n + 1) for even m) for eta being z_j and xi
2576+ // being z_j_mod_F_n.
2577+ #if 0
2578+ // To compute the full-width result
2579+ const std::size_t result_size = two_to_m_plus_1;
2580+ #else
2581+ const std::size_t result_size = op_size;
2582+ #endif
25242583 for (std::size_t j = 0 ; j < num_chunks / 2 ; ++j)
25252584 {
2526- bvt eta = z_j_mod_F_n [j];
2585+ const bvt & eta = z_j [j];
25272586 std::cerr << " eta[" << j << " ]: " << beautify (eta) << std::endl;
2528- bvt xi = z_j [j];
2587+ const bvt & xi = z_j_mod_F_n [j];
25292588 std::cerr << " xi[" << j << " ]: " << beautify (xi) << std::endl;
2530- // TODO: couldn't we do this over just xi .size() bits instead?
2531- bvt delta = sub (eta, zero_extension (xi, eta.size ())) ;
2532- CHECK_RETURN ( delta. size () >= xi. size () );
2533- delta.resize (xi. size ());
2589+ PRECONDITION (eta .size () == address_bits (num_chunks) + 1 );
2590+ bvt xi_2n_2{xi. begin (), xi. begin () + eta.size ()} ;
2591+ bvt delta = sub (eta, xi_2n_2 );
2592+ PRECONDITION ( delta.size () <= result_size );
25342593 std::cerr << " delta[" << j << " ]: " << beautify (delta) << std::endl;
2535- z_j[j] = add (
2536- zero_extension (eta, two_to_m_plus_1),
2537- add (
2538- shift (
2539- zero_extension (delta, two_to_m_plus_1),
2540- shiftt::SHIFT_LEFT,
2541- (std::size_t )1 << n),
2542- zero_extension (delta, two_to_m_plus_1)));
2594+ bvt delta_times_F_n = add (
2595+ shift (
2596+ zero_extension (delta, result_size),
2597+ shiftt::SHIFT_LEFT,
2598+ (std::size_t )1 << n),
2599+ zero_extension (delta, result_size));
2600+ std::cerr << " delta * F_n: " << beautify (delta_times_F_n) << std::endl;
2601+ if (xi.size () > result_size)
2602+ {
2603+ bvt xi_result_size{xi.begin (), xi.begin () + result_size};
2604+ z_j[j] = add (xi_result_size, delta_times_F_n);
2605+ }
2606+ else
2607+ {
2608+ z_j[j] = add (zero_extension (xi, result_size), delta_times_F_n);
2609+ }
25432610 std::cerr << " z_j[" << j << " ]: " << beautify (z_j[j]) << std::endl;
25442611 }
25452612
2546- bvt result = zeros (two_to_m_plus_1 );
2613+ bvt result = zeros (result_size );
25472614 for (std::size_t j = 0 ; j < num_chunks / 2 ; ++j)
25482615 {
2549- if (chunk_size * j >= a. size () )
2616+ if (chunk_size * j >= result_size )
25502617 break ;
25512618 result = add (result, shift (z_j[j], shiftt::SHIFT_LEFT, chunk_size * j));
25522619 }
25532620 std::cerr << " result: " << beautify (result) << std::endl;
2554- CHECK_RETURN (result.size () >= a. size () );
2555- result.resize (a. size () );
2621+ CHECK_RETURN (result.size () >= op_size );
2622+ result.resize (op_size );
25562623 std::cerr << " result resized: " << beautify (result) << std::endl;
25572624
25582625 return result;
0 commit comments