{-# LANGUAGE BangPatterns, GeneralizedNewtypeDeriving #-}
module Jukebox.HighSat where

import MiniSat hiding (neg)
import qualified MiniSat
import qualified Jukebox.Seq as Seq
import Jukebox.Seq(Seq, List)
import Jukebox.Form(Signed(..), neg)
import qualified Jukebox.Map as Map
import Jukebox.Map(Map)
import Control.Monad.State.Strict
import Control.Monad.Reader
import Control.Monad.Trans
import Data.Hashable
import Data.Traversable hiding (mapM, sequence)
import Control.Applicative
import Data.Maybe
import Data.List(partition)

newtype Sat1 a b = Sat1 { runSat1_ :: ReaderT Solver (ReaderT (Watch a) (StateT (Map a Lit) IO)) b } deriving (Functor, Monad, MonadIO)
newtype Sat a b c = Sat { runSat_ :: ReaderT (Watch a) (StateT (Map b (SatState a)) IO) c } deriving (Functor, Monad, MonadIO)
data SatState a = SatState Solver (Map a Lit)
type Watch a = a -> Sat1 a ()

data Form a
  = Lit (Signed a)
  | And (Seq (Form a))
  | Or (Seq (Form a))

nt :: Form a -> Form a
nt (Lit x) = Lit (neg x)
nt (And xs) = Or (fmap nt xs)
nt (Or xs) = And (fmap nt xs)

conj, disj :: List f => f (Form a) -> Form a
conj = And . Seq.fromList
disj = Or . Seq.fromList

true, false :: Form a
true = And Seq.Nil
false = Or Seq.Nil

unique :: List f => f (Form a) -> Form a
unique = u . Seq.toList
  where u [x] = true
        u (x:xs) = conj [disj [nt x, conj (map nt xs)],
                         u xs]

runSat :: (Hashable b, Ord b) => Watch a -> [b] -> Sat a b c -> IO c
runSat w idxs x = go idxs Map.empty
  where go [] m = evalStateT (runReaderT (runSat_ x) w) m
        go (idx:idxs) m =
          withNewSolver $ \s -> go idxs (Map.insert idx (SatState s Map.empty) m)

runSat1 :: (Ord a, Hashable a) => Watch a -> Sat1 a b -> IO b
runSat1 w x = runSat w [()] (atIndex () x)

atIndex :: (Ord a, Hashable a, Ord b, Hashable b) => b -> Sat1 a c -> Sat a b c
atIndex !idx m = do
  watch <- Sat ask
  SatState s ls <- Sat (gets (Map.findWithDefault (error "withSolver: index not found") idx))
  (x, ls') <- liftIO (runStateT (runReaderT (runReaderT (runSat1_ m) s) watch) ls)
  Sat (modify (Map.insert idx (SatState s ls')))
  return x

solve :: (Ord a, Hashable a) => [Signed a] -> Sat1 a Bool
solve xs = do
  s <- Sat1 ask
  ls <- mapM lit xs
  liftIO (MiniSat.solve s ls)

model :: (Ord a, Hashable a) => Sat1 a (a -> Bool)
model = do
  s <- Sat1 ask
  m <- Sat1 (lift get)
  vals <- liftIO (traverse (MiniSat.modelValue s) m)
  return (\v -> fromMaybe False (Map.findWithDefault Nothing v vals))

modelValue :: (Ord a, Hashable a) => a -> Sat1 a Bool
modelValue x = do
  s <- Sat1 ask
  l <- var x
  Just b <- liftIO (MiniSat.modelValue s l)
  return b

addForm :: (Ord a, Hashable a) => Form a -> Sat1 a ()
addForm f = do
  s <- Sat1 ask
  cs <- flatten f
  liftIO (Seq.mapM (MiniSat.addClause s . Seq.toList) cs)
  return ()

flatten :: (Ord a, Hashable a) => Form a -> Sat1 a (Seq (Seq Lit))
flatten (Lit l) = fmap (Seq.Unit . Seq.Unit) (lit l)
flatten (And fs) = fmap Seq.concat (Seq.mapM flatten fs)
flatten (Or fs) = fmap (fmap Seq.concat . Seq.sequence) (Seq.mapM flatten fs)

lit :: (Ord a, Hashable a) => Signed a -> Sat1 a Lit
lit (Pos x) = var x
lit (Neg x) = liftM MiniSat.neg (var x)

var :: (Ord a, Hashable a) => a -> Sat1 a Lit
var x = do
  s <- Sat1 ask
  m <- Sat1 get
  case Map.lookup x m of
    Nothing -> do
      l <- liftIO (MiniSat.newLit s)
      Sat1 (put (Map.insert x l m))
      w <- Sat1 (lift ask)
      w x
      return l
    Just l -> return l