{-# LANGUAGE CPP                  #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE UndecidableInstances #-}
module Agda.TypeChecking.Serialise.Base where
import Control.Exception (evaluate)
import Control.Monad.Catch (catchAll)
import Control.Monad.Reader
import Control.Monad.State.Strict (StateT, gets)
import Data.Proxy
import Data.Array.IArray
import qualified Data.ByteString.Lazy as L
import Data.Hashable
import qualified Data.HashTable.IO as H
import Data.Int (Int32)
import Data.Maybe
import qualified Data.Binary as B
import qualified Data.Binary.Get as B
import Data.Text.Lazy (Text)
import Data.Typeable ( cast, Typeable, typeOf, TypeRep )
import Agda.Syntax.Common (NameId)
import Agda.Syntax.Internal (Term, QName(..), ModuleName(..), nameId)
import Agda.TypeChecking.Monad.Base (TypeError(GenericError), ModuleToSource)
import Agda.Utils.FileName
import Agda.Utils.IORef
import Agda.Utils.Lens
import Agda.Utils.Monad
import Agda.Utils.Pointer
import Agda.Utils.Except (ExceptT, throwError)
import Agda.Utils.TypeLevel
type Node = [Int32]
#if defined(mingw32_HOST_OS) && defined(x86_64_HOST_ARCH)
type HashTable k v = H.CuckooHashTable k v
#else
type HashTable k v = H.BasicHashTable k v
#endif
#ifdef DEBUG
data FreshAndReuse = FreshAndReuse
  { farFresh :: !Int32 
  , farReuse :: !Int32 
  }
#else
newtype FreshAndReuse = FreshAndReuse
  { farFresh :: Int32 
  }
#endif
farEmpty :: FreshAndReuse
farEmpty = FreshAndReuse 0
#ifdef DEBUG
                           0
#endif
lensFresh :: Lens' Int32 FreshAndReuse
lensFresh f r = f (farFresh r) <&> \ i -> r { farFresh = i }
#ifdef DEBUG
lensReuse :: Lens' Int32 FreshAndReuse
lensReuse f r = f (farReuse r) <&> \ i -> r { farReuse = i }
#endif
type QNameId = [NameId]
qnameId :: QName -> QNameId
qnameId (QName (MName ns) n) = map nameId $ n:ns
data Dict = Dict
  
  { nodeD        :: !(HashTable Node    Int32)    
  , stringD      :: !(HashTable String  Int32)    
  , textD        :: !(HashTable Text    Int32)    
  , integerD     :: !(HashTable Integer Int32)    
  , doubleD      :: !(HashTable Double  Int32)    
  
  
  , termD        :: !(HashTable (Ptr Term) Int32) 
  
  
  , nameD        :: !(HashTable NameId  Int32)    
  , qnameD       :: !(HashTable QNameId Int32)    
  
  , nodeC        :: !(IORef FreshAndReuse)  
  , stringC      :: !(IORef FreshAndReuse)
  , textC        :: !(IORef FreshAndReuse)
  , integerC     :: !(IORef FreshAndReuse)
  , doubleC      :: !(IORef FreshAndReuse)
  , termC        :: !(IORef FreshAndReuse)
  , nameC        :: !(IORef FreshAndReuse)
  , qnameC       :: !(IORef FreshAndReuse)
  , stats        :: !(HashTable String Int)
  , collectStats :: Bool
    
    
  , absPathD     :: !(HashTable AbsolutePath Int32) 
  }
emptyDict
  :: Bool
     
  -> IO Dict
emptyDict collectStats = Dict
  <$> H.new
  <*> H.new
  <*> H.new
  <*> H.new
  <*> H.new
  <*> H.new
  <*> H.new
  <*> H.new
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> newIORef farEmpty
  <*> H.new
  <*> pure collectStats
  <*> H.new
data U = forall a . Typeable a => U !a
type Memo = HashTable (Int32, TypeRep) U    
data St = St
  { nodeE     :: !(Array Int32 Node)     
  , stringE   :: !(Array Int32 String)   
  , textE     :: !(Array Int32 Text)     
  , integerE  :: !(Array Int32 Integer)  
  , doubleE   :: !(Array Int32 Double)   
  , nodeMemo  :: !Memo
    
    
  , modFile   :: !ModuleToSource
    
  , includes  :: [AbsolutePath]
    
  }
type S a = ReaderT Dict IO a
type R a = ExceptT TypeError (StateT St IO) a
malformed :: R a
malformed = throwError $ GenericError "Malformed input."
class Typeable a => EmbPrj a where
  icode :: a -> S Int32  
  icod_ :: a -> S Int32  
  value :: Int32 -> R a  
  icode a = do
    tickICode a
    icod_ a
  
  default value :: (Enum a) => Int32 -> R a
  value i = liftIO (evaluate (toEnum (fromIntegral i))) `catchAll` const malformed
  default icod_ :: (Enum a) => a -> S Int32
  icod_ = return . fromIntegral . fromEnum
tickICode :: forall a. Typeable a => a -> S ()
tickICode _ = whenM (asks collectStats) $ do
    let key = "icode " ++ show (typeOf (undefined :: a))
    hmap <- asks stats
    liftIO $ do
      n <- fromMaybe 0 <$> H.lookup hmap key
      H.insert hmap key $! n + 1
runGetState :: B.Get a -> L.ByteString -> B.ByteOffset -> (a, L.ByteString, B.ByteOffset)
runGetState g s n = feed (B.runGetIncremental g) (L.toChunks s)
  where
    feed (B.Done s n' x) ss     = (x, L.fromChunks (s : ss), n + n')
    feed (B.Fail _ _ err) _     = error err
    feed (B.Partial f) (s : ss) = feed (f $ Just s) ss
    feed (B.Partial f) []       = feed (f Nothing) []
icodeX :: (Eq k, Hashable k)
  =>  (Dict -> HashTable k Int32)
  -> (Dict -> IORef FreshAndReuse)
  -> k -> S Int32
icodeX dict counter key = do
  d <- asks dict
  c <- asks counter
  liftIO $ do
    mi <- H.lookup d key
    case mi of
      Just i  -> do
#ifdef DEBUG
        modifyIORef' c $ over lensReuse (+1)
#endif
        return i
      Nothing -> do
        fresh <- (^.lensFresh) <$> do readModifyIORef' c $ over lensFresh (+1)
        H.insert d key fresh
        return fresh
icodeInteger :: Integer -> S Int32
icodeInteger key = do
  d <- asks integerD
  c <- asks integerC
  liftIO $ do
    mi <- H.lookup d key
    case mi of
      Just i  -> do
#ifdef DEBUG
        modifyIORef' c $ over lensReuse (+1)
#endif
        return i
      Nothing -> do
        fresh <- (^.lensFresh) <$> do readModifyIORef' c $ over lensFresh (+1)
        H.insert d key fresh
        return fresh
icodeDouble :: Double -> S Int32
icodeDouble key = do
  d <- asks doubleD
  c <- asks doubleC
  liftIO $ do
    mi <- H.lookup d key
    case mi of
      Just i  -> do
#ifdef DEBUG
        modifyIORef' c $ over lensReuse (+1)
#endif
        return i
      Nothing -> do
        fresh <- (^.lensFresh) <$> do readModifyIORef' c $ over lensFresh (+1)
        H.insert d key fresh
        return fresh
icodeString :: String -> S Int32
icodeString key = do
  d <- asks stringD
  c <- asks stringC
  liftIO $ do
    mi <- H.lookup d key
    case mi of
      Just i  -> do
#ifdef DEBUG
        modifyIORef' c $ over lensReuse (+1)
#endif
        return i
      Nothing -> do
        fresh <- (^.lensFresh) <$> do readModifyIORef' c $ over lensFresh (+1)
        H.insert d key fresh
        return fresh
icodeNode :: Node -> S Int32
icodeNode key = do
  d <- asks nodeD
  c <- asks nodeC
  liftIO $ do
    mi <- H.lookup d key
    case mi of
      Just i  -> do
#ifdef DEBUG
        modifyIORef' c $ over lensReuse (+1)
#endif
        return i
      Nothing -> do
        fresh <- (^.lensFresh) <$> do readModifyIORef' c $ over lensFresh (+1)
        H.insert d key fresh
        return fresh
icodeMemo
  :: (Ord a, Hashable a)
  => (Dict -> HashTable a Int32)    
  -> (Dict -> IORef FreshAndReuse)  
  -> a        
  -> S Int32  
  -> S Int32  
icodeMemo getDict getCounter a icodeP = do
    h  <- asks getDict
    mi <- liftIO $ H.lookup h a
    st <- asks getCounter
    case mi of
      Just i  -> liftIO $ do
#ifdef DEBUG
        modifyIORef' st $ over lensReuse (+ 1)
#endif
        return i
      Nothing -> do
        liftIO $ modifyIORef' st $ over lensFresh (+1)
        i <- icodeP
        liftIO $ H.insert h a i
        return i
{-# INLINE vcase #-}
vcase :: forall a . EmbPrj a => (Node -> R a) -> Int32 -> R a
vcase valu = \ix -> do
    memo <- gets nodeMemo
    
    let aTyp = typeOf (undefined :: a)
    
    
    maybeU <- liftIO $ H.lookup memo (ix, aTyp)
    case maybeU of
      
      Just (U u) -> maybe malformed return (cast u)
      
      Nothing    -> do
          v <- valu . (! ix) =<< gets nodeE
          liftIO $ H.insert memo (ix, aTyp) (U v)
          return v
class ICODE t b where
  icodeArgs :: IsBase t ~ b => All EmbPrj (Domains t) =>
               Proxy t -> Products (Domains t) -> S [Int32]
instance IsBase t ~ 'True => ICODE t 'True where
  icodeArgs _ _  = return []
instance ICODE t (IsBase t) => ICODE (a -> t) 'False where
  icodeArgs _ (a , as) = icode a >>= \ hd -> (hd :) <$> icodeArgs (Proxy :: Proxy t) as
{-# INLINE icodeN #-}
icodeN :: forall t. ICODE t (IsBase t) => Currying (Domains t) (S Int32) =>
          All EmbPrj (Domains t) =>
          Int32 -> t -> Arrows (Domains t) (S Int32)
icodeN tag _ =
  currys (Proxy :: Proxy (Domains t)) (Proxy :: Proxy (S Int32)) $ \ args ->
  icodeNode . (tag :) =<< icodeArgs (Proxy :: Proxy t) args
{-# INLINE icodeN' #-}
icodeN' :: forall t. ICODE t (IsBase t) => Currying (Domains t) (S Int32) =>
           All EmbPrj (Domains t) =>
           t -> Arrows (Domains t) (S Int32)
icodeN' _ =
  currys (Proxy :: Proxy (Domains t)) (Proxy :: Proxy (S Int32)) $ \ args ->
  icodeNode =<< icodeArgs (Proxy :: Proxy t) args
class VALU t b where
  valuN' :: b ~ IsBase t =>
            All EmbPrj (Domains t) =>
            t -> Products (Constant Int32 (Domains t)) -> R (CoDomain t)
  valueArgs :: b ~ IsBase t =>
               All EmbPrj (CoDomain t ': Domains t) =>
               Proxy t -> Node -> Maybe (Products (Constant Int32 (Domains t)))
instance VALU t 'True where
  valuN' c () = return c
  valueArgs _ xs = case xs of
    [] -> Just ()
    _  -> Nothing
instance VALU t (IsBase t) => VALU (a -> t) 'False where
  valuN' c (a, as) = value a >>= \ v -> valuN' (c v) as
  valueArgs _ xs = case xs of
    (x : xs') -> (x,) <$> valueArgs (Proxy :: Proxy t) xs'
    _         -> Nothing
{-# INLINE valuN #-}
valuN :: forall t. VALU t (IsBase t) =>
         Currying (Constant Int32 (Domains t)) (R (CoDomain t)) =>
         All EmbPrj (Domains t) =>
         t -> Arrows (Constant Int32 (Domains t)) (R (CoDomain t))
valuN f = currys (Proxy :: Proxy (Constant Int32 (Domains t)))
                 (Proxy :: Proxy (R (CoDomain t)))
                 (valuN' f)
{-# INLINE valueN #-}
valueN :: forall t. VALU t (IsBase t) =>
          All EmbPrj (CoDomain t ': Domains t) =>
          t -> Int32 -> R (CoDomain t)
valueN t = vcase valu where
  valu int32s = case valueArgs (Proxy :: Proxy t) int32s of
                  Nothing -> malformed
                  Just vs -> valuN' t vs