{-# LANGUAGE ForeignFunctionInterface #-} module Picosat ( solve, solveST, unsafeSolve, Solution(..) ) where import Control.Monad import System.IO.Unsafe (unsafePerformIO) import Control.Monad.ST (ST) import Control.Monad.ST.Unsafe (unsafeIOToST) import Foreign.Ptr import Foreign.C.Types default (Int) foreign import ccall unsafe "picosat_init" picosat_init :: IO (Ptr a) foreign import ccall unsafe "picosat_reset" picosat_reset :: Ptr a -> IO () foreign import ccall unsafe "picosat_add" picosat_add :: Ptr a -> CInt -> IO CInt foreign import ccall unsafe "picosat_variables" picosat_variables :: Ptr a -> IO CInt foreign import ccall unsafe "picosat_sat" picosat_sat :: Ptr a -> CInt -> IO CInt foreign import ccall unsafe "picosat_deref" picosat_deref :: Ptr a -> CInt -> IO CInt unknown, satisfiable, unsatisfiable :: CInt unknown = 0 satisfiable = 10 unsatisfiable = 20 data Solution = Solution [Int] | Unsatisfiable | Unknown deriving (Show, Eq) addClause :: Ptr a -> [CInt] -> IO () addClause pico cl = mapM_ (picosat_add pico) (cl ++ [0]) addClauses :: Ptr a -> [[CInt]] -> IO () addClauses pico = mapM_ (addClause pico) getSolution :: Ptr a -> IO Solution getSolution pico = do vars <- picosat_variables pico sol <- forM [1..vars] $ \i -> do s <- picosat_deref pico i return $ i * s return $ Solution $ map fromIntegral sol solution :: Ptr a -> IO Solution solution pico = do res <- picosat_sat pico (-1) case res of a | a == unknown -> return Unknown | a == unsatisfiable -> return Unsatisfiable | a == satisfiable -> getSolution pico | otherwise -> error "Picosat error." toCIntegers :: Integral a => [[a]] -> [[CInt]] toCIntegers = map $ map fromIntegral solve :: Integral a => [[a]] -> IO Solution solve cls = do let ccls = toCIntegers cls pico <- picosat_init _ <- addClauses pico ccls sol <- solution pico picosat_reset pico return sol {-# NOINLINE solveST #-} solveST :: Integral a => [[a]] -> ST t Solution solveST = unsafeIOToST . solve {-# NOINLINE unsafeSolve #-} unsafeSolve :: Integral a => [[a]] -> Solution unsafeSolve = unsafePerformIO . solve