#ifndef MIN_VERSION_integer_gmp
#define MIN_VERSION_integer_gmp(a,b,c) 0
#endif
#if MIN_VERSION_integer_gmp(0,5,1)
#endif
#ifdef VERSION_integer_gmp
#endif
module Crypto.Number.Basic
    ( sqrti
    , gcde
    , gcde_binary
    , areEven
    , log2
    ) where
#if MIN_VERSION_integer_gmp(0,5,1)
import GHC.Integer.GMP.Internals
#else
import Data.Bits
#endif
#ifdef VERSION_integer_gmp
import GHC.Exts
import GHC.Integer.Logarithms (integerLog2#)
#endif
sqrti :: Integer -> (Integer, Integer)
sqrti i
    | i < 0     = error "cannot compute negative square root"
    | i == 0    = (0,0)
    | i == 1    = (1,1)
    | i == 2    = (1,2)
    | otherwise = loop x0
        where
            nbdigits = length $ show i
            x0n = (if even nbdigits then nbdigits  2 else nbdigits  1) `div` 2
            x0  = if even nbdigits then 2 * 10 ^ x0n else 6 * 10 ^ x0n
            loop x = case compare (sq x) i of
                LT -> iterUp x
                EQ -> (x, x)
                GT -> iterDown x
            iterUp lb = if sq ub >= i then iter lb ub else iterUp ub
                where ub = lb * 2
            iterDown ub = if sq lb >= i then iterDown lb else iter lb ub
                where lb = ub `div` 2
            iter lb ub
                | lb == ub   = (lb, ub)
                | lb+1 == ub = (lb, ub)
                | otherwise  =
                    let d = (ub  lb) `div` 2 in
                    if sq (lb + d) >= i
                        then iter lb (ubd)
                        else iter (lb+d) ub
            sq a = a * a
gcde :: Integer -> Integer -> (Integer, Integer, Integer)
#if MIN_VERSION_integer_gmp(0,5,1)
gcde a b = (s, t, g)
  where (# g, s #) = gcdExtInteger a b
        t = (g  s * a) `div` b
#else
gcde a b = if d < 0 then (x,y,d) else (x,y,d) where
    (d, x, y)                     = f (a,1,0) (b,0,1)
    f t              (0, _, _)    = t
    f (a', sa, ta) t@(b', sb, tb) =
        let (q, r) = a' `divMod` b' in
        f t (r, sa  (q * sb), ta  (q * tb))
#endif
gcde_binary :: Integer -> Integer -> (Integer, Integer, Integer)
#if MIN_VERSION_integer_gmp(0,5,1)
gcde_binary = gcde
#else
gcde_binary a' b'
    | b' == 0   = (1,0,a')
    | a' >= b'  = compute a' b'
    | otherwise = (\(x,y,d) -> (y,x,d)) $ compute b' a'
    where
        getEvenMultiplier !g !x !y
            | areEven [x,y] = getEvenMultiplier (g `shiftL` 1) (x `shiftR` 1) (y `shiftR` 1)
            | otherwise     = (x,y,g)
        halfLoop !x !y !u !i !j
            | areEven [u,i,j] = halfLoop x y (u `shiftR` 1) (i `shiftR` 1) (j `shiftR` 1)
            | even u          = halfLoop x y (u `shiftR` 1) ((i + y) `shiftR` 1) ((j  x) `shiftR` 1)
            | otherwise       = (u, i, j)
        compute a b =
            let (x,y,g) = getEvenMultiplier 1 a b in
            loop g x y x y 1 0 0 1
        loop g _ _ 0  !v _  _  !c !d = (c, d, g * v)
        loop g x y !u !v !a !b !c !d =
            let (u2,a2,b2) = halfLoop x y u a b
                (v2,c2,d2) = halfLoop x y v c d
             in if u2 >= v2
                then loop g x y (u2  v2) v2 (a2  c2) (b2  d2) c2 d2
                else loop g x y u2 (v2  u2) a2 b2 (c2  a2) (d2  b2)
#endif
areEven :: [Integer] -> Bool
areEven = and . map even
log2 :: Integer -> Int
#ifdef VERSION_integer_gmp
log2 0 = 0
log2 x = I# (integerLog2# x)
#else
log2 = imLog 2
  where
    imLog b x = if x < b then 0 else (x `div` b^l) `doDiv` l
      where
        l = 2 * imLog (b * b) x
        doDiv x' l' = if x' < b then l' else (x' `div` b) `doDiv` (l' + 1)
#endif