From 6b5142a7a9b31922d9c7ef505b27c33d551f5016 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Sun, 2 Jul 2023 21:22:02 +0100 Subject: [PATCH] Move mul_root3 out into misc.c and generalise it. I'm going to want to reuse it for sqrt(5) as well as sqrt(3) soon. --- grid.c | 129 +----------------------------------------------------- misc.c | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++ puzzles.h | 1 + 3 files changed, 127 insertions(+), 127 deletions(-) diff --git a/grid.c b/grid.c index 904a9f6..29432bf 100644 --- a/grid.c +++ b/grid.c @@ -3669,131 +3669,6 @@ struct spectrecontext { tree234 *points; }; -/* - * Calculate the nearest integer to n*sqrt(3), via a bitwise algorithm - * that avoids floating point. - * - * (It would probably be OK in practice to use floating point, but I - * felt like overengineering it for fun. With FP, there's at least a - * theoretical risk of rounding the wrong way, due to the three - * successive roundings involved - rounding sqrt(3), rounding its - * product with n, and then rounding to the nearest integer. This - * approach avoids that: it's exact.) - */ -static int mul_root3(int n_signed) -{ - unsigned x, r, m; - int sign = n_signed < 0 ? -1 : +1; - unsigned n = n_signed * sign; - unsigned bitpos; - - /* - * Method: - * - * We transform m gradually from zero into n, by multiplying it by - * 2 in each step and optionally adding 1, so that it's always - * floor(n/2^something). - * - * At the start of each step, x is the largest integer less than - * or equal to m*sqrt(3). We transform m to 2m+bit, and therefore - * we must transform x to 2x+something to match. The 'something' - * we add to 2x is at most 3. (Worst case is if m sqrt(3) was - * equal to x + 1-eps for some tiny eps, and then the incoming bit - * of m is 1, so that (2m+1)sqrt(3) = 2x+2+2eps+sqrt(3), i.e. - * about 2x + 3.732...) - * - * To compute this, we also track the residual value r such that - * x^2+r = 3m^2. - * - * The algorithm below is very similar to the usual approach for - * taking the square root of an integer in binary. The wrinkle is - * that we have an integer multiplier, i.e. we're computing - * P*sqrt(Q) (with P=n and Q=3 in this case) rather than just - * sqrt(Q). Of course in principle we could just take sqrt(P^2Q), - * but we'd need an integer twice the width to hold P^2. Pulling - * out P and treating it specially makes overflow less likely. - */ - - x = r = m = 0; - - for (bitpos = UINT_MAX & ~(UINT_MAX >> 1); bitpos; bitpos >>= 1) { - unsigned a, b = (n & bitpos) ? 1 : 0; - - /* - * Check invariants. We expect that x^2 + r = 3m^2 (i.e. our - * residual term is correct), and also that r < 2x+1 (because - * if not, then we could replace x with x+1 and still get a - * value that made r non-negative, i.e. x would not be the - * _largest_ integer less than m sqrt(3)). - */ - assert(x*x + r == 3*m*m); - assert(r < 2*x+1); - - /* - * We're going to replace m with 2m+b, and x with 2x+a for - * some a we haven't decided on yet. - * - * The new value of the residual will therefore be - * - * 3 (2m+b)^2 - (2x+a)^2 - * = (12m^2 + 12mb + 3b^2) - (4x^2 + 4xa + a^2) - * = 4 (3m^2 - x^2) + 12mb + 3b^2 - 4xa - a^2 - * = 4r + 12mb + 3b^2 - 4xa - a^2 (because r = 3m^2 - x^2) - * = 4r + (12m + 3)b - 4xa - a^2 (b is 0 or 1, so b = b^2) - */ - for (a = 0; a < 4; a++) { - /* If we made this routine handle square roots of numbers - * other than 3 then it would be sensible to make this a - * binary search. Here, it hardly seems important. */ - unsigned pos = 4*r + b*(12*m + 3); - unsigned neg = 4*a*x + a*a; - if (pos < neg) - break; /* this value of a is too big */ - } - - /* The above loop will have terminated with a one too big, - * whether that's because we hit the break statement or fell - * off the end with a=4. So now decrementing a will give us - * the right value to add. */ - a--; - - r = 4*r + b*(12*m + 3) - (4*a*x + a*a); - m = 2*m+b; - x = 2*x+a; - } - - /* - * Finally, round to the nearest integer. At present, x is the - * largest integer that is _at most_ m sqrt(3). But we want the - * _nearest_ integer, whether that's rounded up or down. So check - * whether (x + 1/2) is still less than m sqrt(3), i.e. whether - * (x + 1/2)^2 < 3m^2; if it is, then we increment x. - * - * We have 3m^2 - (x + 1/2)^2 = 3m^2 - x^2 - x - 1/4 - * = r - x - 1/4 - * - * and since r and x are integers, this is greater than 0 if and - * only if r > x. - * - * (There's no need to worry about tie-breaking exact halfway - * rounding cases. sqrt(3) is irrational, so none such exist.) - */ - if (r > x) - x++; - - /* - * Put the sign back on, and convert back from unsigned to int. - */ - if (sign == +1) { - return x; - } else { - /* Be a little careful to avoid compilers deciding I've just - * perpetrated signed-integer overflow. This should optimise - * down to no actual code. */ - return INT_MIN + (int)(-x - (unsigned)INT_MIN); - } -} - static void grid_spectres_callback(void *vctx, const int *coords) { struct spectrecontext *ctx = (struct spectrecontext *)vctx; @@ -3804,9 +3679,9 @@ static void grid_spectres_callback(void *vctx, const int *coords) grid_dot *d = grid_get_dot( ctx->g, ctx->points, (coords[4*i+0] * SPECTRE_UNIT + - mul_root3(coords[4*i+1] * SPECTRE_UNIT)), + n_times_root_k(coords[4*i+1] * SPECTRE_UNIT, 3)), (coords[4*i+2] * SPECTRE_UNIT + - mul_root3(coords[4*i+3] * SPECTRE_UNIT))); + n_times_root_k(coords[4*i+3] * SPECTRE_UNIT, 3))); grid_face_set_dot(ctx->g, d, i); } } diff --git a/misc.c b/misc.c index 9a757d4..334e08d 100644 --- a/misc.c +++ b/misc.c @@ -536,4 +536,128 @@ char *make_prefs_path(const char *dir, const char *sep, return path; } +/* + * Calculate the nearest integer to n*sqrt(k), via a bitwise algorithm + * that avoids floating point. + * + * (It would probably be OK in practice to use floating point, but I + * felt like overengineering it for fun. With FP, there's at least a + * theoretical risk of rounding the wrong way, due to the three + * successive roundings involved - rounding sqrt(k), rounding its + * product with n, and then rounding to the nearest integer. This + * approach avoids that: it's exact.) + */ +int n_times_root_k(int n_signed, int k) +{ + unsigned x, r, m; + int sign = n_signed < 0 ? -1 : +1; + unsigned n = n_signed * sign; + unsigned bitpos; + + /* + * Method: + * + * We transform m gradually from zero into n, by multiplying it by + * 2 in each step and optionally adding 1, so that it's always + * floor(n/2^something). + * + * At the start of each step, x is the largest integer less than + * or equal to m*sqrt(k). We transform m to 2m+bit, and therefore + * we must transform x to 2x+something to match. The 'something' + * we add to 2x is at most floor(sqrt(k))+2. (Worst case is if m + * sqrt(k) was equal to x + 1-eps for some tiny eps, and then the + * incoming bit of m is 1, so that (2m+1)sqrt(k) = + * 2x+2+sqrt(k)-2eps.) + * + * To compute this, we also track the residual value r such that + * x^2+r = km^2. + * + * The algorithm below is very similar to the usual approach for + * taking the square root of an integer in binary. The wrinkle is + * that we have an integer multiplier, i.e. we're computing + * n*sqrt(k) rather than just sqrt(k). Of course in principle we + * could just take sqrt(n^2k), but we'd need an integer twice the + * width to hold n^2. Pulling out n and treating it specially + * makes overflow less likely. + */ + + x = r = m = 0; + + for (bitpos = UINT_MAX & ~(UINT_MAX >> 1); bitpos; bitpos >>= 1) { + unsigned a, b = (n & bitpos) ? 1 : 0; + + /* + * Check invariants. We expect that x^2 + r = km^2 (i.e. our + * residual term is correct), and also that r < 2x+1 (because + * if not, then we could replace x with x+1 and still get a + * value that made r non-negative, i.e. x would not be the + * _largest_ integer less than m sqrt(k)). + */ + assert(x*x + r == k*m*m); + assert(r < 2*x+1); + + /* + * We're going to replace m with 2m+b, and x with 2x+a for + * some a we haven't decided on yet. + * + * The new value of the residual will therefore be + * + * k (2m+b)^2 - (2x+a)^2 + * = (4km^2 + 4kmb + kb^2) - (4x^2 + 4xa + a^2) + * = 4 (km^2 - x^2) + 4kmb + kb^2 - 4xa - a^2 + * = 4r + 4kmb + kb^2 - 4xa - a^2 (because r = km^2 - x^2) + * = 4r + (4m + 1)kb - 4xa - a^2 (b is 0 or 1, so b = b^2) + */ + for (a = 0;; a++) { + /* If we made this routine handle square roots of numbers + * significantly bigger than 3 or 5 then it would be + * sensible to make this a binary search. Here, it hardly + * seems important. */ + unsigned pos = 4*r + k*b*(4*m + 1); + unsigned neg = 4*a*x + a*a; + if (pos < neg) + break; /* this value of a is too big */ + } + + /* The above loop will have terminated with a one too big. So + * now decrementing a will give us the right value to add. */ + a--; + + r = 4*r + b*k*(4*m + 1) - (4*a*x + a*a); + m = 2*m+b; + x = 2*x+a; + } + + /* + * Finally, round to the nearest integer. At present, x is the + * largest integer that is _at most_ m sqrt(k). But we want the + * _nearest_ integer, whether that's rounded up or down. So check + * whether (x + 1/2) is still less than m sqrt(k), i.e. whether + * (x + 1/2)^2 < km^2; if it is, then we increment x. + * + * We have km^2 - (x + 1/2)^2 = km^2 - x^2 - x - 1/4 + * = r - x - 1/4 + * + * and since r and x are integers, this is greater than 0 if and + * only if r > x. + * + * (There's no need to worry about tie-breaking exact halfway + * rounding cases. sqrt(k) is irrational, so none such exist.) + */ + if (r > x) + x++; + + /* + * Put the sign back on, and convert back from unsigned to int. + */ + if (sign == +1) { + return x; + } else { + /* Be a little careful to avoid compilers deciding I've just + * perpetrated signed-integer overflow. This should optimise + * down to no actual code. */ + return INT_MIN + (int)(-x - (unsigned)INT_MIN); + } +} + /* vim: set shiftwidth=4 tabstop=8: */ diff --git a/puzzles.h b/puzzles.h index f057c21..b4d33de 100644 --- a/puzzles.h +++ b/puzzles.h @@ -391,6 +391,7 @@ void obfuscate_bitmap(unsigned char *bmp, int bits, bool decode); char *fgetline(FILE *fp); char *make_prefs_path(const char *dir, const char *sep, const game *game, const char *suffix); +int n_times_root_k(int n, int k); /* allocates output each time. len is always in bytes of binary data. * May assert (or just go wrong) if lengths are unchecked. */