{-# language MultiParamTypeClasses #-} module Satchmo.Binary ( Number, width, number, fixed , add, times , equals ) where import Prelude hiding ( and, or, not ) import qualified Satchmo.Code as C import Satchmo.Boolean import Satchmo.Counting type Booleans = [ Boolean ] data Number = Number { encode :: Booleans -- lsb first , decode :: C.Decoder Integer } instance C.Decode Number Integer where decode = decode width :: Number -> Int width n = length $ encode n -- | declare a number variable (bit width) number :: Int -> SAT Number number w = do xs <- sequence $ replicate w boolean return $ make xs make :: [ Boolean ] -> Number make xs = Number { encode = xs , decode = do ys <- mapM C.decode xs ; return $ fromBinary ys } fromBinary :: [ Bool ] -> Integer fromBinary xs = foldr ( \ x y -> 2*y + if x then 1 else 0 ) 0 xs toBinary :: Int -> Integer -> [ Bool ] toBinary 0 0 = [] toBinary b n | b > 0 = let (d,m) = divMod n 2 in toEnum ( fromIntegral m ) : toBinary (b-1) d -- | declare a number constant (bit width, value) fixed :: Int -> Integer -> SAT Number fixed b n = do xs <- mapM constant $ toBinary b n return $ make xs -- | result width is 1 + largest argument width add :: Number -> Number -> SAT Number add ( Number { encode = xs } ) ( Number { encode = ys } ) = do false <- constant False ( zs, carry ) <- add_with_carry false xs ys return $ make $ zs ++ [carry] -- | result width is largest argument width -- if overflow, then unsatisfiable restricted_add :: Number -> Number -> SAT Number restricted_add a b = do c <- add a b restricted ( max (width a) (width b)) c -- | give only lower k bits, upper bits must be zero, -- (else unsatisfiable) restricted :: Int -> Number -> SAT Number restricted w ( Number { encode = xs } ) = do let ( low, high ) = splitAt w xs sequence $ do x <- high ; return $ assert [ not x ] return $ make low -- | result has max length of both inputs add_with_carry :: Boolean -> Booleans -> Booleans -> SAT ( Booleans, Boolean ) add_with_carry cin [] [] = return ( [], cin ) add_with_carry cin (x:xs) [] = do z <- xor [ cin, x ] c <- and [ cin, x ] ( zs, cout ) <- add_with_carry c xs [] return ( z : zs, cout ) add_with_carry cin [] (y:ys) = do add_with_carry cin (y:ys) [] add_with_carry cin (x:xs ) (y:ys) = do z <- xor [ cin, x, y ] c <- atleast 2 [ cin, x, y ] ( zs, cout ) <- add_with_carry c xs ys return ( z : zs, cout ) times :: Number -> Number -> SAT Number times ( Number { encode = [x] } ) ys = times1 x ys times ( Number { encode = x:xs } ) ys = do xys <- times1 x ys xsys <- times (make xs) ys zs <- shift xsys add xys zs -- | multiply by 2 shift :: Number -> SAT Number shift ( Number { encode = xs } ) = do false <- constant False return $ make $ false : xs times1 :: Boolean -> Number -> SAT Number times1 x ( Number { encode = ys } ) = do zs <- mapM ( \ y -> and [x,y] ) ys return $ make zs equals :: Number -> Number -> SAT Boolean equals ( Number { encode = xs } ) ( Number { encode = ys } ) = do equals' xs ys equals' :: Booleans -> Booleans -> SAT Boolean equals' [] [] = constant True equals' (x:xs) (y:ys) = do z <- xor [x, y] rest <- equals' xs ys and [ not z, rest ] equals' xs [] = and $ map not xs equals' [] ys = and $ map not ys