module Data.ABC.AIG
  ( AIG
  , newAIG
  , readAiger
  , proxy
  , Lit
  , true
  , false
  , writeAIGManToCNFWithMapping
  , checkSat'
    
  , AIG.Network(..)
  , AIG.networkInputCount
  , AIG.IsAIG(..)
  , AIG.IsLit(..)
  , AIG.SatResult(..)
  , AIG.VerifyResult(..)
  , AIG.SomeGraph(..)
  ) where
import Prelude ()
import Prelude.Compat hiding (and, or, not)
import Foreign
import Control.Exception
import Control.Monad.Compat
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as VM
import System.IO
import qualified System.IO.Unsafe as Unsafe
import qualified Data.Map as Map
import           Data.IORef
import Data.ABC.Internal.ABC
import Data.ABC.Internal.AIG
import Data.ABC.Internal.CNF
import Data.ABC.Internal.Field
import Data.ABC.Internal.FRAIG
import Data.ABC.Internal.IO
import Data.ABC.Internal.Main
import Data.ABC.Internal.Orphan
import Data.ABC.Internal.VecInt
import Data.ABC.Internal.VecPtr
import qualified Data.AIG as AIG
import           Data.AIG.Interface (LitView(..))
import qualified Data.AIG.Trace as Tr
import Data.ABC.Util
newtype AIG s = AIG { _ntkPtr :: ForeignPtr Abc_Ntk_t_ }
newtype Lit s = Lit { unLit :: Abc_Obj_t }
  deriving (Eq, Storable, Ord)
proxy :: AIG.Proxy Lit AIG
proxy = AIG.Proxy id
newAIG :: IO (AIG.SomeGraph AIG)
newAIG = do
  abcStart 
  bracketOnError (abcNtkAlloc AbcNtkStrash AbcFuncAig True) abcNtkDelete $ \p -> do
    AIG.SomeGraph . AIG <$> newForeignPtr p_abcNtkDelete p
foreachPo :: Abc_Ntk_t -> (Abc_Obj_t -> IO a) -> IO [a]
foreachPo ntk f = do
  v <- abcNtkPos ntk
  c <- vecPtrSize v
  forN c $ \i -> do
    f =<< vecPtrEntry v i
foreachPo_ :: Abc_Ntk_t -> (Abc_Obj_t -> IO ()) -> IO ()
foreachPo_ ntk f = do
  v <- abcNtkPos ntk
  c <- vecPtrSize v
  forN_ c (f <=< vecPtrEntry v)
deletePos :: Abc_Ntk_t -> IO ()
deletePos p = do
   foreachPo_ p abcNtkDeleteObjPo
   clearVecPtr =<< abcNtkPos p
   clearVecPtr =<< abcNtkCos p
checkReadable :: FilePath -> IO ()
checkReadable path = withFile path ReadMode (\_ -> return())
readAiger :: FilePath -> IO (AIG.Network Lit AIG)
readAiger path = do
  abcStart
  checkReadable path
  bracketOnError (ioReadAiger path True) abcNtkDelete $ \p -> do
    lc <- abcNtkLatchNum p
    when (lc > 0) $ fail "Reading networks with latches is not yet supported."
    
    nmManFree =<< readAt abcNtkManName p
    writeAt abcNtkManName p =<< nmManCreate 1
    
    outputs <- foreachPo p $ \o -> do
      i <- vecIntEntry (abcObjFanins o) 0
      abcNtkDeleteObjPo o
      Lit <$> abcNtkObj p (fromIntegral i)
    
    clearVecPtr =<< abcNtkCos p
    clearVecPtr =<< abcNtkPos p
    
    fp <- newForeignPtr p_abcNtkDelete p
    return (AIG.Network (AIG fp) outputs)
withAIGPtr :: AIG s -> (Abc_Ntk_t -> IO a) -> IO a
withAIGPtr (AIG fp) m = withForeignPtr fp m
instance AIG.IsLit Lit where
  not (Lit l) = Lit (abcObjNot l)
  Lit x === Lit y = x == y
true :: AIG s -> Lit s
true a = Unsafe.unsafePerformIO $ do
  Lit <$> withAIGPtr a abcAigConst1
