module Data.AIG.Interface
  ( 
    IsLit(..)
  , IsAIG(..)
  , lazyMux
    
  , Proxy(..)
  , SomeGraph(..)
  , Network(..)
  , networkInputCount
  
  , LitView(..)
  , LitTree(..)
  , toLitTree
  , fromLitTree
  , toLitForest
  , fromLitForest
  , foldAIG
  , foldAIGs
  , unfoldAIG
  , unfoldAIGs
    
  , SatResult(..)
  , VerifyResult(..)
  , toSatResult
  , toVerifyResult
    
  , genLitView
  , genLitTree
  , getMaxInput
  , buildNetwork
  , randomNetwork
  ) where
import Control.Applicative
import Control.Monad
import Prelude hiding (not, and, or)
import Test.QuickCheck (Gen, Arbitrary(..), generate, oneof, sized, choose)
data LitView a
  = And !a !a
  | NotAnd !a !a
  | Input !Int
  | NotInput !Int
  | TrueLit
  | FalseLit
 deriving (Eq,Show,Ord,Functor)
newtype LitTree = LitTree { unLitTree :: LitView LitTree }
 deriving (Eq,Show,Ord)
class IsLit l where
  
  not :: l s -> l s
  
  
  
  (===) :: l s -> l s -> Bool
data Proxy l g where
  Proxy :: IsAIG l g => (forall a . a -> a) -> Proxy l g
class IsLit l => IsAIG l g | g -> l where
  
  withNewGraph :: Proxy l g 
                            
               -> (forall s . g s -> IO a)
                            
               -> IO a
  withNewGraph p f = newGraph p >>= (`withSomeGraph` f)
  
  
  newGraph :: Proxy l g
           -> IO (SomeGraph g)
  newGraph p = withNewGraph p (return . SomeGraph)
  
  aigerNetwork :: Proxy l g
               -> FilePath
               -> IO (Network l g)
  
  trueLit :: g s -> l s
  
  falseLit :: g s -> l s
  
  constant :: g s -> Bool -> l s
  constant g True  = trueLit  g
  constant g False = falseLit g
  
  
  asConstant :: g s -> l s -> Maybe Bool
  asConstant g l | l === trueLit g = Just True
                 | l === falseLit g = Just False
                 | otherwise = Nothing
  
  newInput :: g s -> IO (l s)
  
  and :: g s -> l s -> l s -> IO (l s)
  
  ands :: g s -> [l s] -> IO (l s)
  ands g [] = return (trueLit g)
  ands g (x:r) = foldM (and g) x r
  
  or :: g s -> l s -> l s -> IO (l s)
  or g x y = not <$> and g (not x) (not y)
  
  eq :: g s -> l s -> l s -> IO (l s)
  eq g x y = not <$> xor g x y
  
  implies :: g s -> l s -> l s -> IO (l s)
  implies g x y = or g (not x) y
  
  xor :: g s -> l s -> l s -> IO (l s)
  xor g x y = do
    o <- or g x y
    a <- and g x y
    and g o (not a)
  
  mux :: g s -> l s -> l s -> l s -> IO (l s)
  mux g c x y = do
   x' <- and g c x
   y' <- and g (not c) y
   or g x' y'
  
  inputCount :: g s -> IO Int
  
  getInput :: g s -> Int -> IO (l s)
  
  writeAiger :: FilePath -> Network l g -> IO ()
  
  checkSat :: g s -> l s -> IO SatResult
  
  cec :: Network l g -> Network l g -> IO VerifyResult
  
  evaluator :: g s
            -> [Bool]
            -> IO (l s -> Bool)
  
  evaluate :: Network l g
           -> [Bool]
           -> IO [Bool]
  evaluate (Network g outputs) inputs = do
    f <- evaluator g inputs
    return (f <$> outputs)
  
  abstractEvaluateAIG
          :: g s
          -> (LitView a -> IO a)
          -> IO (l s -> IO a)
foldAIG :: IsAIG l g
        => g s
        -> (LitView a -> IO a)
        -> l s
        -> IO a
foldAIG n view l = do
   eval <- abstractEvaluateAIG n view
   eval l
foldAIGs :: IsAIG l g
        => g s
        -> (LitView a -> IO a)
        -> [l s]
        -> IO [a]
foldAIGs n view ls = do
   eval <- abstractEvaluateAIG n view
   mapM eval ls
unfoldAIG :: IsAIG l g
          => g s
          -> (a -> IO (LitView a))
          -> a -> IO (l s)
