-- |This is an implementation of the nimbers, which are technically a field
-- over the non-negative ordinals, but in this case are restricted to the
-- non-negative integers. Note that division by n is speedy for n < 16,
-- about one second for n < 256, about a minute for n < 65535, and probably
-- very, very, very slow for n >= 65535.
module Data.Nimber (
               Nimber(fromNimber),
               toNimber, nimRecip
)
where

import Data.Bits
import Data.List
import Data.Maybe
import Data.Ratio
import Control.Monad
import qualified Data.MemoCombinators as Memo
import qualified Data.Set as S
newtype Nimber = Nimber {
      fromNimber :: Integer
} deriving (Eq, Ord)

memoNimber :: (Nimber -> r) -> Nimber -> r
memoNimber = Memo.wrap toNimber fromNimber Memo.integral

-- | cast any non-negative Integer into a Nimber
toNimber :: Integer -> Nimber
toNimber x
         | x < 0 = error "negative nimbers not defined"
         | otherwise = Nimber x


instance Show Nimber where
    show (Nimber x) = '*' : show x
instance Enum Nimber where
    pred (Nimber x) = Nimber (x-1)
    succ (Nimber x) = Nimber (x+1)
    toEnum = Nimber . toInteger
    fromEnum = fromEnum .fromNimber
instance Num Nimber where
    abs = id
    negate = id
    (+) (Nimber x) (Nimber y) = toNimber (x `xor` y)
    signum 0 =  0
    signum _ =  1
    fromInteger = toNimber
    a * b = sum $ fastMult (fromNimber a) (fromNimber b) where
        -- fastMult expands out a product of a pair of nimbers into the products of constituent powers of 2
        -- for example, fastMult 5 6 = [2^2 * 2^2, 2^0 * 2^2, 2^2 * 2^1, 2^0 * 2^1] = [6, 4, 8, 2]
        fastMult a b =
            let aBits = reverse $ toBits a
                bBits = reverse $ toBits b
            in map (\(xs, ys) -> pow2mult $ bitProduct (toBits $ toInteger xs) (toBits $ toInteger ys)) $  filter (\(m, n)  -> aBits !! m * bBits !! n == 1) $ liftM2 (,) [0 ..  length aBits - 1] [0 .. length bBits - 1]
            -- toBits expands a number into its bits; toBits 13 = [1, 1, 0, 1]; toBits 0 = []
            where toBits n = reverse $ unfoldr (\x -> if x==0 then Nothing else Just (x `rem` 2, x `div` 2)) n
                  -- pow2mult multiplies together powers of 2 given in a list as follows:
                  -- pow2mult [3, 0, 1, 0] = (2^(2^3))^3 * 2^(2^1) = 256^3 * 4 = 33152 * 4 = 46256
                  pow2mult [] = 1
                  pow2mult [0] = 1
                  pow2mult [1] = 2
                  pow2mult (0:xs) = pow2mult xs
                  pow2mult (1:xs) = toNimber $ 2^(2^(length xs)) * (fromNimber $ pow2mult xs)
                  pow2mult (x:xs) = pow2mult (x-1:xs) + pow2mult (x-2:(map (+1) xs))
        -- bitProduct combines lists of powers of 2 by zero-padding the shorter list:
        -- bitProduct [1, 0, 2, 1] [1, 3] = [1, 0, 3, 4]
        bitProduct xs ys
            | lx == ly = zipWith (+) xs ys
            | lx < ly = bitProduct ys xs
            | lx > ly = zipWith (+) xs (replicate (lx - ly) 0 ++ ys)
            | otherwise = error "trichotomy violation"
            where lx = length xs
                  ly = length ys
     
instance Fractional Nimber where
    -- Warning: division takes a second or two for 16 <= n <= 255,
    -- a minute or so for 256 <= n < 65535, and probably several minutes
    -- for 65536 <= n <= 4294967295.
    recip = memoNimber recip' where
        recip' a = fromJust $ find (\n -> n * a == 1) [1..]
    fromRational r = (toNimber $ numerator r) / (toNimber $ denominator r)

{-|
  Find the reciprocal of a nimber from the definition.
  This the very slow, original definition version.
  It's only here because I like it, really.
-}
nimRecip :: Nimber -> Nimber
nimRecip = memoNimber nimRecip' where 
    nimRecip' a =  mex . S.toList $ fixedPoint enlarge (S.fromList [0]) where
	fixedPoint f x = fromJust $ find (\x -> f x == x) $ iterate f x
        mex xs = fromJust $ find (`notElem` xs) [0..]
        enlarge xs = xs `S.union` (S.fromList (liftM2 f [1 .. pred a] (S.toList xs)))
        f a' b = (1 + (a' + a) * b) * (nimRecip a')