Skip to content

Commit 1c60d3a

Browse files
committed
WIP: Schoenhage-Strassen debugging
1 parent bc309ff commit 1c60d3a

File tree

1 file changed

+131
-64
lines changed

1 file changed

+131
-64
lines changed

src/solvers/flattening/bv_utils.cpp

Lines changed: 131 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,11 +1000,12 @@ bvt bv_utilst::comba_column_wise(const std::vector<bvt> &pps)
10001000
// #define RADIX_MULTIPLIER 8
10011001
// #define USE_KARATSUBA
10021002
// #define USE_TOOM_COOK
1003-
// #define USE_SCHOENHAGE_STRASSEN
1003+
#define USE_SCHOENHAGE_STRASSEN
10041004
#ifdef RADIX_MULTIPLIER
1005+
//# define COMBA
10051006
# define DADDA_TREE
10061007
#endif
1007-
#define COMBA
1008+
// #define COMBA
10081009

10091010
#ifdef RADIX_MULTIPLIER
10101011
static bvt unsigned_multiply_by_3(propt &prop, const bvt &op)
@@ -2183,12 +2184,29 @@ bvt bv_utilst::unsigned_toom_cook_multiplier(const bvt &_op0, const bvt &_op1)
21832184
}
21842185

21852186
bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
2186-
const bvt &a,
2187-
const bvt &b)
2187+
const bvt &_a,
2188+
const bvt &_b)
21882189
{
2190+
// http://malte-leip.net/beschreibung_ssa.pdf,
2191+
// https://de.wikipedia.org/wiki/Sch%C3%B6nhage-Strassen-Algorithmus
2192+
// Isabelle proof: https://mediatum.ub.tum.de/doc/1717658/1717658.pdf
2193+
bvt a = _a;
2194+
#if 1
2195+
// bvt b = a;
2196+
bvt b = _b;
2197+
#else
2198+
bvt b = _b;
2199+
a.resize(14);
2200+
b.resize(14);
2201+
#endif
2202+
21892203
PRECONDITION(a.size() == b.size());
2204+
const std::size_t op_size = a.size();
2205+
// delta.size() <= result_size doesn't hold for op_size <= 3
2206+
if(op_size <= 3)
2207+
return unsigned_multiplier(a, b);
21902208

2191-
// Running examples: we want to multiple 213 by 15 as 8- or 9-bit integers.
2209+
// Running example: we want to multiply 213 by 15 as 8- or 9-bit integers.
21922210
// That is, we seek to multiply 11010101 (011010101) by 00001111 (000001111).
21932211
// ^bit 7 ^bit 0
21942212
// The expected result is 123 as both an 8-bit and 9-bit result (001111011).
@@ -2202,7 +2220,7 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22022220
// m >= log_2(op_size) + 1.
22032221
// For our examples m will be 4 and 5, respectively, with Fermat numbers
22042222
// 2^16 + 1 and 2^32 + 1.
2205-
const std::size_t m = address_bits(a.size()) + 1;
2223+
const std::size_t m = address_bits(op_size) + 1;
22062224
std::cerr << "m: " << m << std::endl;
22072225

22082226
// Extend bit width to 2^(m + 1) = op_size (rounded to next power of 2) * 4
@@ -2243,8 +2261,11 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22432261
{
22442262
a_rho.emplace_back(
22452263
a_ext.begin() + i * chunk_size, a_ext.begin() + (i + 1) * chunk_size);
2264+
std::cerr << "a_rho[" << i << "]: " << beautify(a_rho.back()) << std::endl;
22462265
b_sigma.emplace_back(
22472266
b_ext.begin() + i * chunk_size, b_ext.begin() + (i + 1) * chunk_size);
2267+
std::cerr << "b_sigma[" << i << "]: " << beautify(b_sigma.back())
2268+
<< std::endl;
22482269
}
22492270
// For our example we now have
22502271
// a_rho = [ 0101, 1101, 0000, ..., 0000 ]
@@ -2266,11 +2287,16 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22662287
++rho)
22672288
{
22682289
const std::size_t sigma = tau - rho;
2269-
gamma_tau[tau] = add(
2270-
gamma_tau[tau],
2290+
std::cerr << "Inner multiplication a_" << rho << " * b_" << sigma;
2291+
auto inner_product = zero_extension(
22712292
unsigned_multiplier(
2272-
zero_extension(a_rho[rho], 3 * n + 5),
2273-
zero_extension(b_sigma[sigma], 3 * n + 5)));
2293+
zero_extension(a_rho[rho], chunk_size * 2),
2294+
zero_extension(b_sigma[sigma], chunk_size * 2)),
2295+
3 * n + 5);
2296+
std::cerr << " = " << beautify(inner_product) << std::endl;
2297+
gamma_tau[tau] = add(gamma_tau[tau], inner_product);
2298+
std::cerr << "gamma_tau[" << tau << "] = " << beautify(gamma_tau[tau])
2299+
<< std::endl;
22742300
}
22752301
}
22762302
// For our example we obtain
@@ -2282,6 +2308,7 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22822308
c_tau.reserve(num_chunks);
22832309
for(std::size_t tau = 0; tau < num_chunks; ++tau)
22842310
{
2311+
// std::cerr << "gamma_tau[" << tau << "]: " << beautify(gamma_tau[tau]) << std::endl;
22852312
c_tau.push_back(add(gamma_tau[tau], gamma_tau[tau + num_chunks]));
22862313
CHECK_RETURN(c_tau.back().size() >= address_bits(num_chunks) + 1);
22872314
c_tau.back().resize(address_bits(num_chunks) + 1);
@@ -2295,7 +2322,11 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
22952322
std::vector<bvt> z_j;
22962323
z_j.reserve(num_chunks / 2);
22972324
for(std::size_t j = 0; j < num_chunks / 2; ++j)
2325+
{
22982326
z_j.push_back(sub(c_tau[j], c_tau[j + num_chunks / 2]));
2327+
// z_j.back().resize(address_bits(num_chunks) + 1);
2328+
std::cerr << "z_" << j << " = " << beautify(z_j.back()) << std::endl;
2329+
}
22992330
// For our example we have z_j = c_tau as all elements beyond the second one
23002331
// are zeros.
23012332

@@ -2325,14 +2356,14 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23252356
// inverse NTT.
23262357

23272358
// Addition mod F_n with overflow
2328-
auto cyclic_add = [this](const bvt &x, const bvt &y)
2329-
{
2359+
auto cyclic_add = [this](const bvt &x, const bvt &y) {
23302360
PRECONDITION(x.size() == y.size());
23312361

23322362
auto result_with_overflow = adder(x, y, const_literal(false));
23332363
if(result_with_overflow.second.is_false())
23342364
return result_with_overflow.first;
23352365

2366+
std::cerr << "OVERFLOW" << std::endl;
23362367
return add(
23372368
result_with_overflow.first,
23382369
zero_extension(bvt{1, result_with_overflow.second}, x.size()));
@@ -2365,14 +2396,16 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23652396
j <<= 1; // the initial shift has no effect
23662397
j |= (k & (1 << nu)) >> nu;
23672398
}
2399+
std::cerr << "k=" << k << " yields j=" << j << std::endl;
23682400
Aa.push_back(a_j[j]);
23692401
Ab.push_back(b_j[j]);
2402+
std::cerr << "Aa[0](" << k << "): " << beautify(Aa.back()) << std::endl;
23702403
}
2371-
for(std::size_t nu = 1; nu <= address_bits(num_chunks); ++nu)
2404+
for(std::size_t nu = 0; nu < address_bits(num_chunks); ++nu)
23722405
{
2373-
const std::size_t bit_nu = (std::size_t)1 << (nu - 1);
2406+
const std::size_t bit_nu = (std::size_t)1 << nu;
23742407
std::size_t bits_up_to_nu = 0;
2375-
for(std::size_t i = 0; i < nu - 1; ++i)
2408+
for(std::size_t i = 0; i < nu; ++i)
23762409
bits_up_to_nu |= 1 << i;
23772410

23782411
// we only need odd ones
@@ -2384,21 +2417,26 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
23842417
bvt Aa_nu_bit_is_zero = Aa[k & ~bit_nu];
23852418
bvt Ab_nu_bit_is_zero = Ab[k & ~bit_nu];
23862419

2420+
std::cerr << "Round " << (nu + 1) << ", k=" << k
2421+
<< ", k & ~bit_nu=" << (k & ~bit_nu) << std::endl;
23872422
const std::size_t chi = (k & bits_up_to_nu)
2388-
<< (address_bits(num_chunks) - 1 - (nu - 1));
2423+
<< (address_bits(num_chunks) - 1 - nu);
2424+
std::cerr << "k & bits_up_to_nu=" << (k & bits_up_to_nu)
2425+
<< ", chi=" << chi << std::endl;
23892426
const std::size_t omega = m % 2 == 1 ? 2 : 4;
23902427
const std::size_t shift_dist = chi * omega / 2;
23912428

2392-
if(nu > 1) // no need to update even indices
2429+
if(nu > 0) // no need to update even indices
23932430
{
23942431
Aa[k & ~bit_nu] = cyclic_add(
23952432
Aa_nu_bit_is_zero, shift(Aa[k], shiftt::ROTATE_LEFT, shift_dist));
23962433
Ab[k & ~bit_nu] = cyclic_add(
23972434
Ab_nu_bit_is_zero, shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist));
2398-
std::cerr << "Aa[" << nu << "](" << (k & ~bit_nu)
2435+
std::cerr << "shift_dist: " << shift_dist << std::endl;
2436+
std::cerr << "Aa[" << (nu + 1) << "](" << (k & ~bit_nu)
23992437
<< "): " << beautify(Aa[k & ~bit_nu]) << std::endl;
24002438
#if 0
2401-
std::cerr << "Ab[" << nu << "](" << (k & ~bit_nu)
2439+
std::cerr << "Ab[" << (nu + 1) << "](" << (k & ~bit_nu)
24022440
<< "): " << beautify(Ab[k & ~bit_nu]) << std::endl;
24032441
#endif
24042442
}
@@ -2412,19 +2450,20 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
24122450
Ab[k] = cyclic_add(
24132451
Ab_nu_bit_is_zero,
24142452
shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist_for_sub));
2415-
std::cerr << "Aa[" << nu << "](" << k << "): " << beautify(Aa[k])
2453+
std::cerr << "shift_dist_for_sub: " << shift_dist_for_sub << std::endl;
2454+
std::cerr << "Aa[" << (nu + 1) << "](" << k << "): " << beautify(Aa[k])
24162455
<< std::endl;
24172456
#if 0
2418-
std::cerr << "Ab[" << nu << "](" << k << "): " << beautify(Ab[k])
2457+
std::cerr << "Ab[" << (nu + 1) << "](" << k << "): " << beautify(Ab[k])
24192458
<< std::endl;
24202459
#endif
24212460
}
24222461
}
24232462