false :: AIG s -> Lit s
false a = AIG.not (true a)
checkSat' :: Ptr Abc_Ntk_t -> IO AIG.SatResult
checkSat' pp = do
  p <- peek pp
  ic <- vecPtrSize =<< abcNtkPis p
  oc <- vecPtrSize =<< abcNtkPos p
  assert (oc == 1) $ do
  isConstant <- checkIsConstant p
  case isConstant of
    Just True -> return (AIG.Sat (replicate ic False))
    Just False  -> return AIG.Unsat
    Nothing -> do
      let params = proveParamsDefault
                   { nItersMax'Prove_Params = 5
                   
                   
                   
                   , fUseRewriting'Prove_Params = False
                   }
      with params $ \pParams -> do
        r <- abcNtkIvyProve pp (castPtr pParams)
        case r of
          1 -> fail "Could not decide equivalence."
          0 -> do
            p1 <- peek pp
            pModel <- abcNtkModel p1
            AIG.Sat . fmap toBool <$> peekArray ic pModel
          1 -> return AIG.Unsat
          _ -> error $ "Unrecognized return code " ++ show r ++ " from abcNtkIvyProve"
memoFoldAIG :: AIG s -> (LitView a -> IO a) -> IO (Lit s -> IO a)
memoFoldAIG g view = do
    r <- newIORef Map.empty
    let memo o t = do
           m <- readIORef r
           writeIORef r $! Map.insert o t m
           return t
        go (And x y)    = view =<< (pure And <*> objTerm x <*> objTerm y)
        go (NotAnd x y) = view =<< (pure NotAnd <*> objTerm x <*> objTerm y)
        go (Input i)    = view (Input i)
        go (NotInput i) = view (NotInput i)
        go TrueLit      = view TrueLit
        go FalseLit     = view FalseLit
        objTerm o = do
           m <- readIORef r
           case Map.lookup o m of
              Just t -> return t
              _ -> memo o =<< go =<< litViewInner o
    
    return $ (\l -> withAIGPtr g $ \_p -> objTerm l)
litViewInner :: Lit s -> IO (LitView (Lit s))
litViewInner (Lit l) = do
  let c = abcObjIsComplement l
  let o = abcObjRegular l
  i <- abcObjId o
  ty <- abcObjType o
  case ty of
    AbcObjPi -> if c then return (NotInput (fromIntegral (i1))) else return (Input (fromIntegral (i1)))
    AbcObjConst1 -> if c then return FalseLit else return TrueLit
    AbcObjNode -> do
     isand <- abcObjIsAnd o
     if isand
       then do
         x <- abcObjLit0 o
         y <- abcObjLit1 o
         if c then return (NotAnd (Lit x) (Lit y))
              else return (And (Lit x) (Lit y))
       else fail "invalid AIG literal: non-and node"
    _ -> fail ("invalid AIG literal: "++show ty++" "++show i++" "++show c)
instance Tr.Traceable Lit where
  compareLit x y = compare x y
  showLit x = show (unLit x)
