{-# language MultiParamTypeClasses #-}

module Satchmo.Binary 

( Number, width, number, fixed
, add, times
, equals
)

where

import Prelude hiding ( and, or, not )

import qualified Satchmo.Code as C
import Satchmo.Boolean
import Satchmo.Counting

type Booleans = [ Boolean ]

data Number = Number 
            { encode :: Booleans -- lsb first
            , decode :: C.Decoder Integer
            }

instance C.Decode Number Integer where
    decode = decode

width :: Number -> Int
width n = length $ encode n

-- | declare a number variable (bit width)
number :: Int -> SAT Number
number w = do
    xs <- sequence $ replicate w boolean
    return $ make xs

make :: [ Boolean ] -> Number
make xs = Number
           { encode = xs
           , decode = do ys <- mapM C.decode xs ; return $ fromBinary ys
           }

fromBinary :: [ Bool ] -> Integer
fromBinary xs = foldr ( \ x y -> 2*y + if x then 1 else 0 ) 0 xs

toBinary :: Int -> Integer -> [ Bool ]
toBinary 0 0 = []
toBinary b n | b > 0 = 
    let (d,m) = divMod n 2
    in  toEnum ( fromIntegral m ) : toBinary (b-1) d

-- | declare a number constant (bit width, value)
fixed :: Int -> Integer -> SAT Number
fixed b n = do
    xs <- mapM constant $ toBinary b n
    return $ make xs

-- | result width is 1 + largest argument width
add :: Number -> Number -> SAT Number
add ( Number { encode = xs } ) ( Number { encode = ys } ) = do
    false <- constant False
    ( zs, carry ) <- add_with_carry false xs ys
    return $ make $ zs ++ [carry]

-- | result width is largest argument width
-- if overflow, then unsatisfiable
restricted_add :: Number -> Number -> SAT Number 
restricted_add a b = do
    c <- add a b
    restricted ( max (width a) (width b)) c

-- | give only lower k bits, upper bits must be zero,
-- (else unsatisfiable)
restricted :: Int -> Number -> SAT Number
restricted w ( Number { encode = xs } ) = do
    let ( low, high ) = splitAt w xs
    sequence $ do x <- high ; return $ assert [ not x ]
    return $ make low

-- | result has max length of both inputs
add_with_carry :: Boolean 
               -> Booleans -> Booleans
               -> SAT ( Booleans, Boolean )
add_with_carry cin [] [] = return ( [], cin )
add_with_carry cin (x:xs) [] = do
    z <- xor [ cin, x ]
    c <- and [ cin, x ]
    ( zs, cout ) <- add_with_carry c xs []
    return ( z : zs, cout )
add_with_carry cin [] (y:ys) = do
    add_with_carry cin (y:ys) []
add_with_carry cin (x:xs ) (y:ys) = do
    z  <- xor [ cin, x, y ]
    c <- atleast 2 [ cin, x, y ]
    ( zs, cout ) <- add_with_carry c xs ys
    return ( z : zs, cout )

times :: Number -> Number -> SAT Number
times ( Number { encode = [x] } ) ys = times1 x ys
times ( Number { encode = x:xs } ) ys = do
    xys  <- times1 x ys
    xsys <- times (make xs) ys
    zs <- shift xsys
    add xys zs

-- | multiply by 2
shift :: Number -> SAT Number
shift ( Number { encode = xs } ) = do
    false <- constant False 
    return $ make $ false : xs

times1 :: Boolean -> Number -> SAT Number
times1 x ( Number { encode = ys } ) = do
    zs <- mapM ( \ y -> and [x,y] ) ys
    return $ make zs

equals :: Number -> Number -> SAT Boolean
equals ( Number { encode = xs } ) ( Number { encode = ys } ) = do
    equals' xs ys

equals' :: Booleans -> Booleans -> SAT Boolean
equals' [] [] = constant True
equals' (x:xs) (y:ys) = do
    z <- xor [x, y]
    rest <- equals' xs ys
    and [ not z, rest ]
equals' xs [] = and $ map not xs
equals' [] ys = and $ map not ys