{-# LANGUAGE DeriveFoldable    #-}
{-# LANGUAGE DeriveFunctor     #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo       #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE ViewPatterns      #-}
module Clash.Normalize.DEC
  (collectGlobals
  ,isDisjoint
  ,mkDisjointGroup
  )
where
import           Control.Concurrent.Supply        (splitSupply)
import qualified Control.Lens                     as Lens
import           Data.Bits                        ((.&.),complement)
import qualified Data.Either                      as Either
import qualified Data.Foldable                    as Foldable
import qualified Data.HashMap.Strict              as HashMap
import qualified Data.IntMap.Strict               as IM
import qualified Data.List                        as List
import qualified Data.Map.Strict                  as Map
import qualified Data.Maybe                       as Maybe
import           Data.Set                         (Set)
import qualified Data.Set                         as Set
import qualified Data.Set.Lens                    as Lens
import           Unbound.Generics.LocallyNameless
  (Bind, bind, embed, fv, unbind, unembed, unrec)
import qualified Unbound.Generics.LocallyNameless as Unbound
import Clash.Core.DataCon    (DataCon, dcTag)
import Clash.Core.Evaluator  (whnf')
import Clash.Core.FreeVars   (termFreeIds, typeFreeVars)
import Clash.Core.Name       (Name (..), string2InternalName)
import Clash.Core.Literal    (Literal (..))
import Clash.Core.Term       (LetBinding, Pat (..), Term (..), TmOccName)
import Clash.Core.TyCon      (tyConDataCons)
import Clash.Core.Type       (Type, isPolyFunTy, mkTyConApp, splitFunForallTy)
import Clash.Core.Util       (collectArgs, mkApps, termType)
import Clash.Normalize.Types (NormalizeState)
import Clash.Normalize.Util  (isConstant)
import Clash.Rewrite.Types
  (RewriteMonad, bindings, evaluator, globalHeap, tcCache, tupleTcCache, uniqSupply)
import Clash.Rewrite.Util    (mkInternalVar, mkSelectorCase,
                              isUntranslatableType)
import Clash.Util
data CaseTree a
  = Leaf a
  | LB [LetBinding] (CaseTree a)
  | Branch Term [(Pat,CaseTree a)]
  deriving (Eq,Show,Functor,Foldable)
isDisjoint :: CaseTree ([Either Term Type])
           -> Bool
isDisjoint (Branch _ [_]) = False
isDisjoint ct = go ct
  where
    go (Leaf _)             = False
    go (LB _ ct')           = go ct'
    go (Branch _ [])        = False
    go (Branch _ [(_,x)])   = go x
    go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty l@(Leaf _) = l
removeEmpty (LB lb ct) =
  case removeEmpty ct of
    Leaf [] -> Leaf []
    ct'     -> LB lb ct'
removeEmpty (Branch s bs) =
  case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
    []  -> Leaf []
    bs' -> Branch s bs'
allEqual :: Eq a => [a] -> Bool
allEqual []     = True
allEqual (x:xs) = all (== x) xs
collectGlobals ::
     Set TmOccName
  -> [(Term,Term)] 
                   
  -> [Term] 
  -> Term 
  -> RewriteMonad NormalizeState
                  (Term,[(Term,([Term],CaseTree [(Either Term Type)]))])
collectGlobals inScope substitution seen (Case scrut ty alts) = do
  rec (alts' ,collected)  <- collectGlobalsAlts inScope substitution seen scrut'
                                                alts
      (scrut',collected') <- collectGlobals inScope substitution
                                            (map fst collected ++ seen) scrut
  return (Case scrut' ty alts',collected ++ collected')
collectGlobals inScope substitution seen e@(collectArgs -> (fun, args@(_:_)))
  | not (isConstant e) = do
    tcm <- Lens.view tcCache
    bndrs <- Lens.use bindings
    primEval <- Lens.view evaluator
    gh <- Lens.use globalHeap
    ids <- Lens.use uniqSupply
    let (ids1,ids2) = splitSupply ids
    uniqSupply Lens..= ids2
    let eval = snd . whnf' primEval bndrs tcm gh ids1 False
    eTy <- termType tcm e
    untran <- isUntranslatableType False eTy
    case untran of
      
      
      False -> case interestingToLift inScope eval fun args of
        Just fun' | fun' `notElem` seen -> do
          (args',collected) <- collectGlobalsArgs inScope substitution
                                                  (fun':seen) args
          let e' = Maybe.fromMaybe (mkApps fun' args') (List.lookup fun' substitution)
          
          
          
          
          return (e',(fun',(seen,Leaf args')):collected)
        _ -> do (args',collected) <- collectGlobalsArgs inScope substitution
                                                        seen args
                return (mkApps fun args',collected)
      _ -> return (e,[])
collectGlobals inScope substitution seen (Letrec b) = do
  (unrec -> lbs,body) <- unbind b
  (body',collected)   <- collectGlobals    inScope substitution seen body
  (lbs',collected')   <- collectGlobalsLbs inScope substitution
                                           (map fst collected ++ seen)
                                           lbs
  return (Letrec (bind (Unbound.rec lbs') body')
         ,map (second (second (LB lbs'))) (collected ++ collected')
         )
collectGlobals _ _ _ e = return (e,[])
collectGlobalsArgs ::
     Set TmOccName
  -> [(Term,Term)] 
                   
  -> [Term] 
  -> [Either Term Type] 
  -> RewriteMonad NormalizeState
                  ([Either Term Type]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsArgs inScope substitution seen args = do
    (_,(args',collected)) <- second unzip <$> mapAccumLM go seen args
    return (args',concat collected)
  where
    go s (Left tm) = do
      (tm',collected) <- collectGlobals inScope substitution s tm
      return (map fst collected ++ s,(Left tm',collected))
    go s (Right ty) = return (s,(Right ty,[]))
collectGlobalsAlts ::
     Set TmOccName
  -> [(Term,Term)] 
                   
  -> [Term] 
  -> Term 
  -> [Bind Pat Term] 
  -> RewriteMonad NormalizeState
                  ([Bind Pat Term]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsAlts inScope substitution seen scrut alts = do
    (alts',collected) <- unzip <$> mapM go alts
    let collectedM  = map (Map.fromList . map (second (second (:[])))) collected
        collectedUN = Map.unionsWith (\(l1,r1) (l2,r2) -> (List.nub (l1 ++ l2),r1 ++ r2)) collectedM
        collected'  = map (second (second (Branch scrut))) (Map.toList collectedUN)
    return (alts',collected')
  where
    go pe = do (p,e) <- unbind pe
               (e',collected) <- collectGlobals inScope substitution seen e
               return (bind p e',map (second (second (p,))) collected)
collectGlobalsLbs ::
     Set TmOccName
  -> [(Term,Term)] 
                   
  -> [Term] 
  -> [LetBinding] 
  -> RewriteMonad NormalizeState
                  ([LetBinding]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsLbs inScope substitution seen lbs = do
    (_,(lbs',collected)) <- second unzip <$> mapAccumLM go seen lbs
    return (lbs',concat collected)
  where
    go :: [Term] -> LetBinding
       -> RewriteMonad NormalizeState
                  ([Term]
                  ,(LetBinding
                   ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                   )
                  )
    go s (id_,unembed -> e) = do
      (e',collected) <- collectGlobals inScope substitution s e
      return (map fst collected ++ s,((id_,embed e'),collected))
mkDisjointGroup :: Set TmOccName 
                -> (Term,([Term],CaseTree [(Either Term Type)]))
                   
                -> RewriteMonad NormalizeState (Term,[Term])
mkDisjointGroup fvs (fun,(seen,cs)) = do
    let argss    = Foldable.toList cs
        argssT   = zip [0..] (List.transpose argss)
        (commonT,uncommonT) = List.partition (isCommon fvs . snd) argssT
        common   = map (second head) commonT
        uncommon = map (Either.lefts) (List.transpose (map snd uncommonT))
        cs'      = fmap (zip [0..]) cs
        cs''     = removeEmpty
                 $ fmap (Either.lefts . map snd)
                        (if null common
                           then cs'
                           else fmap (filter (`notElem` common)) cs')
    tcm <- Lens.view tcCache
    (uncommonCaseM,uncommonProjections) <- case uncommon of
      
      [] -> return (Nothing,[])
      
      (uc:_) -> do
        argTys <- mapM (termType tcm) uc
        disJointSelProj argTys cs''
    let newArgs = mkDJArgs 0 common uncommonProjections
    case uncommonCaseM of
      Just lb -> return (Letrec (bind (Unbound.rec [lb]) (mkApps fun newArgs)), seen)
      Nothing -> return (mkApps fun newArgs, seen)
disJointSelProj :: [Type] 
                -> CaseTree [Term] 
                -> RewriteMonad NormalizeState (Maybe LetBinding,[Term])
disJointSelProj _ (Leaf []) = return (Nothing,[])
disJointSelProj argTys cs = do
    let maxIndex = length argTys - 1
        css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
    (untran,tran) <- partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
    let untranCs   = map (css!!) (map fst untran)
        untranSels = zipWith (\(_,ty) cs' -> genCase ty Nothing []  cs')
                             untran untranCs
    (lbM,projs) <- case tran of
      []       -> return (Nothing,[])
      [(i,ty)] -> return (Nothing,[genCase ty Nothing [] (css!!i)])
      tys      -> do
        tcm    <- Lens.view tcCache
        tupTcm <- Lens.view tupleTcCache
        let m            = length tys
            Just tupTcNm = IM.lookup m tupTcm
            Just tupTc   = HashMap.lookup (nameOcc tupTcNm) tcm
            [tupDc]      = tyConDataCons tupTc
            (tyIxs,tys') = unzip tys
            tupTy        = mkTyConApp tupTcNm tys'
            cs'          = fmap (\es -> map (es !!) tyIxs) cs
            djCase       = genCase tupTy (Just tupDc) tys' cs'
        (scrutId,scrutVar) <- mkInternalVar (string2InternalName "tupIn") tupTy
        projections <- mapM (mkSelectorCase ($(curLoc) ++ "disJointSelProj")
                                            tcm scrutVar (dcTag tupDc)) [0..m-1]
        return (Just (scrutId,embed djCase),projections)
    let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
    return (lbM,selProjs)
  where
    tranOrUnTran _ []       projs     = projs
    tranOrUnTran _ sels     []        = map snd sels
    tranOrUnTran n ((ut,s):uts) (p:projs)
      | n == ut   = s : tranOrUnTran (n+1) uts          (p:projs)
      | otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
isCommon :: Set TmOccName -> [Either Term Type] -> Bool
isCommon _   []             = True
isCommon _   (Right ty:tys) = Set.null (Lens.setOf typeFreeVars ty) &&
                              allEqual (Right ty:tys)
isCommon fvs (Left tm:tms)  = Set.null (Lens.setOf termFreeIds tm Set.\\ fvs) &&
                              allEqual (Left tm:tms)
mkDJArgs :: Int 
         -> [(Int,Either Term Type)] 
         -> [Term] 
         -> [Either Term Type]
mkDJArgs _ cms []   = map snd cms
mkDJArgs _ [] uncms = map Left uncms
mkDJArgs n ((m,x):cms) (y:uncms)
  | n == m    = x       : mkDJArgs (n+1) cms (y:uncms)
  | otherwise = Left y  : mkDJArgs (n+1) ((m,x):cms) uncms
genCase :: Type 
        -> Maybe DataCon 
        -> [Type] 
        -> CaseTree [Term] 
        -> Term
genCase ty dcM argTys = go
  where
    go (Leaf tms) =
      case dcM of
        Just dc -> mkApps (Data dc) (map Right argTys ++ map Left tms)
        _ -> head tms
    go (LB lb ct) =
      Letrec (bind (Unbound.rec lb) (go ct))
    go (Branch scrut [(p,ct)]) =
      let ct' = go ct
          alt = bind p ct'
      in  case Lens.setOf termFreeIds ct' == Lens.setOf fv alt of
            True -> ct'
            _    -> Case scrut ty [alt]
    go (Branch scrut pats) =
      Case scrut ty (map (\(p,ct) -> bind p (go ct)) pats)
interestingToLift
  :: Set TmOccName
  
  -> (Term -> Term)
  
  -> Term
  
  -> [Either Term Type]
  
  -> Maybe Term
interestingToLift inScope _ e@(Var _ nm) _ =
  if nameOcc nm `Set.member` inScope
     then Just e
     else Nothing
interestingToLift inScope eval e@(Prim nm pty) args =
    case List.lookup nm interestingPrims of
      Just t | t || not (all isConstant lArgs) -> Just e
      _ -> if isHOTy pty
              then if not . null . Maybe.catMaybes $
                      map (uncurry (interestingToLift inScope eval) .
                           collectArgs
                          ) lArgs
                      then Just e
                      else Nothing
              else Nothing
  where
    interestingPrims =
      [("Clash.Sized.Internal.BitVector.*#",tailNonPow2)
      ,("Clash.Sized.Internal.BitVector.times#",tailNonPow2)
      ,("Clash.Sized.Internal.BitVector.quot#",lastNotPow2)
      ,("Clash.Sized.Internal.BitVector.rem#",lastNotPow2)
      ,("Clash.Sized.Internal.Index.*#",tailNonPow2)
      ,("Clash.Sized.Internal.Index.quot#",lastNotPow2)
      ,("Clash.Sized.Internal.Index.rem#",lastNotPow2)
      ,("Clash.Sized.Internal.Signed.*#",tailNonPow2)
      ,("Clash.Sized.Internal.Signed.times#",tailNonPow2)
      ,("Clash.Sized.Internal.Signed.rem#",lastNotPow2)
      ,("Clash.Sized.Internal.Signed.quot#",lastNotPow2)
      ,("Clash.Sized.Internal.Signed.div#",lastNotPow2)
      ,("Clash.Sized.Internal.Signed.mod#",lastNotPow2)
      ,("Clash.Sized.Internal.Unsigned.*#",tailNonPow2)
      ,("Clash.Sized.Internal.Unsigned.times#",tailNonPow2)
      ,("Clash.Sized.Internal.Unsigned.quot#",lastNotPow2)
      ,("Clash.Sized.Internal.Unsigned.rem#",lastNotPow2)
      ,("GHC.Base.quotInt",lastNotPow2)
      ,("GHC.Base.remInt",lastNotPow2)
      ,("GHC.Base.divInt",lastNotPow2)
      ,("GHC.Base.modInt",lastNotPow2)
      ,("GHC.Classes.divInt#",lastNotPow2)
      ,("GHC.Classes.modInt#",lastNotPow2)
      ,("GHC.Integer.Type.timesInteger",allNonPow2)
      ,("GHC.Integer.Type.divInteger",lastNotPow2)
      ,("GHC.Integer.Type.modInteger",lastNotPow2)
      ,("GHC.Integer.Type.quotInteger",lastNotPow2)
      ,("GHC.Integer.Type.remInteger",lastNotPow2)
      ,("GHC.Prim.*#",allNonPow2)
      ,("GHC.Prim.quotInt#",lastNotPow2)
      ,("GHC.Prim.remInt#",lastNotPow2)
      ]
    lArgs       = Either.lefts args
    allNonPow2  = all (not . termIsPow2) lArgs
    tailNonPow2 = case lArgs of
                    [] -> True
                    _  -> all (not . termIsPow2) (tail lArgs)
    lastNotPow2 = case lArgs of
                    [] -> True
                    _  -> not (termIsPow2 (last lArgs))
    termIsPow2 e' = case eval e' of
      Literal (IntegerLiteral n) -> isPow2 n
      a -> case collectArgs a of
        (Prim nm' _,[Right _,Left _,Left (Literal (IntegerLiteral n))])
          | isFromInteger nm' -> isPow2 n
        _ -> False
    isPow2 x = x /= 0 && (x .&. (complement x + 1)) == x
    isFromInteger x = x `elem` ["Clash.Sized.Internal.BitVector.fromInteger##"
                               ,"Clash.Sized.Internal.BitVector.fromInteger#"
                               ,"Clash.Sized.Integer.Index.fromInteger"
                               ,"Clash.Sized.Internal.Signed.fromInteger#"
                               ,"Clash.Sized.Internal.Unsigned.fromInteger#"
                               ]
    isHOTy t = case splitFunForallTy t of
      (args',_) -> any isPolyFunTy (Either.rights args')
interestingToLift _ _ _ _ = Nothing