instance AIG.IsAIG Lit AIG where
  newGraph _ = newAIG
  aigerNetwork _ = readAiger
  trueLit = true
  falseLit = false
  newInput a =
    withAIGPtr a $ \p -> do
      Lit <$> abcNtkCreateObj p AbcObjPi
  and a x y = do
    withAIGPtr a $ \p -> do
     manFunc <- castPtr <$> abcNtkManFunc p
     Lit <$> abcAigAnd manFunc (unLit x) (unLit y)
  xor a x y = do
    withAIGPtr a $ \p -> do
      manFunc <- castPtr <$> abcNtkManFunc p
      Lit <$> abcAigXor manFunc (unLit x) (unLit y)
  mux a c t f = do
    withAIGPtr a $ \p -> do
      manFunc <- castPtr <$> abcNtkManFunc p
      Lit <$> abcAigMux manFunc (unLit c) (unLit t) (unLit f)
  inputCount a = withAIGPtr a (vecPtrSize <=< abcNtkPis)
  getInput a i = do
    withAIGPtr a $ \p -> do
      v <- abcNtkPis p
      sz <- vecPtrSize v
      assert (0 <= i && i < sz) $
        Lit . castPtr <$> vecPtrEntry v i
  writeAiger path a = do
    withNetworkPtr a $ \p -> do
      ioWriteAiger p path True False False
  writeCNF aig l path =
    withNetworkPtr (AIG.Network aig [l]) $ \pNtk -> do
      withAbcNtkToDar pNtk False False $ \pMan -> do
        vars <- writeAIGManToCNFWithMapping pMan path
        ciCount <- aigManCiNum pMan
        forM [0..(ciCount  1)] $ \i -> do
          ci <- aigManCi pMan (fromIntegral i)
          ((vars V.!) . fromIntegral) `fmap` (aigObjId ci)
  checkSat g l = do
    withNetworkPtr (AIG.Network g [l]) $ \p ->
      alloca $ \pp ->
        bracket_
          (poke pp =<< abcNtkDup p)
          (abcNtkDelete =<< peek pp)
          (checkSat' pp)
  litView g l = withAIGPtr g $ \_ -> litViewInner l
  abstractEvaluateAIG = memoFoldAIG
  cec x y = do
    ix <- networkInputCount x
    iy <- networkInputCount y
    assert (ix == iy) $ do
    assert (outputCount x == outputCount y) $ do
    withNetworkPtr x $ \xp -> do
    withNetworkPtr y $ \yp -> do
      alloca $ \pp -> do
        bracket_
          (poke pp =<< abcNtkMiter xp yp False 0 False False)
          (abcNtkDelete =<< peek pp)
          (AIG.toVerifyResult <$> checkSat' pp)
  evaluator g inputs_l = do
    withAIGPtr g $ \ntk -> do
      
      objs <- abcNtkObjs ntk
      
      var_count <- vecPtrSize objs
      v <- VM.new var_count
      
      VM.write v 0 True
      
      pis <- abcNtkPis ntk
      pi_count <- vecPtrSize pis
      let inputs = V.fromList inputs_l
      when (V.length inputs /= pi_count) $
        fail "evaluate given unexpected number of inputs."
      forI_ pi_count $ \pi_idx -> do
        o <- vecPtrEntry pis pi_idx
        idx <- fromIntegral <$> abcObjId o
        VM.write v idx (inputs V.! pi_idx)
      
      forI_ var_count $ \i -> do
        o <- vecPtrEntry objs i
        
        unless (o == nullPtr) $ do
        is_and <- abcObjIsAnd o
        when is_and $ do
          r0 <- evaluateFn v . Lit =<< abcObjLit0 o
          r1 <- evaluateFn v . Lit =<< abcObjLit1 o
          VM.write v i (r0 && r1)
      
      pureEvaluateFn g <$> V.freeze v
forI_ :: Monad m => Int -> (Int -> m ()) -> m ()
forI_ = go 0
  where go i n f | i < n = f i >> go (i+1) n f
                 | otherwise = return ()
pureEvaluateFn :: AIG s -> V.Vector Bool -> Lit s -> Bool
pureEvaluateFn g v (Lit l) = Unsafe.unsafePerformIO $ withAIGPtr g $ \_ -> do
  let c = abcObjIsComplement l
  let o = abcObjRegular l
  i <- fromIntegral <$> abcObjId o
  let n = V.length v
  when (i >= n) $ fail "Literal created after evaluator was created."
  return ((v V.! i) /= c)
evaluateFn :: VM.IOVector Bool
           -> Lit s
           -> IO Bool
evaluateFn v (Lit l) = do
  let c = abcObjIsComplement l
  let o = abcObjRegular l
  i <- fromIntegral <$> abcObjId o
  let n = VM.length v
  when (i >= n) $ fail "Literal created after evaluator was created."
  r <- VM.read v i
  return (r /= c)
networkInputCount :: AIG.Network l g -> IO Int
networkInputCount (AIG.Network g _) = AIG.inputCount g
outputCount :: AIG.Network l g -> Int
outputCount (AIG.Network _ o) = length o
withNetworkPtr :: AIG.Network Lit AIG
               -> (Abc_Ntk_t -> IO a)
               -> IO a
withNetworkPtr (AIG.Network x o) m = do
  withAIGPtr x $ \p ->
    bracket_
      (mapM_ (addPo p) o)
      (deletePos p)
      (m p)
addPo :: Abc_Ntk_t -> Lit s -> IO ()
addPo p (Lit ptr) = do
  po <- abcNtkCreateObj p AbcObjPo
  abcObjAddFanin po ptr
checkIsConstant :: Abc_Ntk_t -> IO (Maybe Bool)
checkIsConstant p = do
  c <- abcNtkMiterIsConstant p
  case c of
    1 -> return Nothing
    0 -> return (Just True)
    1 -> return (Just False)
    _ -> error $ "Unrecognized return code " ++ show c ++ " from abcNtkMiterIsConstant"
withAbcNtkToDar :: Abc_Ntk_t
                -> Bool
                -> Bool
                -> (Aig_Man_t -> IO a)
                -> IO a
withAbcNtkToDar ntk exors registers h = do
  bracket (abcNtkToDar ntk exors registers)
          aigManStop
          h
writeAIGManToCNFWithMapping :: Aig_Man_t -> FilePath -> IO (V.Vector Int)
writeAIGManToCNFWithMapping pMan path =
  withCnfDerive pMan 0 $ \pCnf -> do
    cnfDataWriteIntoFile pCnf path 1 nullPtr nullPtr
    getCNFMapping pMan pCnf
getCNFMapping :: Aig_Man_t -> Cnf_Dat_t -> IO (V.Vector Int)
getCNFMapping pMan pCnf = do
    objCount <- fmap fromIntegral $ aigManObjNumMax pMan
    varsPtr <- cnfVarNums pCnf
    V.generateM objCount $ \i -> fromIntegral <$> peekElemOff varsPtr i