unfoldAIG n unfold = f
 where f = unfold >=> g
       g (And x y)    = and' (f x) (f y)
       g (NotAnd x y) = fmap not $ and' (f x) (f y)
       g (Input i)    = getInput n i
       g (NotInput i) = fmap not $ getInput n i
       g TrueLit      = return $ trueLit n
       g FalseLit     = return $ falseLit n
       and' mx my = do
          x <- mx
          y <- my
          and n x y
unfoldAIGs :: IsAIG l g
          => g s
          -> (a -> IO (LitView a))
          -> [a] -> IO [l s]
unfoldAIGs n unfold = mapM (unfoldAIG n unfold)
toLitTree :: IsAIG l g => g s -> l s -> IO LitTree
toLitTree g = foldAIG g (return . LitTree)
fromLitTree :: IsAIG l g => g s -> LitTree -> IO (l s)
fromLitTree g = unfoldAIG g (return . unLitTree)
toLitForest :: IsAIG l g => g s -> [l s] -> IO [LitTree]
toLitForest g = foldAIGs g (return . LitTree)
fromLitForest :: IsAIG l g => g s -> [LitTree] -> IO [l s]
fromLitForest g = unfoldAIGs g (return . unLitTree)
lazyMux :: IsAIG l g => g s -> l s -> IO (l s) -> IO (l s) -> IO (l s)
lazyMux g c
  | c === (trueLit g)  = \x _y -> x
  | c === (falseLit g) = \_x y -> y
  | otherwise = \x y -> join $ pure (mux g c) <*> x <*> y
data Network l g where
   Network :: IsAIG l g => g s -> [l s] -> Network l g
networkInputCount :: Network l g -> IO Int
networkInputCount (Network g _) = inputCount g
data SomeGraph g where
  SomeGraph :: g s -> SomeGraph g
withSomeGraph :: SomeGraph g
              -> (forall s . g s -> IO a)
              -> IO a
withSomeGraph (SomeGraph g) f = f g
data SatResult
   = Unsat
   | Sat !([Bool])
   | SatUnknown
  deriving (Eq,Show)
data VerifyResult
   = Valid
   | Invalid [Bool]
   | VerifyUnknown
  deriving (Eq, Show)
toVerifyResult :: SatResult -> VerifyResult
toVerifyResult Unsat = Valid
toVerifyResult (Sat l) = Invalid l
toVerifyResult SatUnknown = VerifyUnknown
toSatResult :: VerifyResult -> SatResult
toSatResult Valid = Unsat
toSatResult (Invalid l) = Sat l
toSatResult VerifyUnknown = SatUnknown
genLitView :: Gen a -> Gen (LitView a)
genLitView gen = oneof
     [ return TrueLit
     , return FalseLit
     , sized $ \n -> choose (0,n1) >>= \i -> return (Input i)
     , sized $ \n -> choose (0,n1) >>= \i -> return (NotInput i)
     , do x <- gen
          y <- gen
          return (And x y)
     , do x <- gen
          y <- gen
          return (NotAnd x y)
     ]
genLitTree :: Gen LitTree
genLitTree = fmap LitTree $ genLitView genLitTree
getMaxInput :: LitTree -> Int
getMaxInput (LitTree x) =
  case x of
     TrueLit -> 0
     FalseLit -> 0
     Input i -> i
     NotInput i -> i
     And a b -> max (getMaxInput a) (getMaxInput b)
     NotAnd a b -> max (getMaxInput a) (getMaxInput b)
instance Arbitrary LitTree where
  arbitrary = genLitTree
  shrink (LitTree TrueLit)      = []
  shrink (LitTree FalseLit)     = []
  shrink (LitTree (Input _))    = [LitTree TrueLit, LitTree FalseLit]
  shrink (LitTree (NotInput _)) = [LitTree TrueLit, LitTree FalseLit]
  shrink (LitTree (And x y)) =
      [ LitTree TrueLit, LitTree FalseLit, x, y ] ++
      [ LitTree (And x' y') | (x',y') <- shrink (x,y) ]
  shrink (LitTree (NotAnd x y)) =
      [ LitTree TrueLit, LitTree FalseLit, x, y ] ++
      [ LitTree (NotAnd x' y') | (x',y') <- shrink (x,y) ]
buildNetwork :: IsAIG l g => Proxy l g -> [LitTree] -> IO (Network l g)
buildNetwork proxy litForrest = do
   let maxInput = foldr max 0 $ map getMaxInput litForrest
   (SomeGraph g) <- newGraph proxy
   forM_ [0..maxInput] (\_ -> newInput g)
   ls <- fromLitForest g litForrest
   return (Network g ls)
randomNetwork :: IsAIG l g => Proxy l g -> IO (Network l g)
randomNetwork proxy = generate arbitrary >>= buildNetwork proxy