24242463
// Either compute u - v (if u > v), else u - v + 2^2^n + 1
2425-
auto reduce_to_mod_F_n = [this](const bvt &x)
2426-
{
2464+
auto reduce_to_mod_F_n = [this](const bvt &x, std::size_t n) {
24272465
const std::size_t two_to_power_of_n = x.size() / 2;
2466+
PRECONDITION(two_to_power_of_n == (std::size_t)1 << n);
24282467
// std::cerr << "two_to_power_of_n: " << two_to_power_of_n << std::endl;
24292468
const bvt u =
24302469
zero_extension(bvt{x.begin(), x.begin() + two_to_power_of_n}, x.size());
@@ -2445,54 +2484,64 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
24452484

24462485
std::vector<bvt> a_hat_k{num_chunks, bvt{}}, b_hat_k{num_chunks, bvt{}};
24472486
// Reduce by F_n
2448-
for(std::size_t j = 1; j < num_chunks; j += 2)
2487+
for(std::size_t k = 1; k < num_chunks; k += 2)
24492488
{
2450-
a_hat_k[j] = reduce_to_mod_F_n(Aa[j]);
2451-
std::cerr << "a_hat_k[" << j << "]: " << beautify(a_hat_k[j]) << std::endl;
2452-
b_hat_k[j] = reduce_to_mod_F_n(Ab[j]);
2453-
std::cerr << "b_hat_k[" << j << "]: " << beautify(b_hat_k[j]) << std::endl;
2489+
a_hat_k[k] = reduce_to_mod_F_n(Aa[k], n);
2490+
std::cerr << "a_hat_k[" << k << "]: " << beautify(a_hat_k[k]) << std::endl;
2491+
b_hat_k[k] = reduce_to_mod_F_n(Ab[k], n);
2492+
std::cerr << "b_hat_k[" << k << "]: " << beautify(b_hat_k[k]) << std::endl;
24542493
}
24552494

24562495
// Compute point-wise multiplication
24572496
std::vector<bvt> c_hat_k{num_chunks, bvt{}};
2458-
for(std::size_t j = 1; j < num_chunks; j += 2)
2459-
{
2460-
c_hat_k[j] = unsigned_multiplier(a_hat_k[j], b_hat_k[j]);
2461-
std::cerr << "c_hat_k[" << j << "]: " << beautify(c_hat_k[j]) << std::endl;
2497+
for(std::size_t k = 1; k < num_chunks; k += 2)
2498+
{
2499+
// If at least one of a_hat_k[k] or b_hat_k[k] is 2^2^n (i.e., F_n - 1) then
2500+
// multiplication would overflow, so handle those cases separately:
2501+
// x * 2^2^n = x * -1 (mod F_n) which can be computed by rotating x by 2^n
2502+
const std::size_t power_2_n = (std::size_t)1 << n;
2503+
c_hat_k[k] = select(
2504+
a_hat_k[k][power_2_n],
2505+
shift(b_hat_k[k], shiftt::ROTATE_LEFT, power_2_n),
2506+
select(
2507+
b_hat_k[k][power_2_n],
2508+
shift(a_hat_k[k], shiftt::ROTATE_LEFT, power_2_n),
2509+
unsigned_multiplier(a_hat_k[k], b_hat_k[k])));
2510+
std::cerr << "c_hat_k[" << k << "]: " << beautify(c_hat_k[k]) << std::endl;
24622511
}
24632512

24642513
// Apply inverse NTT
24652514
for(std::size_t nu = address_bits(num_chunks) - 1; nu > 0; --nu)
24662515
{
2467-
const std::size_t bit_nu_plus_1 = (std::size_t)1 << nu;
2468-
std::size_t bits_up_to_nu_plus_1 = 0;
2516+
const std::size_t bit_nu = (std::size_t)1 << nu;
2517+
std::size_t bits_up_to_nu = 0;
24692518
for(std::size_t i = 0; i < nu; ++i)
2470-
bits_up_to_nu_plus_1 |= 1 << i;
2519+
bits_up_to_nu |= 1 << i;
24712520

24722521
// we only need odd ones
24732522
for(std::size_t k = 1; k < num_chunks; k += 2)
24742523
{
2475-
if((k & bit_nu_plus_1) == 0)
2524+
if((k & bit_nu) == 0)
24762525
continue;
24772526

2478-
bvt c_hat_k_nu_plus_1_bit_is_zero = c_hat_k[k & ~bit_nu_plus_1];
2527+
bvt c_hat_k_nu_bit_is_zero = c_hat_k[k & ~bit_nu];
24792528

2480-
c_hat_k[k & ~bit_nu_plus_1] = shift(
2481-
cyclic_add(c_hat_k_nu_plus_1_bit_is_zero, c_hat_k[k]),
2529+
c_hat_k[k & ~bit_nu] = shift(
2530+
cyclic_add(c_hat_k_nu_bit_is_zero, c_hat_k[k]),
24822531
shiftt::ROTATE_RIGHT,
24832532
1);
2484-
std::cerr << "c_hat_k[" << nu << "](" << (k & ~bit_nu_plus_1)
2485-
<< "): " << beautify(c_hat_k[k & ~bit_nu_plus_1]) << std::endl;
2533+
std::cerr << "c_hat_k[" << nu << "](" << (k & ~bit_nu)
2534+
<< "): " << beautify(c_hat_k[k & ~bit_nu]) << std::endl;
24862535

2487-
const std::size_t chi = (k & bits_up_to_nu_plus_1)
2536+
const std::size_t chi = (k & bits_up_to_nu)
24882537
<< (address_bits(num_chunks) - 1 - nu);
24892538
const std::size_t omega = m % 2 == 1 ? 2 : 4;
24902539
const std::size_t shift_dist = chi * omega / 2 + 1;
24912540
std::cerr << "SHIFT: " << shift_dist << std::endl;
24922541

24932542
c_hat_k[k] = shift(
24942543
cyclic_add(
2495-
c_hat_k_nu_plus_1_bit_is_zero,
2544+
c_hat_k_nu_bit_is_zero,
24962545
shift(c_hat_k[k], shiftt::ROTATE_LEFT, (std::size_t)1 << n)),
24972546
shiftt::ROTATE_RIGHT,
24982547
shift_dist);
@@ -2514,45 +2563,63 @@ bvt bv_utilst::unsigned_schoenhage_strassen_multiplier(
25142563
}
25152564
k |= 1;
25162565
std::cerr << "j " << j << " maps to " << k << std::endl;
2517-
z_j_mod_F_n.push_back(reduce_to_mod_F_n(c_hat_k[k]));
2566+
// TODO: we could add a capability to reduce_to_mod_F_n to restrict the
2567+
// result to op_size bits so that this call (but not the ones above) could
2568+
// use it.
2569+
z_j_mod_F_n.push_back(reduce_to_mod_F_n(c_hat_k[k], n));
25182570
std::cerr << "z_j_mod_F_n[" << j << "]: " << beautify(z_j_mod_F_n[j])
25192571
<< std::endl;
25202572
}
25212573

2522-
// Compute final coefficients as eta + delta * F_n where delta = eta - xi for
2523-
// eta z_j and xi c_hat_k.
2574+
// Compute final coefficients as xi + delta * F_n where delta = eta - xi (mod
2575+
// 2^(n + 2) for odd m and 2^(n + 1) for even m) for eta being z_j and xi
2576+
// being z_j_mod_F_n.
2577+
#if 0
2578+
// To compute the full-width result
2579+
const std::size_t result_size = two_to_m_plus_1;
2580+
#else
2581+
const std::size_t result_size = op_size;
2582+
#endif
25242583
for(std::size_t j = 0; j < num_chunks / 2; ++j)
25252584
{
2526-
bvt eta = z_j_mod_F_n[j];
2585+
const bvt &eta = z_j[j];
25272586
std::cerr << "eta[" << j << "]: " << beautify(eta) << std::endl;
2528-
bvt xi = z_j[j];
2587+
const bvt &xi = z_j_mod_F_n[j];
25292588
std::cerr << "xi[" << j << "]: " << beautify(xi) << std::endl;
2530-
// TODO: couldn't we do this over just xi.size() bits instead?
2531-
bvt delta = sub(eta, zero_extension(xi, eta.size()));
2532-
CHECK_RETURN(delta.size() >= xi.size());
2533-
delta.resize(xi.size());
2589+
PRECONDITION(eta.size() == address_bits(num_chunks) + 1);
2590+
bvt xi_2n_2{xi.begin(), xi.begin() + eta.size()};
2591+
bvt delta = sub(eta, xi_2n_2);
2592+
PRECONDITION(delta.size() <= result_size);
25342593
std::cerr << "delta[" << j << "]: " << beautify(delta) << std::endl;
2535-
z_j[j] = add(
2536-
zero_extension(eta, two_to_m_plus_1),
2537-
add(
2538-
shift(
2539-
zero_extension(delta, two_to_m_plus_1),
2540-
shiftt::SHIFT_LEFT,
2541-
(std::size_t)1 << n),
2542-
zero_extension(delta, two_to_m_plus_1)));
2594+
bvt delta_times_F_n = add(
2595+
shift(
2596+
zero_extension(delta, result_size),
2597+
shiftt::SHIFT_LEFT,
2598+
(std::size_t)1 << n),
2599+
zero_extension(delta, result_size));
2600+
std::cerr << "delta * F_n: " << beautify(delta_times_F_n) << std::endl;
2601+
if(xi.size() > result_size)
2602+
{
2603+
bvt xi_result_size{xi.begin(), xi.begin() + result_size};
2604+
z_j[j] = add(xi_result_size, delta_times_F_n);
2605+
}
2606+
else
2607+
{
2608+
z_j[j] = add(zero_extension(xi, result_size), delta_times_F_n);
2609+
}
25432610
std::cerr << "z_j[" << j << "]: " << beautify(z_j[j]) << std::endl;
25442611
}
25452612

2546-
bvt result = zeros(two_to_m_plus_1);
2613+
bvt result = zeros(result_size);
25472614
for(std::size_t j = 0; j < num_chunks / 2; ++j)
25482615
{
2549-
if(chunk_size * j >= a.size())
2616+
if(chunk_size * j >= result_size)
25502617
break;
25512618
result = add(result, shift(z_j[j], shiftt::SHIFT_LEFT, chunk_size * j));
25522619
}
25532620
std::cerr << "result: " << beautify(result) << std::endl;
2554-
CHECK_RETURN(result.size() >= a.size());
2555-
result.resize(a.size());
2621+
CHECK_RETURN(result.size() >= op_size);
2622+
result.resize(op_size);
25562623
std::cerr << "result resized: " << beautify(result) << std::endl;
25572624

25582625
return result;

0 commit comments

Comments
 (0)