module CLaSH.Normalize where
import           Control.Concurrent.Supply (Supply)
import           Control.Lens              ((.=))
import qualified Control.Lens              as Lens
import qualified Control.Monad.State       as State
import           Data.Either               (partitionEithers)
import           Data.HashMap.Strict       (HashMap)
import qualified Data.HashMap.Strict       as HashMap
import           Data.List                 (mapAccumL)
import qualified Data.Map                  as Map
import qualified Data.Maybe                as Maybe
import qualified Data.Set                  as Set
import           Unbound.LocallyNameless   (unembed)
import           CLaSH.Core.FreeVars       (termFreeIds)
import           CLaSH.Core.Pretty         (showDoc)
import           CLaSH.Core.Subst          (substTms)
import           CLaSH.Core.Term           (Term (..), TmName)
import           CLaSH.Core.Type           (Type)
import           CLaSH.Core.TyCon          (TyCon, TyConName)
import           CLaSH.Core.Util           (collectArgs, mkApps, termType)
import           CLaSH.Core.Var            (Id,varName)
import           CLaSH.Netlist.Types       (HWType)
import           CLaSH.Netlist.Util        (splitNormalized)
import           CLaSH.Normalize.Strategy
import           CLaSH.Normalize.Transformations ( bindConstantVar, topLet )
import           CLaSH.Normalize.Types
import           CLaSH.Normalize.Util
import           CLaSH.Rewrite.Combinators ((!->),repeatR,topdownR)
import           CLaSH.Rewrite.Types       (DebugLevel (..), RewriteState (..),
                                            bindings, dbgLevel, tcCache)
import           CLaSH.Rewrite.Util        (liftRS, runRewrite,
                                            runRewriteSession)
import           CLaSH.Util
runNormalization :: DebugLevel
                 
                 -> Supply
                 
                 -> HashMap TmName (Type,Term)
                 
                 -> (HashMap TyConName TyCon -> Type -> Maybe (Either String HWType))
                 
                 -> HashMap TyConName TyCon
                 
                 -> (HashMap TyConName TyCon -> Term -> Term)
                 
                 -> NormalizeSession a
                 
                 -> a
runNormalization lvl supply globals typeTrans tcm eval
  = flip State.evalState normState
  . runRewriteSession lvl rwState
  where
    rwState   = RewriteState 0 globals supply typeTrans tcm eval
    normState = NormalizeState
                  HashMap.empty
                  Map.empty
                  HashMap.empty
                  100
                  HashMap.empty
                  100
                  (error "Report as bug: no curFun")
normalize :: [TmName]
          -> NormalizeSession (HashMap TmName (Type,Term))
normalize []  = return HashMap.empty
normalize top = do
  (new,topNormalized) <- unzip <$> mapM normalize' top
  newNormalized <- normalize (concat new)
  return (HashMap.union (HashMap.fromList topNormalized) newNormalized)
normalize' :: TmName
           -> NormalizeSession ([TmName],(TmName,(Type,Term)))
normalize' nm = do
  exprM <- HashMap.lookup nm <$> Lens.use bindings
  let nmS = showDoc nm
  case exprM of
    Just (_,tm) -> do
      tmNorm <- makeCachedT3S nm normalized $ do
                  liftRS $ curFun .= nm
                  tm' <- rewriteExpr ("normalization",normalization) (nmS,tm)
                  tcm <- Lens.use tcCache
                  ty' <- termType tcm tm'
                  return (ty',tm')
      let usedBndrs = termFreeIds (snd tmNorm)
      if nm `elem` usedBndrs
        then error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " remains recursive after normalization."
        else do
          prevNorm <- fmap HashMap.keys $ liftRS $ Lens.use normalized
          let toNormalize = filter (`notElem` prevNorm) usedBndrs
          return (toNormalize,(nm,tmNorm))
    Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"
rewriteExpr :: (String,NormRewrite) 
            -> (String,Term) 
            -> NormalizeSession Term
