{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Rewrite.Util where
import           Control.Concurrent.Supply   (splitSupply)
import           Control.DeepSeq
import           Control.Exception           (throw)
import           Control.Lens
  (Lens', (%=), (+=), (^.), _Left)
import qualified Control.Lens                as Lens
import qualified Control.Monad               as Monad
#if !MIN_VERSION_base(4,13,0)
import           Control.Monad.Fail          (MonadFail)
#endif
import qualified Control.Monad.State.Strict  as State
import qualified Control.Monad.Writer        as Writer
import           Data.Bool                   (bool)
import           Data.Bifunctor              (bimap)
import           Data.Coerce                 (coerce)
import           Data.Functor.Const          (Const (..))
import           Data.List                   (group, partition, sort)
import           Data.List.Extra             (allM, partitionM)
import qualified Data.Map                    as Map
import           Data.Maybe                  (catMaybes,isJust,mapMaybe)
import qualified Data.Monoid                 as Monoid
import qualified Data.Set                    as Set
import qualified Data.Set.Lens               as Lens
import qualified Data.Set.Ordered            as OSet
import qualified Data.Set.Ordered.Extra      as OSet
import           Data.Text                   (Text)
import qualified Data.Text                   as Text
#ifdef HISTORY
import           Data.Binary                 (encode)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Lazy        as BL
import           System.IO.Unsafe            (unsafePerformIO)
#endif
import           BasicTypes                  (InlineSpec (..))
import           Clash.Core.DataCon          (dcExtTyVars)
import           Clash.Core.Evaluator        (whnf')
import           Clash.Core.Evaluator.Types  (PureHeap)
import           Clash.Core.FreeVars
  (freeLocalVars, hasLocalFreeVars, localIdDoesNotOccurIn, localIdOccursIn,
   typeFreeVars, termFreeVars')
import           Clash.Core.Name
import           Clash.Core.Pretty           (showPpr)
import           Clash.Core.Subst
  (substTmEnv, aeqTerm, aeqType, extendIdSubst, mkSubst, substTm)
import           Clash.Core.Term
import           Clash.Core.TermInfo
import           Clash.Core.TyCon
  (TyConMap, tyConDataCons)
import           Clash.Core.Type             (KindOrType, Type (..),
                                              TypeView (..), coreView1,
                                              normalizeType,
                                              typeKind, tyView, isPolyFunTy)
import           Clash.Core.Util
  (dataConInstArgTysE, isClockOrReset, isEnable)
import           Clash.Core.Var
  (Id, IdScope (..), TyVar, Var (..), isLocalId, mkGlobalId, mkLocalId, mkTyVar)
import           Clash.Core.VarEnv
  (InScopeSet, VarEnv, elemVarSet, extendInScopeSetList, mkInScopeSet,
   uniqAway, uniqAway', mapVarEnv)
import           Clash.Debug (traceIf)
import           Clash.Driver.Types
  (DebugLevel (..), BindingMap, Binding(..))
import           Clash.Netlist.Util          (representableType)
import           Clash.Pretty                (clashPretty, showDoc)
import           Clash.Rewrite.Types
import           Clash.Unique
import           Clash.Util
import qualified Clash.Util.Interpolate as I
zoomExtra :: State.State extra a
          -> RewriteMonad extra a
zoomExtra m = R (\_ s w -> case State.runState m (s ^. extra) of
                            (a,s') -> (a,s {_extra = s'},w))
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
  \case
    Var {}      -> []
    Data {}     -> []
    Literal {}  -> []
    Prim {}     -> []
    Lam _ t     -> findAccidentialShadows t
    TyLam _ t   -> findAccidentialShadows t
    App t1 t2   -> concatMap findAccidentialShadows [t1, t2]
    TyApp t _   -> findAccidentialShadows t
    Cast t _ _  -> findAccidentialShadows t
    Tick _ t    -> findAccidentialShadows t
    Case t _ as ->
      concatMap (findInPat . fst) as ++
        concatMap findAccidentialShadows (t : map snd as)
    Letrec bs t ->
      findDups (map fst bs) ++ findAccidentialShadows t
 where
  findInPat :: Pat -> [[Id]]
  findInPat (LitPat _)        = []
  findInPat (DefaultPat)      = []
  findInPat (DataPat _ _ ids) = findDups ids
  findDups :: [Id] -> [[Id]]
  findDups ids = filter ((1 <) . length) (group (sort ids))
apply
  :: String
  
  -> Rewrite extra
  
  -> Rewrite extra
apply = \s rewrite ctx expr0 -> do
  lvl <- Lens.view dbgLevel
  dbgTranss <- Lens.view dbgTransformations
  let isTryLvl = lvl == DebugTry || lvl >= DebugAll
      isRelevantTrans = s `Set.member` dbgTranss || Set.null dbgTranss
  traceIf (isTryLvl && isRelevantTrans) ("Trying: " ++ s) (pure ())
  (expr1,anyChanged) <- Writer.listen (rewrite ctx expr0)
  let hasChanged = Monoid.getAny anyChanged
      !expr2     = if hasChanged then expr1 else expr0
  Monad.when hasChanged (transformCounter += 1)
#ifdef HISTORY
  
  Monad.when hasChanged $ do
    (curBndr, _) <- Lens.use curFun
    let !_ = unsafePerformIO
             $ BS.appendFile "history.dat"
             $ BL.toStrict
             $ encode RewriteStep
                 { t_ctx    = tfContext ctx
                 , t_name   = s
                 , t_bndrS  = showPpr (varName curBndr)
                 , t_before = expr0
                 , t_after  = expr1
                 }
    return ()
#endif
  if lvl == DebugNone
    then return expr2
    else applyDebug lvl dbgTranss s expr0 hasChanged expr2
{-# INLINE apply #-}
applyDebug
  :: DebugLevel
  
  -> Set.Set String
  
  -> String
  
  -> Term
  
  -> Bool
  
  -> Term
  
  -> RewriteMonad extra Term
applyDebug lvl transformations name exprOld hasChanged exprNew
  | not (Set.null transformations) =
    let newLvl = bool DebugNone lvl (name `Set.member` transformations) in
    applyDebug newLvl Set.empty name exprOld hasChanged exprNew
applyDebug lvl _transformations name exprOld hasChanged exprNew =
 traceIf (lvl >= DebugAll) ("Tried: " ++ name ++ " on:\n" ++ before) $ do
  Monad.when (lvl > DebugNone && hasChanged) $ do
    tcm                  <- Lens.view tcCache
    let beforeTy          = termType tcm exprOld
        beforeFV          = Lens.setOf freeLocalVars exprOld
        afterTy           = termType tcm exprNew
        afterFV           = Lens.setOf freeLocalVars exprNew
        newFV             = not (afterFV `Set.isSubsetOf` beforeFV)
        accidentalShadows = findAccidentialShadows exprNew
    Monad.when newFV $
            error ( concat [ $(curLoc)
                           , "Error when applying rewrite ", name
                           , " to:\n" , before
                           , "\nResult:\n" ++ after ++ "\n"
                           , "It introduces free variables."
                           , "\nBefore: " ++ showPpr (Set.toList beforeFV)
                           , "\nAfter: " ++ showPpr (Set.toList afterFV)
                           ]
                  )
    Monad.when (not (null accidentalShadows)) $
      error ( concat [ $(curLoc)
                     , "Error when applying rewrite ", name
                     , " to:\n" , before
                     , "\nResult:\n" ++ after ++ "\n"
                     , "It accidentally creates shadowing let/case-bindings:\n"
                     , " ", showPpr accidentalShadows, "\n"
                     , "This usually means that a transformation did not extend "
                     , "or incorrectly extended its InScopeSet before applying a "
                     , "substitution."
                     ])
    traceIf (lvl >= DebugApplied && (not (beforeTy `aeqType` afterTy)))
            ( concat [ $(curLoc)
                     , "Error when applying rewrite ", name
                     , " to:\n" , before
                     , "\nResult:\n" ++ after ++ "\n"
                     , "Changes type from:\n", showPpr beforeTy
                     , "\nto:\n", showPpr afterTy
                     ]
            ) (return ())
  Monad.when (lvl >= DebugApplied && not hasChanged && not (exprOld `aeqTerm` exprNew)) $
    error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++  "): before"
                      ++ before ++ "\nafter:\n" ++ after
  traceIf (lvl >= DebugName && hasChanged) name $
    traceIf (lvl >= DebugApplied && hasChanged) ("Changes when applying rewrite to:\n"
                      ++ before ++ "\nResult:\n" ++ after ++ "\n") $
      traceIf (lvl >= DebugAll && not hasChanged) ("No changes when applying rewrite "
                        ++ name ++ " to:\n" ++ after ++ "\n") $
        return exprNew
 where
  before = showPpr exprOld
  after  = showPpr exprNew
runRewrite
  :: String
  
  -> InScopeSet
  -> Rewrite extra
  
  -> Term
  
  -> RewriteMonad extra Term
runRewrite name is rewrite expr = apply name rewrite (TransformContext is []) expr
runRewriteSession :: RewriteEnv
                  -> RewriteState extra
                  -> RewriteMonad extra a
                  -> a
runRewriteSession r s m =
  traceIf (_dbgLevel r > DebugNone)
    ("Clash: Applied " ++ show (s' ^. transformCounter) ++ " transformations")
    a
  where
    (a,s',_) = runR m r s
setChanged :: RewriteMonad extra ()
setChanged = Writer.tell (Monoid.Any True)
changed :: a -> RewriteMonad extra a
changed val = do
  Writer.tell (Monoid.Any True)
  return val
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Nothing
closestLetBinder (LetBinding id_ _:_) = Just id_
closestLetBinder (_:ctx)              = closestLetBinder ctx
mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName (TransformContext _ ctx) sf = case closestLetBinder ctx of
  Just id_ -> appendToName (varName id_) ('_' `Text.cons` sf)
  _ -> mkUnsafeInternalName sf 0
mkTmBinderFor
  :: (MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap 
  -> Name a 
  -> Term 
  -> m Id
mkTmBinderFor is tcm name e = do
  Left r <- mkBinderFor is tcm name (Left e)
  return r
mkBinderFor
  :: (MonadUnique m, MonadFail m)
  => InScopeSet
  -> TyConMap 
  -> Name a 
  -> Either Term Type 
  -> m (Either Id TyVar)
mkBinderFor is tcm name (Left term) = do
  name' <- cloneNameWithInScopeSet is name
  let ty = termType tcm term
  return (Left (mkLocalId ty (coerce name')))
mkBinderFor is tcm name (Right ty) = do
  name' <- cloneNameWithInScopeSet is name
  let ki = typeKind tcm ty
  return (Right (mkTyVar ki (coerce name')))
mkInternalVar
  :: (MonadUnique m)
  => InScopeSet
  -> OccName
  
  -> KindOrType
  -> m Id
mkInternalVar inScope name ty = do
  i <- getUniqueM
  let nm = mkUnsafeInternalName name i
  return (uniqAway inScope (mkLocalId ty nm))
inlineBinders
  :: (Term -> LetBinding -> RewriteMonad extra Bool)
  
  -> Rewrite extra
inlineBinders condition (TransformContext inScope0 _) expr@(Letrec xes res) = do
  (toInline,toKeep) <- partitionM (condition expr) xes
  case toInline of
    [] -> return expr
    _  -> do
      let inScope1 = extendInScopeSetList inScope0 (map fst xes)
          (toInlRec,(toKeep1,res1)) =
            substituteBinders inScope1 toInline toKeep res
      case toInlRec ++ toKeep1 of
        []   -> changed res1
        xes1 -> changed (Letrec xes1 res1)
inlineBinders _ _ e = return e
isJoinPointIn :: Id   
              -> Term 
              -> Bool
isJoinPointIn id_ e = case tailCalls id_ e of
                      Just n | n > 1 -> True
                      _              -> False
tailCalls :: Id   
          -> Term 
          -> Maybe Int
tailCalls id_ = \case
  Var nm | id_ == nm -> Just 1
         | otherwise -> Just 0
  Lam _ e -> tailCalls id_ e
  TyLam _ e -> tailCalls id_ e
  App l r  -> case tailCalls id_ r of
                Just 0 -> tailCalls id_ l
                _      -> Nothing
  TyApp l _ -> tailCalls id_ l
  Letrec bs e ->
    let (bsIds,bsExprs) = unzip bs
        bsTls           = map (tailCalls id_) bsExprs
        bsIdsUsed       = mapMaybe (\(l,r) -> pure l <* r) (zip bsIds bsTls)
        bsIdsTls        = map (`tailCalls` e) bsIdsUsed
        bsCount         = pure . sum $ catMaybes bsTls
    in  case (all isJust bsTls) of
          False -> Nothing
          True  -> case (all (==0) $ catMaybes bsTls) of
            False  -> case all isJust bsIdsTls of
              False -> Nothing
              True  -> (+) <$> bsCount <*> tailCalls id_ e
            True -> tailCalls id_ e
  Case scrut _ alts ->
    let scrutTl = tailCalls id_ scrut
        altsTl  = map (tailCalls id_ . snd) alts
    in  case scrutTl of
          Just 0 | all (/= Nothing) altsTl -> Just (sum (catMaybes altsTl))
          _ -> Nothing
  _ -> Just 0
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam bndr e@(collectArgs -> (Var _,_))) =
  bndr `localIdDoesNotOccurIn` e
isVoidWrapper _ = False
substituteBinders
  :: InScopeSet
  -> [LetBinding]
  
  -> [LetBinding]
  
  -> Term
  
  -> ([LetBinding],([LetBinding],Term))
  
  
  
  
substituteBinders inScope toInline toKeep body =
  let (subst,toInlRec) = go (mkSubst inScope) [] toInline
  in  ( map (second (substTm "substToInlRec" subst)) toInlRec
      , ( map (second (substTm "substToKeep" subst)) toKeep
        , substTm "substBody" subst body) )
 where
  go subst inlRec [] = (subst,inlRec)
  go !subst !inlRec ((x,e):toInl) =
    let e1      = substTm "substInl" subst e
        substE  = extendIdSubst (mkSubst inScope) x e1
        subst1  = subst { substTmEnv = mapVarEnv (substTm "substSubst" substE)
                                                 (substTmEnv subst)}
        subst2  = extendIdSubst subst1 x e1
    in  if x `localIdOccursIn` e1 then
          go subst ((x,e1):inlRec) toInl
        else
          go subst2 inlRec toInl
liftAndSubsituteBinders
  :: InScopeSet
  -> [LetBinding]
  
  -> [LetBinding]
  
  -> Term
  
  -> RewriteMonad extra ([LetBinding],Term)
liftAndSubsituteBinders inScope toLift toKeep body = do
  subst <- go (mkSubst inScope) toLift
  pure ( map (second (substTm "liftToKeep" subst)) toKeep
       , substTm "keepBody" subst body
       )
 where
  go subst [] = pure subst
  go !subst ((x,e):inl) = do
    let e1 = substTm "liftInl" subst e
    (_,e2) <- liftBinding (x,e1)
    let substE = extendIdSubst (mkSubst inScope) x e2
        subst1 = subst { substTmEnv = mapVarEnv (substTm "liftSubst" substE)
                                                (substTmEnv subst) }
        subst2 = extendIdSubst subst1 x e2
    if x `localIdOccursIn` e2 then do
      (_,sp) <- Lens.use curFun
      throw (ClashException sp [I.i|
        Internal error: inlineOrLiftBInders failed on:
        #{showPpr (x,e)}
        creating a self-recursive let-binding:
        #{showPpr (x,e2)}
        given the already built subtitution:
        #{showDoc (clashPretty (substTmEnv subst))}
      |] Nothing)
    else
      go subst2 inl
isWorkFree
  :: Term
  -> Bool
isWorkFree (collectArgs -> (fun,args)) = case fun of
  Var i            -> isLocalId i && not (isPolyFunTy (varType i))
  Data {}          -> all isWorkFreeArg args
  Literal {}       -> True
  Prim pInfo -> case primWorkInfo pInfo of
    WorkConstant   -> True 
                           
                           
    WorkNever      -> all isWorkFreeArg args
    WorkVariable   -> all isConstantArg args
    WorkAlways     -> False 
                            
  Lam _ e          -> isWorkFree e && all isWorkFreeArg args
  TyLam _ e        -> isWorkFree e && all isWorkFreeArg args
  Letrec bs e ->
    isWorkFree e && all (isWorkFree . snd) bs && all isWorkFreeArg args
  Case s _ [(_,a)] -> isWorkFree s && isWorkFree a && all isWorkFreeArg args
  Cast e _ _       -> isWorkFree e && all isWorkFreeArg args
  _                -> False
 where
  isWorkFreeArg = either isWorkFree (const True)
  isConstantArg = either isConstant (const True)
isFromInt :: Text -> Bool
isFromInt nm = nm == "Clash.Sized.Internal.BitVector.fromInteger##" ||
               nm == "Clash.Sized.Internal.BitVector.fromInteger#" ||
               nm == "Clash.Sized.Internal.Index.fromInteger#" ||
               nm == "Clash.Sized.Internal.Signed.fromInteger#" ||
               nm == "Clash.Sized.Internal.Unsigned.fromInteger#"
isConstant :: Term -> Bool
isConstant e = case collectArgs e of
  (Data _, args)   -> all (either isConstant (const True)) args
  (Prim _, args) -> all (either isConstant (const True)) args
  (Lam _ _, _)     -> not (hasLocalFreeVars e)
  (Literal _,_)    -> True
  _                -> False
isConstantNotClockReset
  :: Term
  -> RewriteMonad extra Bool
isConstantNotClockReset e = do
  tcm <- Lens.view tcCache
  let eTy = termType tcm e
  if isClockOrReset tcm eTy
     then case collectArgs e of
        (Prim p,_) -> return (primName p == "Clash.Transformations.removedArg")
        _ -> return False
     else pure (isConstant e)
isWorkFreeClockOrResetOrEnable
  :: TyConMap
  -> Term
  -> Maybe Bool
isWorkFreeClockOrResetOrEnable tcm e =
  let eTy = termType tcm e in
  if isClockOrReset tcm eTy || isEnable tcm eTy then
    case collectArgs e of
      (Prim p,_) -> Just (primName p == "Clash.Transformations.removedArg")
      (Var _, []) -> Just True
      (Data _, []) -> Just True 
      (Literal _,_) -> Just True
      _ -> Just False
  else
    Nothing
isWorkFreeIsh
  :: Term
  -> RewriteMonad extra Bool
isWorkFreeIsh e = do
  tcm <- Lens.view tcCache
  case isWorkFreeClockOrResetOrEnable tcm e of
    Just b -> pure b
    Nothing ->
      case collectArgs e of
        (Data _, args)   -> allM isWorkFreeIshArg args
        (Prim pInfo, args) -> case primWorkInfo pInfo of
          WorkAlways     -> pure False 
                                       
          WorkVariable   -> pure (all isConstantArg args)
          _              -> allM isWorkFreeIshArg args
        (Lam _ _, _)     -> pure (not (hasLocalFreeVars e))
        (Literal _,_)    -> pure True
        _                -> pure False
 where
  isWorkFreeIshArg = either isWorkFreeIsh (pure . const True)
  isConstantArg    = either isConstant (const True)
inlineOrLiftBinders
  :: (LetBinding -> RewriteMonad extra Bool)
  
  -> (Term -> LetBinding -> Bool)
  
  
  
  
  -> Rewrite extra
inlineOrLiftBinders condition inlineOrLift (TransformContext inScope0 _) e@(Letrec bndrs body) = do
  (toReplace,toKeep) <- partitionM condition bndrs
  case toReplace of
    [] -> return e
    _  -> do
      let inScope1 = extendInScopeSetList inScope0 (map fst bndrs)
      let (toInline,toLift) = partition (inlineOrLift e) toReplace
      
      
      let (toLiftExtra,(toReplace1,body1)) =
            substituteBinders inScope1 toInline (toLift ++ toKeep) body
          (toLift1,toKeep1) = splitAt (length toLift) toReplace1
      
      (toKeep2,body2) <- liftAndSubsituteBinders inScope1
                           (toLiftExtra ++ toLift1)
                           toKeep1 body1
      case toKeep2 of
        [] -> changed body2
        _  -> changed (Letrec toKeep2 body2)
inlineOrLiftBinders _ _ _ e = return e
liftBinding :: LetBinding
            -> RewriteMonad extra LetBinding
liftBinding (var@Id {varName = idName} ,e) = do
  
  let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
      unitFV v@(Id {})    = Const (emptyUniqSet,unitUniqSet (coerce v))
      unitFV v@(TyVar {}) = Const (unitUniqSet (coerce v),emptyUniqSet)
      interesting :: Var a -> Bool
      interesting Id {idScope = GlobalId} = False
      interesting v@(Id {idScope = LocalId}) = varUniq v /= varUniq var
      interesting _ = True
      (boundFTVsSet,boundFVsSet) =
        getConst (Lens.foldMapOf (termFreeVars' interesting) unitFV e)
      boundFTVs = eltsUniqSet boundFTVsSet
      boundFVs  = eltsUniqSet boundFVsSet
  
  tcm       <- Lens.view tcCache
  let newBodyTy = termType tcm $ mkTyLams (mkLams e boundFVs) boundFTVs
  (cf,sp)   <- Lens.use curFun
  binders <- Lens.use bindings
  newBodyNm <-
    cloneNameWithBindingMap
      binders
      (appendToName (varName cf) ("_" `Text.append` nameOcc idName))
  let newBodyId = mkGlobalId newBodyTy newBodyNm {nameSort = Internal}
  
  
  let newExpr = mkTmApps
                  (mkTyApps (Var newBodyId)
                            (map VarTy boundFTVs))
                  (map Var boundFVs)
      inScope0 = mkInScopeSet (coerce boundFVsSet)
      inScope1 = extendInScopeSetList inScope0 [var,newBodyId]
  let subst    = extendIdSubst (mkSubst inScope1) var newExpr
      
      e' = substTm "liftBinding" subst e
      
      newBody = mkTyLams (mkLams e' boundFVs) boundFTVs
  
  aeqExisting <- (eltsUniqMap . filterUniqMap ((`aeqTerm` newBody) . bindingTerm)) <$> Lens.use bindings
  case aeqExisting of
    
    [] -> do 
             bindings %= extendUniqMap newBodyNm
                                    
                                    
                                    
                                    
                                    
                                    
                                    
                                    
                                    (Binding
                                      newBodyId
                                      sp
#if MIN_VERSION_ghc(8,4,1)
                                      NoUserInline
#else
                                      EmptyInlineSpec
#endif
                                      newBody)
             
             return (var, newExpr)
    
    (b:_) ->
      let newExpr' = mkTmApps
                      (mkTyApps (Var $ bindingId b)
                                (map VarTy boundFTVs))
                      (map Var boundFVs)
      in  return (var, newExpr')
liftBinding _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar"
uniqAwayBinder
  :: BindingMap
  -> Name a
  -> Name a
uniqAwayBinder binders nm =
  uniqAway' (`elemUniqMapDirectly` binders) (nameUniq nm) nm
mkFunction
  :: TmName
  
  -> SrcSpan
  -> InlineSpec
  -> Term
  
  -> RewriteMonad extra Id
  
mkFunction bndrNm sp inl body = do
  tcm <- Lens.view tcCache
  let bodyTy = termType tcm body
  binders <- Lens.use bindings
  bodyNm <- cloneNameWithBindingMap binders bndrNm
  addGlobalBind bodyNm bodyTy sp inl body
  return (mkGlobalId bodyTy bodyNm)
addGlobalBind
  :: TmName
  -> Type
  -> SrcSpan
  -> InlineSpec
  -> Term
  -> RewriteMonad extra ()
addGlobalBind vNm ty sp inl body = do
  let vId = mkGlobalId ty vNm
  (ty,body) `deepseq` bindings %= extendUniqMap vNm (Binding vId sp inl body)
cloneNameWithInScopeSet
  :: (MonadUnique m)
  => InScopeSet
  -> Name a
  -> m (Name a)
cloneNameWithInScopeSet is nm = do
  i <- getUniqueM
  return (uniqAway is (setUnique nm i))
cloneNameWithBindingMap
  :: (MonadUnique m)
  => BindingMap
  -> Name a
  -> m (Name a)
cloneNameWithBindingMap binders nm = do
  i <- getUniqueM
  return (uniqAway' (`elemUniqMapDirectly` binders) i (setUnique nm i))
{-# INLINE isUntranslatable #-}
isUntranslatable
  :: Bool
  
  -> Term
  -> RewriteMonad extra Bool
isUntranslatable stringRepresentable tm = do
  tcm <- Lens.view tcCache
  not <$> (representableType <$> Lens.view typeTranslator
                             <*> Lens.view customReprs
                             <*> pure stringRepresentable
                             <*> pure tcm
                             <*> pure (termType tcm tm))
{-# INLINE isUntranslatableType #-}
isUntranslatableType
  :: Bool
  
  -> Type
  -> RewriteMonad extra Bool
isUntranslatableType stringRepresentable ty =
  not <$> (representableType <$> Lens.view typeTranslator
                             <*> Lens.view customReprs
                             <*> pure stringRepresentable
                             <*> Lens.view tcCache
                             <*> pure ty)
mkWildValBinder
  :: (MonadUnique m)
  => InScopeSet
  -> Type
  -> m Id
mkWildValBinder is = mkInternalVar is "wild"
mkSelectorCase
  :: HasCallStack
  => (Functor m, MonadUnique m)
  => String 
  -> InScopeSet
  -> TyConMap 
  -> Term 
  -> Int 
  -> Int 
  -> m Term
mkSelectorCase caller inScope tcm scrut dcI fieldI = go (termType tcm scrut)
  where
    go (coreView1 tcm -> Just ty') = go ty'
    go scrutTy@(tyView -> TyConApp tc args) =
      case tyConDataCons (lookupUniqMap' tcm tc) of
        [] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showPpr tc) scrutTy
        dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max" scrutTy
            | otherwise -> do
          let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI-1)) dcs (dcI-1)
          let (Just fieldTys) = dataConInstArgTysE inScope tcm dc args
          if fieldI >= length fieldTys
            then cantCreate $(curLoc) "Field index exceed max" scrutTy
            else do
              wildBndrs <- mapM (mkWildValBinder inScope) fieldTys
              let ty = indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI
              selBndr <- mkInternalVar inScope "sel" ty
              let bndrs  = take fieldI wildBndrs ++ [selBndr] ++ drop (fieldI+1) wildBndrs
                  pat    = DataPat dc (dcExtTyVars dc) bndrs
                  retVal = Case scrut ty [ (pat, Var selBndr) ]
              return retVal
    go scrutTy = cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showPpr scrutTy) scrutTy
    cantCreate loc info scrutTy = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showPpr scrut ++ " :: " ++ showPpr scrutTy ++ ")\nAdditional info: " ++ info
specialise :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) 
           -> Lens' extra (VarEnv Int) 
           -> Lens' extra Int 
           -> Rewrite extra
specialise specMapLbl specHistLbl specLimitLbl ctx e = case e of
  (TyApp e1 ty) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Right ty)
  (App e1 e2)   -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Left  e2)
  _             -> return e
specialise' :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id) 
            -> Lens' extra (VarEnv Int) 
            -> Lens' extra Int 
            -> TransformContext 
            -> Term 
            -> (Term, [Either Term Type], [TickInfo]) 
            -> Either Term Type 
            -> RewriteMonad extra Term
specialise' specMapLbl specHistLbl specLimitLbl (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do
  lvl <- Lens.view dbgLevel
  tcm <- Lens.view tcCache
  
  topEnts <- Lens.view topEntities
  if f `elemVarSet` topEnts
  then do
    case specArgIn of
      Left _ -> traceIf (lvl >= DebugNone) ("Not specializing TopEntity: " ++ showPpr (varName f)) (return e)
      Right tyArg -> traceIf (lvl >= DebugApplied) ("Dropping type application on TopEntity: " ++ showPpr (varName f) ++ "\ntype:\n" ++ showPpr tyArg) $
        
        
        
        
        
        let newVarTy = piResultTy tcm (varType f) tyArg
        in  changed (mkApps (mkTicks (Var f{varType = newVarTy}) ticks) args)
  else do 
  let specArg = bimap (normalizeTermTypes tcm) (normalizeType tcm) specArgIn
      
      
      (specBndrsIn,specVars) = specArgBndrsAndVars specArg
      argLen  = length args
      specBndrs :: [Either Id TyVar]
      specBndrs = map (Lens.over _Left (normalizeId tcm)) specBndrsIn
      specAbs :: Either Term Type
      specAbs = either (Left . (`mkAbstraction` specBndrs)) (Right . id) specArg
  
  specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specMapLbl)
  case specM of
    
    Just f' ->
      traceIf (lvl >= DebugApplied)
        ("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++
          (either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $
        changed $ mkApps (mkTicks (Var f') ticks) (args ++ specVars)
    
    Nothing -> do
      
      bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings
      case bodyMaybe of
        Just (Binding _ sp inl bodyTm) -> do
          
          specHistM <- lookupUniqMap f <$> Lens.use (extra.specHistLbl)
          specLim   <- Lens.use (extra . specLimitLbl)
          if maybe False (> specLim) specHistM
            then throw (ClashException
                        sp
                        (unlines [ "Hit specialisation limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
                                 , "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
                                 , "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
                                 , "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
                                 , "Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
                                 ])
                        Nothing)
            else do
              let existingNames = collectBndrsMinusApps bodyTm
                  newNames      = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n
                                  | n <- [(0::Int)..]
                                  ]
              
              (boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
                                     Monad.zipWithM
                                       (mkBinderFor is0 tcm)
                                       (existingNames ++ newNames)
                                       args
              
              
              (fId,inl',specArg') <- case specArg of
                Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a
                    then do
                      
                      
                      
                      
                      
                      
                      
                      
                      
                      
                      
                      gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings
                      return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM)
                    else return (f,inl,specArg)
                _ -> return (f,inl,specArg)
              
              let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
              newf <- mkFunction (varName fId) sp inl' newBody
              
              (extra.specHistLbl) %= extendUniqMapWith f 1 (+)
              (extra.specMapLbl)  %= Map.insert (f,argLen,specAbs) newf
              
              let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars)
              newf `deepseq` changed newExpr
        Nothing -> return e
  where
    collectBndrsMinusApps :: Term -> [Name a]
    collectBndrsMinusApps = reverse . go []
      where
        go bs (Lam v e')    = go (coerce (varName v):bs)  e'
        go bs (TyLam tv e') = go (coerce (varName tv):bs) e'
        go bs (App e' _) = case go [] e' of
          []  -> bs
          bs' -> init bs' ++ bs
        go bs (TyApp e' _) = case go [] e' of
          []  -> bs
          bs' -> init bs' ++ bs
        go bs _ = bs
specialise' _ _ _ _ctx _ (appE,args,ticks) (Left specArg) = do
  
  let (specBndrs,specVars) = specArgBndrsAndVars (Left specArg)
  
      newBody = mkAbstraction specArg specBndrs
  
  
  existing <- filterUniqMap ((`aeqTerm` newBody) . bindingTerm) <$> Lens.use bindings
  
  newf <- case eltsUniqMap existing of
    [] -> do (cf,sp) <- Lens.use curFun
             mkFunction (appendToName (varName cf) "_specF")
                        sp
#if MIN_VERSION_ghc(8,4,1)
                        NoUserInline
#else
                        EmptyInlineSpec
#endif
                        newBody
    (b:_) -> return (bindingId b)
  
  let newArg  = Left $ mkApps (Var newf) specVars
  
  let newExpr = mkApps (mkTicks appE ticks) (args ++ [newArg])
  changed newExpr
specialise' _ _ _ _ e _ _ = return e
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes tcm e = case e of
  Cast e' ty1 ty2 -> Cast (normalizeTermTypes tcm e') (normalizeType tcm ty1) (normalizeType tcm ty2)
  Var v -> Var (normalizeId tcm v)
  
  _ -> e
normalizeId :: TyConMap -> Id -> Id
normalizeId tcm v@(Id {}) = v {varType = normalizeType tcm (varType v)}
normalizeId _   tyvar     = tyvar
specArgBndrsAndVars
  :: Either Term Type
  -> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars specArg =
  
  let unitFV :: Var a -> Const (OSet.OLSet TyVar, OSet.OLSet Id) (Var a)
      unitFV v@(Id {}) = Const (mempty, coerce (OSet.singleton (coerce v)))
      unitFV v@(TyVar {}) = Const (coerce (OSet.singleton (coerce v)), mempty)
      (specFTVs,specFVs) = case specArg of
        Left tm  -> (OSet.toListL *** OSet.toListL) . getConst $
                    Lens.foldMapOf freeLocalVars unitFV tm
        Right ty -> (eltsUniqSet (Lens.foldMapOf typeFreeVars unitUniqSet ty),[] :: [Id])
      specTyBndrs = map Right specFTVs
      specTmBndrs = map Left  specFVs
      specTyVars  = map (Right . VarTy) specFTVs
      specTmVars  = map (Left . Var) specFVs
  in  (specTyBndrs ++ specTmBndrs,specTyVars ++ specTmVars)
whnfRW
  :: Bool
  
  
  -> TransformContext
  -> Term
  -> Rewrite extra
  -> RewriteMonad extra Term
whnfRW isSubj ctx@(TransformContext is0 _) e rw = do
  tcm <- Lens.view tcCache
  bndrs <- Lens.use bindings
  (primEval, primUnwind) <- Lens.view evaluator
  ids <- Lens.use uniqSupply
  let (ids1,ids2) = splitSupply ids
  uniqSupply Lens..= ids2
  gh <- Lens.use globalHeap
  case whnf' primEval primUnwind bndrs tcm gh ids1 is0 isSubj e of
    (!gh1,ph,v) -> do
      globalHeap Lens..= gh1
      bindPureHeap tcm ph rw ctx v
{-# SCC whnfRW #-}
bindPureHeap
  :: TyConMap
  -> PureHeap
  -> Rewrite extra
  -> Rewrite extra
bindPureHeap tcm heap rw (TransformContext is0 hist) e = do
  (e1, Monoid.getAny -> hasChanged) <- Writer.listen $ rw ctx e
  if hasChanged && not (null bndrs)
    then return $ Letrec bndrs e1
    else return e1
  where
    bndrs = map toLetBinding $ toListUniqMap heap
    heapIds = map fst bndrs
    is1 = extendInScopeSetList is0 heapIds
    ctx = TransformContext is1 (LetBody heapIds : hist)
    toLetBinding :: (Unique,Term) -> LetBinding
    toLetBinding (uniq,term) = (nm, term)
      where
        ty = termType tcm term
        nm = mkLocalId ty (mkUnsafeSystemName "x" uniq)