rewriteExpr (nrwS,nrw) (bndrS,expr) = do
  lvl <- Lens.view dbgLevel
  let before = showDoc expr
  let expr' = traceIf (lvl >= DebugFinal)
                (bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n")
                expr
  rewritten <- runRewrite nrwS nrw expr'
  let after = showDoc rewritten
  traceIf (lvl >= DebugFinal)
    (bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $
    return rewritten
checkNonRecursive :: TmName 
                  -> HashMap TmName (Type,Term) 
                  -> HashMap TmName (Type,Term)
checkNonRecursive topEntity norm =
  let cg = callGraph [] norm topEntity
  in  case recursiveComponents cg of
       []  -> norm
       rcs -> error $ "Callgraph after normalisation contains following recursive cycles: " ++ show rcs
cleanupGraph :: TmName
             -> (HashMap TmName (Type,Term))
             -> NormalizeSession (HashMap TmName (Type,Term))
cleanupGraph topEntity norm = do
  let ct = mkCallTree [] norm topEntity
  ctFlat <- flattenCallTree ct
  return (HashMap.fromList $ snd $ callTreeToList [] ctFlat)
data CallTree = CLeaf   (TmName,(Type,Term))
              | CBranch (TmName,(Type,Term)) [CallTree]
mkCallTree :: [TmName] 
           -> HashMap TmName (Type,Term) 
           -> TmName 
           -> CallTree
mkCallTree visited bindingMap root = case used of
                            [] -> CLeaf   (root,rootTm)
                            _  -> CBranch (root,rootTm) other
  where
    rootTm = Maybe.fromMaybe (error $ show root ++ " is not a global binder") $ HashMap.lookup root bindingMap
    used   = Set.toList $ termFreeIds $ snd rootTm
    other  = map (mkCallTree (root:visited) bindingMap) (filter (`notElem` visited) used)
stripArgs :: [TmName]
          -> [Id]
          -> [Either Term Type]
          -> Maybe [Either Term Type]
stripArgs _      (_:_) []   = Nothing
stripArgs allIds []    args = if any mentionsId args
                                then Nothing
                                else Just args
  where
    mentionsId t = case t of
                     (Left (Var _ nm)) | nm `elem` allIds -> True
                     _ -> False
stripArgs allIds (id_:ids) (Left (Var _ nm):args)
      | varName id_ == nm = stripArgs allIds ids args
      | otherwise         = Nothing
stripArgs _ _ _ = Nothing
flattenNode :: CallTree
            -> NormalizeSession (Either CallTree ((TmName,Term),[CallTree]))
flattenNode c@(CLeaf (nm,(_,e))) = do
  tcm  <- Lens.use tcCache
  norm <- splitNormalized tcm e
  case norm of
    Right (ids,[(_,bExpr)],_) -> do
      let (fun,args) = collectArgs (unembed bExpr)
      case stripArgs (map varName ids) (reverse ids) (reverse args) of
        Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),[]))
        Nothing        -> return (Left c)
    _ -> return (Left c)
flattenNode b@(CBranch (nm,(_,e)) us) = do
  tcm  <- Lens.use tcCache
  norm <- splitNormalized tcm e
  case norm of
    Right (ids,[(_,bExpr)],_) -> do
      let (fun,args) = collectArgs (unembed bExpr)
      case stripArgs (map varName ids) (reverse ids) (reverse args) of
        Just remainder -> return (Right ((nm,mkApps fun (reverse remainder)),us))
        Nothing        -> return (Left b)
    _ -> return (Left b)
flattenCallTree :: CallTree
                -> NormalizeSession CallTree
flattenCallTree c@(CLeaf _) = return c
flattenCallTree (CBranch (nm,(ty,tm)) used) = do
  flattenedUsed   <- mapM flattenCallTree used
  (newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
  let (toInline,il_used) = unzip il_ct
  newExpr <- case toInline of
               [] -> return tm
               _  -> rewriteExpr ("bindConstants",(topdownR (repeatR $ bindConstantVar)) !-> topLet) (showDoc nm, substTms toInline tm)
  return (CBranch (nm,(ty,newExpr)) (newUsed ++ (concat il_used)))
callTreeToList :: [TmName]
               -> CallTree
               -> ([TmName],[(TmName,(Type,Term))])
callTreeToList visited (CLeaf (nm,(ty,tm)))
  | nm `elem` visited = (visited,[])
  | otherwise         = (nm:visited,[(nm,(ty,tm))])
callTreeToList visited (CBranch (nm,(ty,tm)) used)
  | nm `elem` visited = (visited,[])
  | otherwise         = (visited',(nm,(ty,tm)):(concat others))
  where
    (visited',others) = mapAccumL callTreeToList (nm:visited) used