{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Transformations of the Normalization process
-}

{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE MagicHash         #-}
{-# LANGUAGE MultiWayIf        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE ViewPatterns      #-}

module Clash.Normalize.Transformations
  ( appProp
  , caseLet
  , caseCon
  , caseCase
  , caseElemNonReachable
  , elemExistentials
  , inlineNonRep
  , inlineOrLiftNonRep
  , typeSpec
  , nonRepSpec
  , etaExpansionTL
  , nonRepANF
  , bindConstantVar
  , constantSpec
  , makeANF
  , deadCode
  , topLet
  , recToLetRec
  , inlineWorkFree
  , inlineHO
  , inlineSmall
  , simpleCSE
  , reduceConst
  , reduceNonRepPrim
  , caseFlat
  , disjointExpressionConsolidation
  , removeUnusedExpr
  , inlineCleanup
  , flattenLet
  , splitCastWork
  , inlineCast
  , caseCast
  , letCast
  , eliminateCastCast
  , argCastSpec
  , etaExpandSyn
  , appPropFast
  )
where

import           Control.Concurrent.Supply   (splitSupply)
import           Control.Exception           (throw)
import           Control.Lens                (_2)
import qualified Control.Lens                as Lens
import qualified Control.Monad               as Monad
import           Control.Monad.State         (StateT (..), modify)
import           Control.Monad.State.Strict  (evalState)
import           Control.Monad.Writer        (lift, listen)
import           Control.Monad.Trans.Except  (runExcept)
import           Data.Bits                   ((.&.), complement)
import           Data.Coerce                 (coerce)
import qualified Data.Either                 as Either
import qualified Data.HashMap.Lazy           as HashMap
import qualified Data.HashMap.Strict         as HashMapS
import qualified Data.List                   as List
import qualified Data.Maybe                  as Maybe
import qualified Data.Monoid                 as Monoid
import qualified Data.Primitive.ByteArray    as BA
import qualified Data.Text                   as Text
import qualified Data.Vector.Primitive       as PV
import           Debug.Trace                 (trace)
import           GHC.Integer.GMP.Internals   (Integer (..), BigNat (..))

import           BasicTypes                  (InlineSpec (..))

import           Clash.Annotations.Primitive (extractPrim)
import           Clash.Core.DataCon          (DataCon (..))
import           Clash.Core.Evaluator        (PureHeap, whnf')
import           Clash.Core.Name
  (Name (..), NameSort (..), mkUnsafeSystemName)
import           Clash.Core.FreeVars
  (localIdOccursIn, localIdsDoNotOccurIn, freeLocalIds, termFreeTyVars, typeFreeVars, localVarsDoNotOccurIn)
import           Clash.Core.Literal          (Literal (..))
import           Clash.Core.Pretty           (showPpr)
import           Clash.Core.Subst
  (substTm, mkSubst, extendIdSubst, extendIdSubstList, extendTvSubst,
   extendTvSubstList, freshenTm, substTyInVar, deShadowTerm)
import           Clash.Core.Term
  (LetBinding, Pat (..), Term (..), CoreContext (..), PrimInfo (..), TickInfo,
   isLambdaBodyCtx, isTickCtx, collectArgs, collectArgsTicks, collectTicks,
   partitionTicks)
import           Clash.Core.Type             (Type, TypeView (..), applyFunTy,
                                              isPolyFunCoreTy, isClassTy,
                                              normalizeType, splitFunForallTy,
                                              splitFunTy,
                                              tyView)
import           Clash.Core.TyCon            (TyConMap, tyConDataCons)
import           Clash.Core.Util
  (isCon, isFun, isLet, isPolyFun, isPrim,
   isSignalType, isVar, mkApps, mkLams, mkVec, piResultTy, termSize, termType,
   tyNatSize, patVars, isAbsurdAlt, altEqs, substInExistentialsList,
   solveNonAbsurds, patIds, isLocalVar, undefinedTm, stripTicks, mkTicks)
import           Clash.Core.Var
  (Id, Var (..), isGlobalId, isLocalId, mkLocalId)
import           Clash.Core.VarEnv
  (InScopeSet, VarEnv, VarSet, elemVarSet,
   emptyVarEnv, emptyVarSet, extendInScopeSet, extendInScopeSetList, lookupVarEnv,
   notElemVarSet, unionVarEnvWith, unionVarSet, unionInScope, unitVarEnv,
   unitVarSet, mkVarSet, mkInScopeSet, uniqAway)
import           Clash.Driver.Types          (DebugLevel (..))
import           Clash.Netlist.BlackBox.Util (usedArguments)
import           Clash.Netlist.Types         (HWType (..), FilteredHWType(..))
import           Clash.Netlist.Util
  (coreTypeToHWType, representableType, splitNormalized, bindsExistentials)
import           Clash.Normalize.DEC
import           Clash.Normalize.PrimitiveReductions
import           Clash.Normalize.Types
import           Clash.Normalize.Util
import           Clash.Primitives.Types
  (Primitive(..), TemplateKind(TExpr), CompiledPrimMap)
import           Clash.Rewrite.Combinators
import           Clash.Rewrite.Types
import           Clash.Rewrite.Util
import           Clash.Unique
  (Unique, lookupUniqMap, toListUniqMap)
import           Clash.Util

inlineOrLiftNonRep :: HasCallStack => NormRewrite
inlineOrLiftNonRep :: NormRewrite
inlineOrLiftNonRep = (LetBinding -> RewriteMonad NormalizeState Bool)
-> (Term -> LetBinding -> RewriteMonad NormalizeState Bool)
-> NormRewrite
forall extra.
(LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineOrLiftBinders LetBinding -> RewriteMonad NormalizeState Bool
forall extra. LetBinding -> RewriteMonad extra Bool
nonRepTest Term -> LetBinding -> RewriteMonad NormalizeState Bool
forall extra. Term -> LetBinding -> RewriteMonad extra Bool
inlineTest
  where
    nonRepTest :: (Id, Term) -> RewriteMonad extra Bool
    nonRepTest :: LetBinding -> RewriteMonad extra Bool
nonRepTest (Id {varType :: forall a. Var a -> Kind
varType = Kind
ty}, _)
      = Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Kind -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Kind
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                                   RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                                   RewriteMonad extra (Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                   RewriteMonad extra (TyConMap -> Kind -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                                   RewriteMonad extra (Kind -> Bool)
-> RewriteMonad extra Kind -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kind -> RewriteMonad extra Kind
forall (f :: * -> *) a. Applicative f => a -> f a
pure Kind
ty)
    nonRepTest _ = Bool -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

    inlineTest :: Term -> (Id, Term) -> RewriteMonad extra Bool
    inlineTest :: Term -> LetBinding -> RewriteMonad extra Bool
inlineTest e :: Term
e (id_ :: Id
id_, e' :: Term
e')
      = Bool -> Bool
not (Bool -> Bool) -> ([Bool] -> Bool) -> [Bool] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool)
-> RewriteMonad extra [Bool] -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [RewriteMonad extra Bool] -> RewriteMonad extra [Bool]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence -- We do __NOT__ inline:
              [ -- 1. recursive let-binders
                Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id
id_ Id -> Term -> Bool
`localIdOccursIn` Term
e')
                -- 2. join points (which are not void-wrappers)
              , Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id -> Term -> Bool
isJoinPointIn Id
id_ Term
e Bool -> Bool -> Bool
&& Bool -> Bool
not (Term -> Bool
isVoidWrapper Term
e'))
                -- 3. binders that are used more than once in the body, because
                --    it makes CSE a whole lot more difficult.
              , Bool -> RewriteMonad extra Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
freeOccurances Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 1)
              ]
      where
        -- The number of free occurrences of the binder in the entire
        -- let-expression
        freeOccurances :: Int
        freeOccurances :: Int
freeOccurances = case Term
e of
          Letrec _ res :: Term
res -> do
            Sum Int -> Int
forall a. Sum a -> a
Monoid.getSum
              (Getting (Sum Int) Term Id -> (Id -> Sum Int) -> Term -> Sum Int
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (Sum Int) Term Id
Fold Term Id
freeLocalIds
                              (\i :: Id
i -> if Id
i Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
id_
                                        then Int -> Sum Int
forall a. a -> Sum a
Monoid.Sum 1
                                        else Int -> Sum Int
forall a. a -> Sum a
Monoid.Sum 0)
                              Term
res)
          _ -> 0

{- [Note] join points and void wrappers
Join points are functions that only occur in tail-call positions within an
expression, and only when they occur in a tail-call position more than once.

Normally bindNonRep binds/inlines all non-recursive local functions. However,
doing so for join points would significantly increase compilation time, so we
avoid it. The only exception to this rule are so-called void wrappers. Void
wrappers are functions of the form:

> \(w :: Void) -> f a b c

i.e. a wrapper around the function 'f' where the argument 'w' is not used. We
do bind/line these join-points because these void-wrappers interfere with the
'disjoint expression consolidation' (DEC) and 'common sub-expression elimination'
(CSE) transformation, sometimes resulting in circuits that are twice as big
as they'd need to be.
-}

-- | Specialize functions on their type
typeSpec :: HasCallStack => NormRewrite
typeSpec :: NormRewrite
typeSpec ctx :: TransformContext
ctx e :: Term
e@(TyApp e1 :: Term
e1 ty :: Kind
ty)
  | (Var {},  args :: [Either Term Kind]
args) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e1
  , [TyVar] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Kind TyVar -> Kind -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Kind TyVar
Fold Kind TyVar
typeFreeVars Kind
ty
  , (_, []) <- [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
  = NormRewrite
specializeNorm TransformContext
ctx Term
e

typeSpec _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Specialize functions on their non-representable argument
nonRepSpec :: HasCallStack => NormRewrite
nonRepSpec :: NormRewrite
nonRepSpec ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(App e1 :: Term
e1 e2 :: Term
e2)
  | (Var {}, args :: [Either Term Kind]
args) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e1
  , (_, [])     <- [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
  , [TyVar] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Term TyVar -> Term -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Term TyVar
Fold Term TyVar
termFreeTyVars Term
e2
  = do TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
       let e2Ty :: Kind
e2Ty = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e2
       let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
e2
       Bool
nonRepE2 <- Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Kind -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Kind
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                                              RewriteMonad
  NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState CustomReprs
-> RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                                              RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                              RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState TyConMap
-> RewriteMonad NormalizeState (Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                                              RewriteMonad NormalizeState (Kind -> Bool)
-> RewriteMonad NormalizeState Kind
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kind -> RewriteMonad NormalizeState Kind
forall (f :: * -> *) a. Applicative f => a -> f a
pure Kind
e2Ty)
       if Bool
nonRepE2 Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
localVar
         then do
           Term
e2' <- Term -> RewriteMonad NormalizeState Term
inlineInternalSpecialisationArgument Term
e2
           NormRewrite
specializeNorm TransformContext
ctx (Term -> Term -> Term
App Term
e1 Term
e2')
         else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    -- | If the argument on which we're specialising ia an internal function,
    -- one created by the compiler, then inline that function before we
    -- specialise.
    --
    -- We need to do this because otherwise the specialisation history won't
    -- recognize the new specialisation argument as something the function has
    -- already been specialized on
    inlineInternalSpecialisationArgument
      :: Term
      -> NormalizeSession Term
    inlineInternalSpecialisationArgument :: Term -> RewriteMonad NormalizeState Term
inlineInternalSpecialisationArgument app :: Term
app
      | (Var f :: Id
f,fArgs :: [Either Term Kind]
fArgs,ticks :: [TickInfo]
ticks) <- Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
app
      = do
        Maybe (Id, SrcSpan, InlineSpec, Term)
fTmM <- Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv (Id, SrcSpan, InlineSpec, Term)
 -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
        case Maybe (Id, SrcSpan, InlineSpec, Term)
fTmM of
          Just (fNm :: Id
fNm,_,_,tm :: Term
tm)
            | Name Term -> NameSort
forall a. Name a -> NameSort
nameSort (Id -> Name Term
forall a. Var a -> Name a
varName Id
fNm) NameSort -> NameSort -> Bool
forall a. Eq a => a -> a -> Bool
== NameSort
Internal
            -> do
              Term
tm' <- (Any -> Any)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall extra a.
(Any -> Any) -> RewriteMonad extra a -> RewriteMonad extra a
censor (Any -> Any -> Any
forall a b. a -> b -> a
const Any
forall a. Monoid a => a
mempty)
                            (NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR HasCallStack => NormRewrite
NormRewrite
appProp TransformContext
ctx
                                       (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
tm [TickInfo]
ticks) [Either Term Kind]
fArgs))
              -- See Note [AppProp no-shadow invariant]
              Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is0 Term
tm')
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
app
      | Bool
otherwise = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
app

nonRepSpec _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Lift the let-bindings out of the subject of a Case-decomposition
caseLet :: HasCallStack => NormRewrite
caseLet :: NormRewrite
caseLet _ (Case (Letrec xes :: [LetBinding]
xes e :: Term
e) ty :: Kind
ty alts :: [Alt]
alts) =
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes (Term -> Kind -> [Alt] -> Term
Case Term
e Kind
ty [Alt]
alts))

caseLet _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Remove non-reachable alternatives. For example, consider:
--
--    data STy ty where
--      SInt :: Int -> STy Int
--      SBool :: Bool -> STy Bool
--
--    f :: STy ty -> ty
--    f (SInt b) = b + 1
--    f (SBool True) = False
--    f (SBool False) = True
--    {-# NOINLINE f #-}
--
--    g :: STy Int -> Int
--    g = f
--
-- @f@ is always specialized on @STy Int@. The SBool alternatives are therefore
-- unreachable. Additional information can be found at:
-- https://github.com/clash-lang/clash-compiler/pull/465
caseElemNonReachable :: HasCallStack => NormRewrite
caseElemNonReachable :: NormRewrite
caseElemNonReachable _ case0 :: Term
case0@(Case scrut :: Term
scrut altsTy :: Kind
altsTy alts0 :: [Alt]
alts0) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache

  let (altsAbsurd :: [Alt]
altsAbsurd, altsOther :: [Alt]
altsOther) = (Alt -> Bool) -> [Alt] -> ([Alt], [Alt])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (TyConMap -> Alt -> Bool
isAbsurdAlt TyConMap
tcm) [Alt]
alts0
  case [Alt]
altsAbsurd of
    [] -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
case0
    _  -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Term -> RewriteMonad NormalizeState Term
forall extra. Term -> RewriteMonad extra Term
caseOneAlt (Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
altsTy [Alt]
altsOther)

caseElemNonReachable _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Tries to eliminate existentials by using heuristics to determine what the
-- existential should be. For example, consider Vec:
--
--    data Vec :: Nat -> Type -> Type where
--      Nil       :: Vec 0 a
--      Cons x xs :: a -> Vec n a -> Vec (n + 1) a
--
-- Thus, 'null' (annotated with existentials) could look like:
--
--    null :: forall n . Vec n Bool -> Bool
--    null v =
--      case v of
--        Nil  {n ~ 0}                                     -> True
--        Cons {n1:Nat} {n~n1+1} (x :: a) (xs :: Vec n1 a) -> False
--
-- When it's applied to a vector of length 5, this becomes:
--
--    null :: Vec 5 Bool -> Bool
--    null v =
--      case v of
--        Nil  {5 ~ 0}                                     -> True
--        Cons {n1:Nat} {5~n1+1} (x :: a) (xs :: Vec n1 a) -> False
--
-- This function solves 'n1' and replaces every occurrence with its solution. A
-- very limited number of solutions are currently recognized: only adds (such
-- as in the example) will be solved.
elemExistentials :: HasCallStack => NormRewrite
elemExistentials :: NormRewrite
elemExistentials (TransformContext is0 :: InScopeSet
is0 _) (Case scrut :: Term
scrut altsTy :: Kind
altsTy alts0 :: [Alt]
alts0) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache

  [Alt]
alts1 <- (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go InScopeSet
is0 TyConMap
tcm) [Alt]
alts0
  Term -> RewriteMonad NormalizeState Term
forall extra. Term -> RewriteMonad extra Term
caseOneAlt (Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
altsTy [Alt]
alts1)

 where
    -- Eliminate free type variables if possible
    go :: InScopeSet -> TyConMap -> (Pat, Term) -> NormalizeSession (Pat, Term)
    go :: InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go is2 :: InScopeSet
is2 tcm :: TyConMap
tcm alt :: Alt
alt@(DataPat dc :: DataCon
dc exts0 :: [TyVar]
exts0 xs0 :: [Id]
xs0, term0 :: Term
term0) =
      case TyConMap -> [(Kind, Kind)] -> [(TyVar, Kind)]
solveNonAbsurds TyConMap
tcm (TyConMap -> Alt -> [(Kind, Kind)]
altEqs TyConMap
tcm Alt
alt) of
        -- No equations solved:
        [] -> Alt -> RewriteMonad NormalizeState Alt
forall (m :: * -> *) a. Monad m => a -> m a
return Alt
alt
        -- One or more equations solved:
        sols :: [(TyVar, Kind)]
sols ->
          Alt -> RewriteMonad NormalizeState Alt
forall a extra. a -> RewriteMonad extra a
changed (Alt -> RewriteMonad NormalizeState Alt)
-> RewriteMonad NormalizeState Alt
-> RewriteMonad NormalizeState Alt
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< InScopeSet -> TyConMap -> Alt -> RewriteMonad NormalizeState Alt
go InScopeSet
is2 TyConMap
tcm (DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [TyVar]
exts1 [Id]
xs1, Term
term1)
          where
            -- Substitute solution in existentials and applied types
            is3 :: InScopeSet
is3   = InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is2 [TyVar]
exts0
            xs1 :: [Id]
xs1   = (Id -> Id) -> [Id] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Subst -> Id -> Id
forall a. HasCallStack => Subst -> Var a -> Var a
substTyInVar (Subst -> [(TyVar, Kind)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is3) [(TyVar, Kind)]
sols)) [Id]
xs0
            exts1 :: [TyVar]
exts1 = HasCallStack => InScopeSet -> [TyVar] -> [(TyVar, Kind)] -> [TyVar]
InScopeSet -> [TyVar] -> [(TyVar, Kind)] -> [TyVar]
substInExistentialsList InScopeSet
is2 [TyVar]
exts0 [(TyVar, Kind)]
sols

            -- Substitute solution in term.
            is4 :: InScopeSet
is4       = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is3 [Id]
xs1
            subst :: Subst
subst     = Subst -> [(TyVar, Kind)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is4) [(TyVar, Kind)]
sols
            term1 :: Term
term1     = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "Replacing tyVar due to solved eq" Subst
subst Term
term0

    go _ _ alt :: Alt
alt = Alt -> RewriteMonad NormalizeState Alt
forall (m :: * -> *) a. Monad m => a -> m a
return Alt
alt

elemExistentials _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Move a Case-decomposition from the subject of a Case-decomposition to the alternatives
caseCase :: HasCallStack => NormRewrite
caseCase :: NormRewrite
caseCase _ e :: Term
e@(Case (Term -> Term
stripTicks -> Case scrut :: Term
scrut alts1Ty :: Kind
alts1Ty alts1 :: [Alt]
alts1) alts2Ty :: Kind
alts2Ty alts2 :: [Alt]
alts2)
  = do
    Bool
ty1Rep <- (CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Kind -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Kind
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                                RewriteMonad
  NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState CustomReprs
-> RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                                RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState TyConMap
-> RewriteMonad NormalizeState (Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                                RewriteMonad NormalizeState (Kind -> Bool)
-> RewriteMonad NormalizeState Kind
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kind -> RewriteMonad NormalizeState Kind
forall (f :: * -> *) a. Applicative f => a -> f a
pure Kind
alts1Ty
    if Bool -> Bool
not Bool
ty1Rep
      then let newAlts :: [Alt]
newAlts = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> Alt -> Alt
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (\altE :: Term
altE -> Term -> Kind -> [Alt] -> Term
Case Term
altE Kind
alts2Ty [Alt]
alts2)) [Alt]
alts1
           in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
alts2Ty [Alt]
newAlts
      else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

caseCase _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Inline function with a non-representable result if it's the subject
-- of a Case-decomposition
inlineNonRep :: HasCallStack => NormRewrite
inlineNonRep :: NormRewrite
inlineNonRep (TransformContext localScope :: InScopeSet
localScope _) e :: Term
e@(Case scrut :: Term
scrut altsTy :: Kind
altsTy alts :: [Alt]
alts)
  | (Var f :: Id
f, args :: [Either Term Kind]
args,ticks :: [TickInfo]
ticks) <- Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
scrut
  , Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
f
  = do
    (cf :: Id
cf,_)    <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
    Maybe Int
isInlined <- State NormalizeState (Maybe Int)
-> RewriteMonad NormalizeState (Maybe Int)
forall extra a. State extra a -> RewriteMonad extra a
zoomExtra (Id -> Id -> State NormalizeState (Maybe Int)
alreadyInlined Id
f Id
cf)
    Int
limit     <- Getting Int (RewriteState NormalizeState) Int
-> RewriteMonad NormalizeState Int
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Int NormalizeState)
-> RewriteState NormalizeState
-> Const Int (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Int NormalizeState)
 -> RewriteState NormalizeState
 -> Const Int (RewriteState NormalizeState))
-> ((Int -> Const Int Int)
    -> NormalizeState -> Const Int NormalizeState)
-> Getting Int (RewriteState NormalizeState) Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int -> Const Int Int)
-> NormalizeState -> Const Int NormalizeState
Lens' NormalizeState Int
inlineLimit)
    TyConMap
tcm       <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let scrutTy :: Kind
scrutTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
scrut
        noException :: Bool
noException = Bool -> Bool
not (TyConMap -> Kind -> Bool
exception TyConMap
tcm Kind
scrutTy)
    if Bool
noException Bool -> Bool -> Bool
&& (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
Maybe.fromMaybe 0 Maybe Int
isInlined) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
limit
      then do
        Bool
-> String
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> String -> a -> a
traceIf Bool
True ([String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [$(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "InlineNonRep: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
f)
                             ," already inlined " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
limit String -> String -> String
forall a. [a] -> [a] -> [a]
++ " times in:"
                             , Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
cf)
                             , "\nType of the subject is: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Kind -> String
forall p. PrettyPrec p => p -> String
showPpr Kind
scrutTy
                             , "\nFunction " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
cf)
                             , " will not reach a normal form, and compilation"
                             , " might fail."
                             , "\nRun with '-fclash-inline-limit=N' to increase"
                             , " the inlining limit to N."
                             ])
                     (Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e)
      else do
        Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe   <- Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv (Id, SrcSpan, InlineSpec, Term)
 -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
        Bool
nonRepScrut <- Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Kind -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Kind
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                                                  RewriteMonad
  NormalizeState (CustomReprs -> Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState CustomReprs
-> RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
                                                  RewriteMonad NormalizeState (Bool -> TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
                                                  RewriteMonad NormalizeState (TyConMap -> Kind -> Bool)
-> RewriteMonad NormalizeState TyConMap
-> RewriteMonad NormalizeState (Kind -> Bool)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
                                                  RewriteMonad NormalizeState (Kind -> Bool)
-> RewriteMonad NormalizeState Kind
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Kind -> RewriteMonad NormalizeState Kind
forall (f :: * -> *) a. Applicative f => a -> f a
pure Kind
scrutTy)
        case (Bool
nonRepScrut, Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe) of
          (True,Just (_,_,_,scrutBody0 :: Term
scrutBody0)) -> do
            Bool
-> RewriteMonad NormalizeState () -> RewriteMonad NormalizeState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when Bool
noException (State NormalizeState () -> RewriteMonad NormalizeState ()
forall extra a. State extra a -> RewriteMonad extra a
zoomExtra (Id -> Id -> State NormalizeState ()
addNewInline Id
f Id
cf))
            -- See Note [AppProp no-shadow invariant]
            let scrutBody1 :: Term
scrutBody1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
localScope Term
scrutBody0
            Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> Kind -> [Alt] -> Term
Case (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
scrutBody1 [TickInfo]
ticks) [Either Term Kind]
args) Kind
altsTy [Alt]
alts
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    exception :: TyConMap -> Kind -> Bool
exception = TyConMap -> Kind -> Bool
isClassTy

inlineNonRep _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Specialize a Case-decomposition (replace by the RHS of an alternative) if
-- the subject is (an application of) a DataCon; or if there is only a single
-- alternative that doesn't reference variables bound by the pattern.
--
-- Note [CaseCon deshadow]
--
-- Imagine:
--
-- @
-- case D (f a b) (g x y) of
--   D a b -> h a
-- @
--
-- rewriting this to:
--
-- @
-- let a = f a b
-- in  h a
-- @
--
-- is very bad because the newly introduced let-binding now captures the free
-- variable 'a' in 'f a b'.
--
-- instead me must rewrite to:
--
-- @
-- let a1 = f a b
-- in  h a1
-- @
caseCon :: HasCallStack => NormRewrite
caseCon :: NormRewrite
caseCon (TransformContext is0 :: InScopeSet
is0 _) (Case scrut :: Term
scrut ty :: Kind
ty alts :: [Alt]
alts)
  | (Data dc :: DataCon
dc, args :: [Either Term Kind]
args) <- Term -> (Term, [Either Term Kind])
collectArgs Term
scrut
  = case (Alt -> Bool) -> [Alt] -> Maybe Alt
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find (DataCon -> Pat -> Bool
equalCon DataCon
dc (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
alts of
      Just (DataPat _ tvs :: [TyVar]
tvs xs :: [Id]
xs, e :: Term
e) -> do
        let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [TyVar]
tvs) [Id]
xs
        let fvs :: UniqSet (Var Any)
fvs = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
            (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                      ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs ([Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args)
            e' :: Term
e' = case [LetBinding]
binds of
                  [] -> Term
e
                  _  ->
                    -- See Note [CaseCon deshadow]
                    let ((is3 :: InScopeSet
is3,substIds :: [LetBinding]
substIds),binds' :: [LetBinding]
binds') = ((InScopeSet, [LetBinding])
 -> LetBinding -> ((InScopeSet, [LetBinding]), LetBinding))
-> (InScopeSet, [LetBinding])
-> [LetBinding]
-> ((InScopeSet, [LetBinding]), [LetBinding])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
List.mapAccumL (InScopeSet, [LetBinding])
-> LetBinding -> ((InScopeSet, [LetBinding]), LetBinding)
forall b.
(InScopeSet, [LetBinding])
-> (Id, b) -> ((InScopeSet, [LetBinding]), (Id, b))
newBinder
                                                    (InScopeSet
is1,[]) [LetBinding]
binds
                        subst :: Subst
subst = Subst -> [LetBinding] -> Subst
extendIdSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is3) [LetBinding]
substIds
                    in  [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds' (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "caseCon0" Subst
subst Term
e)
        let subst :: Subst
subst = Subst -> [(TyVar, Kind)] -> Subst
extendTvSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is1)
                  ([(TyVar, Kind)] -> Subst) -> [(TyVar, Kind)] -> Subst
forall a b. (a -> b) -> a -> b
$ [TyVar] -> [Kind] -> [(TyVar, Kind)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TyVar]
tvs (Int -> [Kind] -> [Kind]
forall a. Int -> [a] -> [a]
drop ([TyVar] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (DataCon -> [TyVar]
dcUnivTyVars DataCon
dc)) ([Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args))
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "caseCon1" Subst
subst Term
e')
      _ -> case [Alt]
alts of
             ((DefaultPat,e :: Term
e):_) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
             _ -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Kind -> Term
undefinedTm Kind
ty)
  where
    equalCon :: DataCon -> Pat -> Bool
equalCon dc :: DataCon
dc (DataPat dc' :: DataCon
dc' _ _) = DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== DataCon -> Int
dcTag DataCon
dc'
    equalCon _  _                 = Bool
False

    newBinder :: (InScopeSet, [LetBinding])
-> (Id, b) -> ((InScopeSet, [LetBinding]), (Id, b))
newBinder (isN0 :: InScopeSet
isN0,substN :: [LetBinding]
substN) (x :: Id
x,arg :: b
arg) =
      let x' :: Id
x'   = InScopeSet -> Id -> Id
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
isN0 Id
x
          isN1 :: InScopeSet
isN1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isN0 Id
x'
      in  ((InScopeSet
isN1,(Id
x,Id -> Term
Var Id
x')LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
substN),(Id
x',b
arg))

caseCon _ c :: Term
c@(Case (Term -> Term
stripTicks -> Literal l :: Literal
l) _ alts :: [Alt]
alts) = case (Alt -> Bool) -> [Alt] -> Maybe Alt
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find (Pat -> Bool
equalLit (Pat -> Bool) -> (Alt -> Pat) -> Alt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
alts of
    Just (LitPat _,e :: Term
e) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    _ -> Term -> Literal -> [Alt] -> RewriteMonad NormalizeState Term
matchLiteralContructor Term
c Literal
l [Alt]
alts
  where
    equalLit :: Pat -> Bool
equalLit (LitPat l' :: Literal
l')     = Literal
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
    equalLit _               = Bool
False

caseCon ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(Case subj :: Term
subj ty :: Kind
ty alts :: [Alt]
alts)
  | (Prim _ _,_) <- Term -> (Term, [Either Term Kind])
collectArgs Term
subj = do
    CustomReprs
reprs <- Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs <- Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
    PrimEvaluator
primEval <- Getting PrimEvaluator RewriteEnv PrimEvaluator
-> RewriteMonad NormalizeState PrimEvaluator
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting PrimEvaluator RewriteEnv PrimEvaluator
Lens' RewriteEnv PrimEvaluator
evaluator
    Supply
ids <- Getting Supply (RewriteState NormalizeState) Supply
-> RewriteMonad NormalizeState Supply
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting Supply (RewriteState NormalizeState) Supply
forall extra. Lens' (RewriteState extra) Supply
uniqSupply
    let (ids1 :: Supply
ids1,ids2 :: Supply
ids2) = Supply -> (Supply, Supply)
splitSupply Supply
ids
    (Supply -> Identity Supply)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra. Lens' (RewriteState extra) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> Supply -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= Supply
ids2
    GlobalHeap
gh <- Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
-> RewriteMonad NormalizeState GlobalHeap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
forall extra. Lens' (RewriteState extra) GlobalHeap
globalHeap
    DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
    case PrimEvaluator
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> TyConMap
-> GlobalHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (GlobalHeap, PureHeap, Term)
whnf' PrimEvaluator
primEval VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs TyConMap
tcm GlobalHeap
gh Supply
ids1 InScopeSet
is0 Bool
True Term
subj of
      (gh' :: GlobalHeap
gh',ph' :: PureHeap
ph',v :: Term
v) -> do
        (GlobalHeap -> Identity GlobalHeap)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra. Lens' (RewriteState extra) GlobalHeap
globalHeap ((GlobalHeap -> Identity GlobalHeap)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> GlobalHeap -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= GlobalHeap
gh'
        TransformContext
-> TyConMap
-> PureHeap
-> (TransformContext -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall extra.
TransformContext
-> TyConMap
-> PureHeap
-> (TransformContext -> RewriteMonad extra Term)
-> RewriteMonad extra Term
bindPureHeap TransformContext
ctx TyConMap
tcm PureHeap
ph' ((TransformContext -> RewriteMonad NormalizeState Term)
 -> RewriteMonad NormalizeState Term)
-> (TransformContext -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ \ctx' :: TransformContext
ctx' -> case Term -> Term
stripTicks Term
v of
          Literal l :: Literal
l -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx' (Term -> Kind -> [Alt] -> Term
Case (Literal -> Term
Literal Literal
l) Kind
ty [Alt]
alts)
          subj' :: Term
subj' -> case Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
subj' of
            (Data _,_,_) -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx' (Term -> Kind -> [Alt] -> Term
Case Term
subj' Kind
ty [Alt]
alts)
#if MIN_VERSION_ghc(8,2,2)
            (Prim nm :: Text
nm ty' :: PrimInfo
ty',_:msgOrCallStack :: Either Term Kind
msgOrCallStack:_,ticks :: [TickInfo]
ticks)
              | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Control.Exception.Base.absentError" ->
                let e' :: Term
e' = Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Text -> PrimInfo -> Term
Prim Text
nm PrimInfo
ty') [TickInfo]
ticks)
                                [Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
ty,Either Term Kind
msgOrCallStack]
                in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
#endif

            (Prim nm :: Text
nm ty' :: PrimInfo
ty',repTy :: Either Term Kind
repTy:_:msgOrCallStack :: Either Term Kind
msgOrCallStack:_,ticks :: [TickInfo]
ticks)
              | Text
nm Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ["Control.Exception.Base.patError"
#if !MIN_VERSION_ghc(8,2,2)
                          ,"Control.Exception.Base.absentError"
#endif
                          ,"GHC.Err.undefined"] ->
                let e' :: Term
e' = Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Text -> PrimInfo -> Term
Prim Text
nm PrimInfo
ty') [TickInfo]
ticks)
                                [Either Term Kind
repTy,Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
ty,Either Term Kind
msgOrCallStack]
                in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
            (Prim nm :: Text
nm ty' :: PrimInfo
ty',[_],ticks :: [TickInfo]
ticks)
              | Text
nm Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ "Clash.Transformations.undefined"
                          , "Clash.GHC.Evaluator.undefined"
                          , "EmptyCase"] ->
                let e' :: Term
e' = Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Text -> PrimInfo -> Term
Prim Text
nm PrimInfo
ty') [TickInfo]
ticks) [Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
ty]
                in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
            _ -> do
              let subjTy :: Kind
subjTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
subj
              CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Maybe (Either String FilteredHWType))
tran <- Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
              case (State HWMap (Either String FilteredHWType)
-> HWMap -> Either String FilteredHWType
forall s a. State s a -> s -> a
`evalState` HWMap
forall k v. HashMap k v
HashMapS.empty) ((CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Either String FilteredHWType)
coreTypeToHWType CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Maybe (Either String FilteredHWType))
tran CustomReprs
reprs TyConMap
tcm Kind
subjTy) of
                Right (FilteredHWType (Void (Just hty :: HWType
hty)) _areVoids :: [[(Bool, FilteredHWType)]]
_areVoids)
                  | HWType
hty HWType -> [HWType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int -> HWType
BitVector 0, Int -> HWType
Unsigned 0, Int -> HWType
Signed 0, Integer -> HWType
Index 1]
                  -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx' (Term -> Kind -> [Alt] -> Term
Case (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral 0)) Kind
ty [Alt]
alts)
                _ -> do
                  let ret :: RewriteMonad extra Term
ret = Term -> RewriteMonad extra Term
forall extra. Term -> RewriteMonad extra Term
caseOneAlt Term
e
                  if DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone then do
                    let subjIsConst :: Bool
subjIsConst = Term -> Bool
isConstant Term
subj
                    Bool
-> String
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone Bool -> Bool -> Bool
&& Bool
subjIsConst) ("Irreducible constant as case subject: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
subj String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\nCan be reduced to: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
subj') RewriteMonad NormalizeState Term
forall extra. RewriteMonad extra Term
ret
                  else
                    RewriteMonad NormalizeState Term
forall extra. RewriteMonad extra Term
ret

caseCon ctx :: TransformContext
ctx e :: Term
e@(Case subj :: Term
subj ty :: Kind
ty alts :: [Alt]
alts) = do
  CustomReprs
reprs <- Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad NormalizeState CustomReprs
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Lens' RewriteEnv CustomReprs
customReprs
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let subjTy :: Kind
subjTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
subj
  CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Maybe (Either String FilteredHWType))
tran <- Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     NormalizeState
     (CustomReprs
      -> TyConMap
      -> Kind
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Kind
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
  case (State HWMap (Either String FilteredHWType)
-> HWMap -> Either String FilteredHWType
forall s a. State s a -> s -> a
`evalState` HWMap
forall k v. HashMap k v
HashMapS.empty) ((CustomReprs
 -> TyConMap
 -> Kind
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Either String FilteredHWType)
coreTypeToHWType CustomReprs
-> TyConMap
-> Kind
-> State HWMap (Maybe (Either String FilteredHWType))
tran CustomReprs
reprs TyConMap
tcm Kind
subjTy) of
    Right (FilteredHWType (Void (Just hty :: HWType
hty)) _areVoids :: [[(Bool, FilteredHWType)]]
_areVoids)
      | HWType
hty HWType -> [HWType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int -> HWType
BitVector 0, Int -> HWType
Unsigned 0, Int -> HWType
Signed 0, Integer -> HWType
Index 1]
      -> HasCallStack => NormRewrite
NormRewrite
caseCon TransformContext
ctx (Term -> Kind -> [Alt] -> Term
Case (Literal -> Term
Literal (Integer -> Literal
IntegerLiteral 0)) Kind
ty [Alt]
alts)
    _ -> Term -> RewriteMonad NormalizeState Term
forall extra. Term -> RewriteMonad extra Term
caseOneAlt Term
e

caseCon _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e


-- | Binds variables on the PureHeap over the result of the rewrite
--
-- To prevent unnecessary rewrites only do this when rewrite changed something.
bindPureHeap
  :: TransformContext
  -> TyConMap
  -> PureHeap
  -> (TransformContext -> RewriteMonad extra Term)
  -> RewriteMonad extra Term
bindPureHeap :: TransformContext
-> TyConMap
-> PureHeap
-> (TransformContext -> RewriteMonad extra Term)
-> RewriteMonad extra Term
bindPureHeap (TransformContext is0 :: InScopeSet
is0 ctxs :: Context
ctxs) tcm :: TyConMap
tcm heap :: PureHeap
heap rw :: TransformContext -> RewriteMonad extra Term
rw = do
  (e :: Term
e, Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen (RewriteMonad extra Term -> RewriteMonad extra (Term, Any))
-> RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall a b. (a -> b) -> a -> b
$ TransformContext -> RewriteMonad extra Term
rw TransformContext
ctx'
  if Bool
hasChanged Bool -> Bool -> Bool
&& Bool -> Bool
not ([LetBinding] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LetBinding]
bndrs)
    then Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> RewriteMonad extra Term)
-> Term -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ [LetBinding] -> Term -> Term
Letrec [LetBinding]
bndrs Term
e
    else Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    bndrs :: [LetBinding]
bndrs = ((Int, Term) -> LetBinding) -> [(Int, Term)] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Term) -> LetBinding
toLetBinding ([(Int, Term)] -> [LetBinding]) -> [(Int, Term)] -> [LetBinding]
forall a b. (a -> b) -> a -> b
$ PureHeap -> [(Int, Term)]
forall a. UniqMap a -> [(Int, a)]
toListUniqMap PureHeap
heap
    heapIds :: [Id]
heapIds = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs
    is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
heapIds
    ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 ([Id] -> CoreContext
LetBody [Id]
heapIds CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctxs)

    toLetBinding :: (Unique,Term) -> LetBinding
    toLetBinding :: (Int, Term) -> LetBinding
toLetBinding (uniq :: Int
uniq,term :: Term
term) = (Id
nm, Term
term)
      where
        ty :: Kind
ty = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
term
        nm :: Id
nm = Kind -> Name Term -> Id
mkLocalId Kind
ty (Text -> Int -> Name Term
forall a. Text -> Int -> Name a
mkUnsafeSystemName "x" Int
uniq) -- See [Note: Name re-creation]

{- [Note: Name re-creation]
The names of heap bound variables are safely generate with mkUniqSystemId in Clash.Core.Evaluator.newLetBinding.
But only their uniqs end up in the heap, not the complete names.
So we use mkUnsafeSystemName to recreate the same Name.
-}

matchLiteralContructor
  :: Term
  -> Literal
  -> [(Pat,Term)]
  -> NormalizeSession Term
matchLiteralContructor :: Term -> Literal -> [Alt] -> RewriteMonad NormalizeState Term
matchLiteralContructor c :: Term
c (IntegerLiteral l :: Integer
l) alts :: [Alt]
alts = [Alt] -> RewriteMonad NormalizeState Term
forall extra. [Alt] -> RewriteMonad extra Term
go ([Alt] -> [Alt]
forall a. [a] -> [a]
reverse [Alt]
alts)
 where
  go :: [Alt] -> RewriteMonad extra Term
go [(DefaultPat,e :: Term
e)] = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
  go ((DataPat dc :: DataCon
dc [] xs :: [Id]
xs,e :: Term
e):alts' :: [Alt]
alts')
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= ((-2)Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(63::Int)) Bool -> Bool -> Bool
&&  Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(63::Int)
    = let fvs :: UniqSet (Var Any)
fvs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
          (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                    ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs [Literal -> Term
Literal (Integer -> Literal
IntLiteral Integer
l)]
          e' :: Term
e' = case [LetBinding]
binds of
                 [] -> Term
e
                 _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
e
      in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 2
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= 2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(63::Int)
    = let !(Jp# !(BN# ba :: ByteArray#
ba)) = Integer
l
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          bv :: Vector a
bv        = Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
PV.Vector 0 (ByteArray -> Int
BA.sizeofByteArray ByteArray
ba') ByteArray
ba'
          fvs :: UniqSet (Var Any)
fvs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
          (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                    ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs [Literal -> Term
Literal (Vector Word8 -> Literal
ByteArrayLiteral Vector Word8
forall a. Vector a
bv)]
          e' :: Term
e' = case [LetBinding]
binds of
                 [] -> Term
e
                 _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
e
      in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< ((-2)Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(63::Int))
    = let !(Jn# !(BN# ba :: ByteArray#
ba)) = Integer
l
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          bv :: Vector a
bv        = Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
PV.Vector 0 (ByteArray -> Int
BA.sizeofByteArray ByteArray
ba') ByteArray
ba'
          fvs :: UniqSet (Var Any)
fvs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
          (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                    ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs [Literal -> Term
Literal (Vector Word8 -> Literal
ByteArrayLiteral Vector Word8
forall a. Vector a
bv)]
          e' :: Term
e' = case [LetBinding]
binds of
                 [] -> Term
e
                 _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
e
      in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go ((LitPat l' :: Literal
l', e :: Term
e):alts' :: [Alt]
alts')
    | Integer -> Literal
IntegerLiteral Integer
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
    = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go _ = String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra Term)
-> String -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c

matchLiteralContructor c :: Term
c (NaturalLiteral l :: Integer
l) alts :: [Alt]
alts = [Alt] -> RewriteMonad NormalizeState Term
forall extra. [Alt] -> RewriteMonad extra Term
go ([Alt] -> [Alt]
forall a. [a] -> [a]
reverse [Alt]
alts)
 where
  go :: [Alt] -> RewriteMonad extra Term
go [(DefaultPat,e :: Term
e)] = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
  go ((DataPat dc :: DataCon
dc [] xs :: [Id]
xs,e :: Term
e):alts' :: [Alt]
alts')
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= 0 Bool -> Bool -> Bool
&& Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(64::Int)
    = let fvs :: UniqSet (Var Any)
fvs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
          (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                    ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs [Literal -> Term
Literal (Integer -> Literal
WordLiteral Integer
l)]
          e' :: Term
e' = case [LetBinding]
binds of
                 [] -> Term
e
                 _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
e
      in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
    | DataCon -> Int
dcTag DataCon
dc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 2
    , Integer
l Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= 2Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^(64::Int)
    = let !(Jp# !(BN# ba :: ByteArray#
ba)) = Integer
l
          ba' :: ByteArray
ba'       = ByteArray# -> ByteArray
BA.ByteArray ByteArray#
ba
          bv :: Vector a
bv        = Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
PV.Vector 0 (ByteArray -> Int
BA.sizeofByteArray ByteArray
ba') ByteArray
ba'
          fvs :: UniqSet (Var Any)
fvs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
e
          (binds :: [LetBinding]
binds,_) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst)
                    ([LetBinding] -> ([LetBinding], [LetBinding]))
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a b. (a -> b) -> a -> b
$ [Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
xs [Literal -> Term
Literal (Vector Word8 -> Literal
ByteArrayLiteral Vector Word8
forall a. Vector a
bv)]
          e' :: Term
e' = case [LetBinding]
binds of
                 [] -> Term
e
                 _  -> [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
e
      in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go ((LitPat l' :: Literal
l', e :: Term
e):alts' :: [Alt]
alts')
    | Integer -> Literal
NaturalLiteral Integer
l Literal -> Literal -> Bool
forall a. Eq a => a -> a -> Bool
== Literal
l'
    = Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
    | Bool
otherwise
    = [Alt] -> RewriteMonad extra Term
go [Alt]
alts'
  go _ = String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra Term)
-> String -> RewriteMonad extra Term
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c

matchLiteralContructor _ _ ((DefaultPat,e :: Term
e):_) = Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
matchLiteralContructor c :: Term
c _ _ =
  String -> RewriteMonad NormalizeState Term
forall a. HasCallStack => String -> a
error (String -> RewriteMonad NormalizeState Term)
-> String -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Report as bug: caseCon error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c

caseOneAlt :: Term -> RewriteMonad extra Term
caseOneAlt :: Term -> RewriteMonad extra Term
caseOneAlt e :: Term
e@(Case _ _ [(pat :: Pat
pat,altE :: Term
altE)]) = case Pat
pat of
  DefaultPat -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
  LitPat _ -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
  DataPat _ tvs :: [TyVar]
tvs xs :: [Id]
xs
    | ([TyVar] -> [Var Any]
forall a b. Coercible a b => a -> b
coerce [TyVar]
tvs [Var Any] -> [Var Any] -> [Var Any]
forall a. [a] -> [a] -> [a]
++ [Id] -> [Var Any]
forall a b. Coercible a b => a -> b
coerce [Id]
xs) [Var Any] -> Term -> Bool
forall a. [Var a] -> Term -> Bool
`localVarsDoNotOccurIn` Term
altE
    -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
altE
    | Bool
otherwise
    -> Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

caseOneAlt e :: Term
e = Term -> RewriteMonad extra Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Bring an application of a DataCon or Primitive in ANF, when the argument is
-- is considered non-representable
nonRepANF :: HasCallStack => NormRewrite
nonRepANF :: NormRewrite
nonRepANF ctx :: TransformContext
ctx e :: Term
e@(App appConPrim :: Term
appConPrim arg :: Term
arg)
  | (conPrim :: Term
conPrim, _) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e
  , Term -> Bool
isCon Term
conPrim Bool -> Bool -> Bool
|| Term -> Bool
isPrim Term
conPrim
  = do
    Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
arg
    case (Bool
untranslatable,Term -> Term
stripTicks Term
arg) of
      (True,Letrec binds :: [LetBinding]
binds body :: Term
body) -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds (Term -> Term -> Term
App Term
appConPrim Term
body))
      (True,Case {})  -> NormRewrite
specializeNorm TransformContext
ctx Term
e
      (True,Lam {})   -> NormRewrite
specializeNorm TransformContext
ctx Term
e
      (True,TyLam {}) -> NormRewrite
specializeNorm TransformContext
ctx Term
e
      _               -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

nonRepANF _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Ensure that top-level lambda's eventually bind a let-expression of which
-- the body is a variable-reference.
topLet :: HasCallStack => NormRewrite
topLet :: NormRewrite
topLet (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) e :: Term
e
  | (CoreContext -> Bool) -> Context -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\c :: CoreContext
c -> CoreContext -> Bool
isLambdaBodyCtx CoreContext
c Bool -> Bool -> Bool
|| CoreContext -> Bool
isTickCtx CoreContext
c) Context
ctx Bool -> Bool -> Bool
&& Bool -> Bool
not (Term -> Bool
isLet Term
e)
  = do
  Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
e
  if Bool
untranslatable
    then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
    else do TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
            Id
argId <- InScopeSet
-> TyConMap -> Name Any -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is0 TyConMap
tcm (Text -> Int -> Name Any
forall a. Text -> Int -> Name a
mkUnsafeSystemName "result" 0) Term
e
            Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [(Id
argId, Term
e)] (Id -> Term
Var Id
argId))

topLet (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) e :: Term
e@(Letrec binds :: [LetBinding]
binds body :: Term
body)
  | (CoreContext -> Bool) -> Context -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\c :: CoreContext
c -> CoreContext -> Bool
isLambdaBodyCtx CoreContext
c Bool -> Bool -> Bool
|| CoreContext -> Bool
isTickCtx CoreContext
c) Context
ctx
  = do
    let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
body
    Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
body
    if Bool
localVar Bool -> Bool -> Bool
|| Bool
untranslatable
      then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      else do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        let is2 :: InScopeSet
is2 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds)
        Id
argId <- InScopeSet
-> TyConMap -> Name Any -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is2 TyConMap
tcm (Text -> Int -> Name Any
forall a. Text -> Int -> Name a
mkUnsafeSystemName "result" 0) Term
body
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec ([LetBinding]
binds [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [(Id
argId,Term
body)]) (Id -> Term
Var Id
argId))

topLet _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- Misc rewrites

-- | Remove unused let-bindings
deadCode :: HasCallStack => NormRewrite
deadCode :: NormRewrite
deadCode _ e :: Term
e@(Letrec xes :: [LetBinding]
xes body :: Term
body) = do
    let bodyFVs :: UniqSet (Var Any)
bodyFVs = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
body
        (xesUsed :: [LetBinding]
xesUsed,xesOther :: [LetBinding]
xesOther) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
bodyFVs) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst) [LetBinding]
xes
        xesUsed' :: [LetBinding]
xesUsed' = [LetBinding] -> [LetBinding] -> [LetBinding] -> [LetBinding]
findUsedBndrs [] [LetBinding]
xesUsed [LetBinding]
xesOther
    if [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
xesUsed' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
xes
      then case [LetBinding]
xesUsed' of
              [] -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
body
              _  -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xesUsed' Term
body)
      else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    findUsedBndrs :: [(Id, Term)] -> [(Id, Term)]
                  -> [(Id, Term)] -> [(Id, Term)]
    findUsedBndrs :: [LetBinding] -> [LetBinding] -> [LetBinding] -> [LetBinding]
findUsedBndrs used :: [LetBinding]
used []      _     = [LetBinding]
used
    findUsedBndrs used :: [LetBinding]
used explore :: [LetBinding]
explore other :: [LetBinding]
other =
      let fvsUsed :: UniqSet (Var Any)
fvsUsed = (UniqSet (Var Any) -> UniqSet (Var Any) -> UniqSet (Var Any))
-> UniqSet (Var Any) -> [UniqSet (Var Any)] -> UniqSet (Var Any)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' UniqSet (Var Any) -> UniqSet (Var Any) -> UniqSet (Var Any)
unionVarSet
                                UniqSet (Var Any)
emptyVarSet
                                ((LetBinding -> UniqSet (Var Any))
-> [LetBinding] -> [UniqSet (Var Any)]
forall a b. (a -> b) -> [a] -> [b]
map (Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet (Term -> UniqSet (Var Any))
-> (LetBinding -> Term) -> LetBinding -> UniqSet (Var Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
explore)
          (explore' :: [LetBinding]
explore',other' :: [LetBinding]
other') = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition
                                ((Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
fvsUsed) (Id -> Bool) -> (LetBinding -> Id) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Id
forall a b. (a, b) -> a
fst) [LetBinding]
other
      in [LetBinding] -> [LetBinding] -> [LetBinding] -> [LetBinding]
findUsedBndrs ([LetBinding]
used [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [LetBinding]
explore) [LetBinding]
explore' [LetBinding]
other'

deadCode _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

removeUnusedExpr :: HasCallStack => NormRewrite
removeUnusedExpr :: NormRewrite
removeUnusedExpr _ e :: Term
e@(Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (p :: Term
p@(Prim nm :: Text
nm pInfo :: PrimInfo
pInfo),args :: [Either Term Kind]
args,ticks :: [TickInfo]
ticks)) = do
  Maybe GuardedCompiledPrimitive
bbM <- Text
-> HashMap Text GuardedCompiledPrimitive
-> Maybe GuardedCompiledPrimitive
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMap.lookup Text
nm (HashMap Text GuardedCompiledPrimitive
 -> Maybe GuardedCompiledPrimitive)
-> RewriteMonad
     NormalizeState (HashMap Text GuardedCompiledPrimitive)
-> RewriteMonad NormalizeState (Maybe GuardedCompiledPrimitive)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (HashMap Text GuardedCompiledPrimitive)
  (RewriteState NormalizeState)
  (HashMap Text GuardedCompiledPrimitive)
-> RewriteMonad
     NormalizeState (HashMap Text GuardedCompiledPrimitive)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState
 -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
-> RewriteState NormalizeState
-> Const
     (HashMap Text GuardedCompiledPrimitive)
     (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState
  -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
 -> RewriteState NormalizeState
 -> Const
      (HashMap Text GuardedCompiledPrimitive)
      (RewriteState NormalizeState))
-> ((HashMap Text GuardedCompiledPrimitive
     -> Const
          (HashMap Text GuardedCompiledPrimitive)
          (HashMap Text GuardedCompiledPrimitive))
    -> NormalizeState
    -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
-> Getting
     (HashMap Text GuardedCompiledPrimitive)
     (RewriteState NormalizeState)
     (HashMap Text GuardedCompiledPrimitive)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(HashMap Text GuardedCompiledPrimitive
 -> Const
      (HashMap Text GuardedCompiledPrimitive)
      (HashMap Text GuardedCompiledPrimitive))
-> NormalizeState
-> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState
Lens' NormalizeState (HashMap Text GuardedCompiledPrimitive)
primitives)
  case Maybe GuardedCompiledPrimitive
bbM of
    Just (GuardedCompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim ->  Just (BlackBox pNm :: Text
pNm _ _ _ _ _ _ inc :: [((Text, Text), BlackBox)]
inc templ :: BlackBox
templ)) -> do
      let usedArgs :: [Int]
usedArgs | Text -> Bool
isFromInt Text
pNm
                   = [0,1,2]
                   | Text
nm Text -> [Text] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ["Clash.Annotations.BitRepresentation.Deriving.dontApplyInHDL"
                               ]
                   = [0,1]
                   | Bool
otherwise
                   = BlackBox -> [Int]
usedArguments BlackBox
templ [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (((Text, Text), BlackBox) -> [Int])
-> [((Text, Text), BlackBox)] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BlackBox -> [Int]
usedArguments (BlackBox -> [Int])
-> (((Text, Text), BlackBox) -> BlackBox)
-> ((Text, Text), BlackBox)
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Text, Text), BlackBox) -> BlackBox
forall a b. (a, b) -> b
snd) [((Text, Text), BlackBox)]
inc
      TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
      [Either Term Kind]
args' <- TyConMap
-> Int
-> [Int]
-> [Either Term Kind]
-> RewriteMonad NormalizeState [Either Term Kind]
forall (m :: * -> *) (t :: * -> *) b.
(Monad m, Foldable t) =>
TyConMap -> Int -> t Int -> [Either Term b] -> m [Either Term b]
go TyConMap
tcm 0 [Int]
usedArgs [Either Term Kind]
args
      if [Either Term Kind]
args [Either Term Kind] -> [Either Term Kind] -> Bool
forall a. Eq a => a -> a -> Bool
== [Either Term Kind]
args'
         then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
         else Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
p [TickInfo]
ticks) [Either Term Kind]
args')
    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    arity :: Int
arity = [Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Kind] -> Int)
-> (([Either TyVar Kind], Kind) -> [Kind])
-> ([Either TyVar Kind], Kind)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either TyVar Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights ([Either TyVar Kind] -> [Kind])
-> (([Either TyVar Kind], Kind) -> [Either TyVar Kind])
-> ([Either TyVar Kind], Kind)
-> [Kind]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Either TyVar Kind], Kind) -> [Either TyVar Kind]
forall a b. (a, b) -> a
fst (([Either TyVar Kind], Kind) -> Int)
-> ([Either TyVar Kind], Kind) -> Int
forall a b. (a -> b) -> a -> b
$ Kind -> ([Either TyVar Kind], Kind)
splitFunForallTy (PrimInfo -> Kind
primType PrimInfo
pInfo)

    go :: TyConMap -> Int -> t Int -> [Either Term b] -> m [Either Term b]
go _ _ _ [] = [Either Term b] -> m [Either Term b]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    go tcm :: TyConMap
tcm n :: Int
n used :: t Int
used (Right ty :: b
ty:args' :: [Either Term b]
args') = do
      [Either Term b]
args'' <- TyConMap -> Int -> t Int -> [Either Term b] -> m [Either Term b]
go TyConMap
tcm Int
n t Int
used [Either Term b]
args'
      [Either Term b] -> m [Either Term b]
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Either Term b
forall a b. b -> Either a b
Right b
ty Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')
    go tcm :: TyConMap
tcm n :: Int
n used :: t Int
used (Left tm :: Term
tm : args' :: [Either Term b]
args') = do
      [Either Term b]
args'' <- TyConMap -> Int -> t Int -> [Either Term b] -> m [Either Term b]
go TyConMap
tcm (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) t Int
used [Either Term b]
args'
      let ty :: Kind
ty = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
tm
          p' :: Term
p' = Kind -> Term
removedTm Kind
ty
      if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
arity Bool -> Bool -> Bool
&& Int
n Int -> t Int -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` t Int
used
         then [Either Term b] -> m [Either Term b]
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Either Term b
forall a b. a -> Either a b
Left Term
p' Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')
         else [Either Term b] -> m [Either Term b]
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Either Term b
forall a b. a -> Either a b
Left Term
tm Either Term b -> [Either Term b] -> [Either Term b]
forall a. a -> [a] -> [a]
: [Either Term b]
args'')

removeUnusedExpr _ e :: Term
e@(Case _ _ [(DataPat _ [] xs :: [Id]
xs,altExpr :: Term
altExpr)]) =
  if [Id]
xs [Id] -> Term -> Bool
`localIdsDoNotOccurIn` Term
altExpr
     then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
altExpr
     else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- Replace any expression that creates a Vector of size 0 within the application
-- of the Cons constructor, by the Nil constructor.
removeUnusedExpr _ e :: Term
e@(Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (Data dc :: DataCon
dc, [_,Right aTy :: Kind
aTy,Right nTy :: Kind
nTy,_,Left a :: Term
a,Left nil :: Term
nil],ticks :: [TickInfo]
ticks))
  | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Vector.Cons"
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
      Right 0
        | (con :: Term
con, _) <- Term -> (Term, [Either Term Kind])
collectArgs Term
nil
        , Bool -> Bool
not (Term -> Bool
isCon Term
con)
        -> let eTy :: Kind
eTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e
               (TyConApp vecTcNm :: TyConName
vecTcNm _) = Kind -> TypeView
tyView Kind
eTy
               (Just vecTc :: TyCon
vecTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
vecTcNm TyConMap
tcm
               [nilCon :: DataCon
nilCon,consCon :: DataCon
consCon] = TyCon -> [DataCon]
tyConDataCons TyCon
vecTc
               v :: Term
v = Term -> [TickInfo] -> Term
mkTicks (DataCon -> DataCon -> Kind -> Integer -> [Term] -> Term
mkVec DataCon
nilCon DataCon
consCon Kind
aTy 1 [Term
a]) [TickInfo]
ticks
           in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
v
      _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

removeUnusedExpr _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Inline let-bindings when the RHS is either a local variable reference or
-- is constant (except clock or reset generators)
bindConstantVar :: HasCallStack => NormRewrite
bindConstantVar :: NormRewrite
bindConstantVar = (Term -> LetBinding -> RewriteMonad NormalizeState Bool)
-> NormRewrite
forall extra.
(Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders Term -> LetBinding -> RewriteMonad NormalizeState Bool
forall p a. p -> (a, Term) -> RewriteMonad NormalizeState Bool
test
  where
    test :: p -> (a, Term) -> RewriteMonad NormalizeState Bool
test _ (_,Term -> Term
stripTicks -> Term
e) = case Term -> Bool
isLocalVar Term
e of
      True -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      _    -> Term -> RewriteMonad NormalizeState Bool
forall extra. Term -> RewriteMonad extra Bool
isWorkFreeIsh Term
e RewriteMonad NormalizeState Bool
-> (Bool -> RewriteMonad NormalizeState Bool)
-> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        True -> Getting Word (RewriteState NormalizeState) Word
-> RewriteMonad NormalizeState Word
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Word NormalizeState)
-> RewriteState NormalizeState
-> Const Word (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Word NormalizeState)
 -> RewriteState NormalizeState
 -> Const Word (RewriteState NormalizeState))
-> ((Word -> Const Word Word)
    -> NormalizeState -> Const Word NormalizeState)
-> Getting Word (RewriteState NormalizeState) Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Word -> Const Word Word)
-> NormalizeState -> Const Word NormalizeState
Lens' NormalizeState Word
inlineConstantLimit) RewriteMonad NormalizeState Word
-> (Word -> RewriteMonad NormalizeState Bool)
-> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          0 -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
          n :: Word
n -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Word
termSize Term
e Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<= Word
n)
        _ -> Bool -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    -- test _ _ = return False

-- | Push a cast over a case into it's alternatives.
caseCast :: HasCallStack => NormRewrite
caseCast :: NormRewrite
caseCast _ (Cast (Term -> Term
stripTicks -> Case subj :: Term
subj ty :: Kind
ty alts :: [Alt]
alts) ty1 :: Kind
ty1 ty2 :: Kind
ty2) = do
  let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (\(p :: Pat
p,e :: Term
e) -> (Pat
p, Term -> Kind -> Kind -> Term
Cast Term
e Kind
ty1 Kind
ty2)) [Alt]
alts
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Kind -> [Alt] -> Term
Case Term
subj Kind
ty [Alt]
alts')
caseCast _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Push a cast over a Letrec into it's body
letCast :: HasCallStack => NormRewrite
letCast :: NormRewrite
letCast _ (Cast (Term -> Term
stripTicks -> Letrec binds :: [LetBinding]
binds body :: Term
body) ty1 :: Kind
ty1 ty2 :: Kind
ty2) =
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ [LetBinding] -> Term -> Term
Letrec [LetBinding]
binds (Term -> Kind -> Kind -> Term
Cast Term
body Kind
ty1 Kind
ty2)
letCast _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e


-- | Push cast over an argument to a function into that function
--
-- This is done by specializing on the casted argument.
-- Example:
-- @
--   y = f (cast a)
--     where f x = g x
-- @
-- transforms to:
-- @
--   y = f' a
--     where f' x' = (\x -> g x) (cast x')
-- @
--
-- The reason d'etre for this transformation is that we hope to end up with
-- and expression where two casts are "back-to-back" after which we can
-- eliminate them in 'eliminateCastCast'.
argCastSpec :: HasCallStack => NormRewrite
argCastSpec :: NormRewrite
argCastSpec ctx :: TransformContext
ctx e :: Term
e@(App _ (Term -> Term
stripTicks -> Cast e' :: Term
e' _ _)) =
  if Term -> Bool
isWorkFree Term
e' then
    RewriteMonad NormalizeState Term
go
  else
    RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. a -> a
warn RewriteMonad NormalizeState Term
go
 where
  go :: RewriteMonad NormalizeState Term
go = NormRewrite
specializeNorm TransformContext
ctx Term
e
  warn :: a -> a
warn = String -> a -> a
forall a. String -> a -> a
trace ([String] -> String
unwords
    [ "WARNING:", $(curLoc), "specializing a function on a non work-free"
    , "cast. Generated HDL implementation might contain duplicate work."
    , "Please report this as a bug.", "\n\nExpression where this occured:"
    , "\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
e
    ])
argCastSpec _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Only inline casts that just contain a 'Var', because these are guaranteed work-free.
-- These are the result of the 'splitCastWork' transformation.
inlineCast :: HasCallStack => NormRewrite
inlineCast :: NormRewrite
inlineCast = (Term -> LetBinding -> RewriteMonad NormalizeState Bool)
-> NormRewrite
forall extra.
(Term -> LetBinding -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders Term -> LetBinding -> RewriteMonad NormalizeState Bool
forall (m :: * -> *) p a. Monad m => p -> (a, Term) -> m Bool
test
  where
    test :: p -> (a, Term) -> m Bool
test _ (_, (Cast (Term -> Term
stripTicks -> Var {}) _ _)) = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    test _ _ = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Eliminate two back to back casts where the type going in and coming out are the same
--
-- @
--   (cast :: b -> a) $ (cast :: a -> b) x   ==> x
-- @
eliminateCastCast :: HasCallStack => NormRewrite
eliminateCastCast :: NormRewrite
eliminateCastCast _ c :: Term
c@(Cast (Term -> Term
stripTicks -> Cast e :: Term
e tyA :: Kind
tyA tyB :: Kind
tyB) tyB' :: Kind
tyB' tyC :: Kind
tyC) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let ntyA :: Kind
ntyA  = TyConMap -> Kind -> Kind
normalizeType TyConMap
tcm Kind
tyA
      ntyB :: Kind
ntyB  = TyConMap -> Kind -> Kind
normalizeType TyConMap
tcm Kind
tyB
      ntyB' :: Kind
ntyB' = TyConMap -> Kind -> Kind
normalizeType TyConMap
tcm Kind
tyB'
      ntyC :: Kind
ntyC  = TyConMap -> Kind -> Kind
normalizeType TyConMap
tcm Kind
tyC
  if Kind
ntyB Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
ntyB' Bool -> Bool -> Bool
&& Kind
ntyA Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
ntyC then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
                                   else RewriteMonad NormalizeState Term
forall b. RewriteMonad NormalizeState b
throwError
  where throwError :: RewriteMonad NormalizeState b
throwError = do
          (nm :: Id
nm,sp :: SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
          ClashException -> RewriteMonad NormalizeState b
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp ($(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm
                  String -> String -> String
forall a. [a] -> [a] -> [a]
++ ": Found 2 nested casts whose types don't line up:\n"
                  String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c)
                Maybe String
forall a. Maybe a
Nothing)

eliminateCastCast _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Make a cast work-free by splitting the work of to a separate binding
--
-- @
-- let x = cast (f a b)
-- ==>
-- let x  = cast x'
--     x' = f a b
-- @
splitCastWork :: HasCallStack => NormRewrite
splitCastWork :: NormRewrite
splitCastWork ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) unchanged :: Term
unchanged@(Letrec vs :: [LetBinding]
vs e' :: Term
e') = do
  (vss' :: [[LetBinding]]
vss', Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState ([[LetBinding]], Any)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ((LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
forall extra.
InScopeSet -> LetBinding -> RewriteMonad extra [LetBinding]
splitCastLetBinding InScopeSet
is0) [LetBinding]
vs)
  let vs' :: [LetBinding]
vs' = [[LetBinding]] -> [LetBinding]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[LetBinding]]
vss'
  if Bool
hasChanged then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
vs' Term
e')
                else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
unchanged
  where
    splitCastLetBinding
      :: InScopeSet
      -> LetBinding
      -> RewriteMonad extra [LetBinding]
    splitCastLetBinding :: InScopeSet -> LetBinding -> RewriteMonad extra [LetBinding]
splitCastLetBinding isN :: InScopeSet
isN x :: LetBinding
x@(nm :: Id
nm, e :: Term
e) = case Term -> Term
stripTicks Term
e of
      Cast (Var {}) _ _  -> [LetBinding] -> RewriteMonad extra [LetBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [LetBinding
x]  -- already work-free
      Cast (Cast {}) _ _ -> [LetBinding] -> RewriteMonad extra [LetBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [LetBinding
x]  -- casts will be eliminated
      Cast e0 :: Term
e0 ty1 :: Kind
ty1 ty2 :: Kind
ty2 -> do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        Id
nm' <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad extra Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
isN TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx (Name Term -> Text
forall a. Name a -> Text
nameOcc (Name Term -> Text) -> Name Term -> Text
forall a b. (a -> b) -> a -> b
$ Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)) Term
e0
        [LetBinding] -> RewriteMonad extra [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
nm',Term
e0)
                ,(Id
nm, Term -> Kind -> Kind -> Term
Cast (Id -> Term
Var Id
nm') Kind
ty1 Kind
ty2)
                ]
      _ -> [LetBinding] -> RewriteMonad extra [LetBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [LetBinding
x]

splitCastWork _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e


-- | Inline work-free functions, i.e. fully applied functions that evaluate to
-- a constant
inlineWorkFree :: HasCallStack => NormRewrite
inlineWorkFree :: NormRewrite
inlineWorkFree (TransformContext localScope :: InScopeSet
localScope _) e :: Term
e@(Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (Var f :: Id
f,args :: [Either Term Kind]
args@(_:_),ticks :: [TickInfo]
ticks))
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let eTy :: Kind
eTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e
    Bool
argsHaveWork <- [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool)
-> RewriteMonad NormalizeState [Bool]
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Either Term Kind -> RewriteMonad NormalizeState Bool)
-> [Either Term Kind] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Term -> RewriteMonad NormalizeState Bool)
-> (Kind -> RewriteMonad NormalizeState Bool)
-> Either Term Kind
-> RewriteMonad NormalizeState Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> RewriteMonad NormalizeState Bool
forall (m :: * -> *). MonadReader RewriteEnv m => Term -> m Bool
expressionHasWork
                                        (RewriteMonad NormalizeState Bool
-> Kind -> RewriteMonad NormalizeState Bool
forall a b. a -> b -> a
const (Bool -> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False)))
                                [Either Term Kind]
args
    Bool
untranslatable <- Bool -> Kind -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Kind -> RewriteMonad extra Bool
isUntranslatableType Bool
True Kind
eTy
    let isSignal :: Bool
isSignal = TyConMap -> Kind -> Bool
isSignalType TyConMap
tcm Kind
eTy
    let lv :: Bool
lv = Id -> Bool
forall a. Var a -> Bool
isLocalId Id
f
    if Bool
untranslatable Bool -> Bool -> Bool
|| Bool
isSignal Bool -> Bool -> Bool
|| Bool
argsHaveWork Bool -> Bool -> Bool
|| Bool
lv
      then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      else do
        VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs <- Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
        case Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs of
          -- Don't inline recursive expressions
          Just (_,_,_,body :: Term
body) -> do
            Bool
isRecBndr <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
f
            if Bool
isRecBndr
               then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
               else do
                 -- See Note [AppProp no-shadow invariant]
                 Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
localScope Term
body) [TickInfo]
ticks) [Either Term Kind]
args)
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    -- an expression is has work when it contains free local variables,
    -- or has a Signal type, i.e. it does not evaluate to a work-free
    -- constant.
    expressionHasWork :: Term -> m Bool
expressionHasWork e' :: Term
e' = do
      let fvIds :: [Id]
fvIds = Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
freeLocalIds Term
e'
      TyConMap
tcm   <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
      let e'Ty :: Kind
e'Ty     = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e'
          isSignal :: Bool
isSignal = TyConMap -> Kind -> Bool
isSignalType TyConMap
tcm Kind
e'Ty
      Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool
not ([Id] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Id]
fvIds) Bool -> Bool -> Bool
|| Bool
isSignal)

inlineWorkFree (TransformContext localScope :: InScopeSet
localScope _) e :: Term
e@(Var f :: Id
f) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let fTy :: Kind
fTy      = Id -> Kind
forall a. Var a -> Kind
varType Id
f
      closed :: Bool
closed   = Bool -> Bool
not (TyConMap -> Kind -> Bool
isPolyFunCoreTy TyConMap
tcm Kind
fTy)
      isSignal :: Bool
isSignal = TyConMap -> Kind -> Bool
isSignalType TyConMap
tcm Kind
fTy
  Bool
untranslatable <- Bool -> Kind -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Kind -> RewriteMonad extra Bool
isUntranslatableType Bool
True Kind
fTy
  let gv :: Bool
gv = Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
f
  if Bool
closed Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
untranslatable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isSignal Bool -> Bool -> Bool
&& Bool
gv
    then do
      VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs <- Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
      case Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs of
        -- Don't inline recursive expressions
        Just top :: (Id, SrcSpan, InlineSpec, Term)
top -> do
          Bool
isRecBndr <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
f
          if Bool
isRecBndr
             then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
             else do
              (_,_,_,body :: Term
body) <- Id
-> (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
normalizeTopLvlBndr Id
f (Id, SrcSpan, InlineSpec, Term)
top
              -- See Note [AppProp no-shadow invariant]
              Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
localScope Term
body)
        _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
    else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

inlineWorkFree _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Inline small functions
inlineSmall :: HasCallStack => NormRewrite
inlineSmall :: NormRewrite
inlineSmall (TransformContext localScope :: InScopeSet
localScope _) e :: Term
e@(Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (Var f :: Id
f,args :: [Either Term Kind]
args,ticks :: [TickInfo]
ticks)) = do
  Bool
untranslatable <- Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
True Term
e
  UniqSet (Var Any)
topEnts <- Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
-> RewriteMonad NormalizeState (UniqSet (Var Any))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
Lens' RewriteEnv (UniqSet (Var Any))
topEntities
  let lv :: Bool
lv = Id -> Bool
forall a. Var a -> Bool
isLocalId Id
f
  if Bool
untranslatable Bool -> Bool -> Bool
|| Id
f Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
topEnts Bool -> Bool -> Bool
|| Bool
lv
    then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
    else do
      VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs <- Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
      Word
sizeLimit <- Getting Word (RewriteState NormalizeState) Word
-> RewriteMonad NormalizeState Word
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Word NormalizeState)
-> RewriteState NormalizeState
-> Const Word (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Word NormalizeState)
 -> RewriteState NormalizeState
 -> Const Word (RewriteState NormalizeState))
-> ((Word -> Const Word Word)
    -> NormalizeState -> Const Word NormalizeState)
-> Getting Word (RewriteState NormalizeState) Word
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Word -> Const Word Word)
-> NormalizeState -> Const Word NormalizeState
Lens' NormalizeState Word
inlineFunctionLimit)
      case Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs of
        -- Don't inline recursive expressions
        Just (_,_,inl :: InlineSpec
inl,body :: Term
body) -> do
          Bool
isRecBndr <- Id -> RewriteMonad NormalizeState Bool
isRecursiveBndr Id
f
          if Bool -> Bool
not Bool
isRecBndr Bool -> Bool -> Bool
&& InlineSpec
inl InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
/= InlineSpec
NoInline Bool -> Bool -> Bool
&& Term -> Word
termSize Term
body Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
< Word
sizeLimit
             then do
               -- See Note [AppProp no-shadow invariant]
               Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
localScope Term
body) [TickInfo]
ticks) [Either Term Kind]
args)
             else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
        _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

inlineSmall _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Specialise functions on arguments which are constant, except when they
-- are clock, reset generators.
constantSpec :: HasCallStack => NormRewrite
constantSpec :: NormRewrite
constantSpec ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 tfCtx :: Context
tfCtx) e :: Term
e@(App e1 :: Term
e1 e2 :: Term
e2)
  | (Var {}, args :: [Either Term Kind]
args) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e1
  , (_, []) <- [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
  , [TyVar] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Term TyVar -> Term -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Term TyVar
Fold Term TyVar
termFreeTyVars Term
e2
  = do ConstantSpecInfo
specInfo<- TransformContext
-> Term -> RewriteMonad NormalizeState ConstantSpecInfo
constantSpecInfo TransformContext
ctx Term
e2
       if ConstantSpecInfo -> Bool
csrFoundConstant ConstantSpecInfo
specInfo then
         let newBindings :: [LetBinding]
newBindings = ConstantSpecInfo -> [LetBinding]
csrNewBindings ConstantSpecInfo
specInfo in
         if [LetBinding] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LetBinding]
newBindings then
           -- Whole of e2 is constant
           NormRewrite
specializeNorm TransformContext
ctx (Term -> Term -> Term
App Term
e1 Term
e2)
         else do
           -- Parts of e2 are constant
           let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 (LetBinding -> Id
forall a b. (a, b) -> a
fst (LetBinding -> Id) -> [LetBinding] -> [Id]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConstantSpecInfo -> [LetBinding]
csrNewBindings ConstantSpecInfo
specInfo)
           -- Deshadow because appPropFast will be called after constantSpec
           HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is0
            (Term -> Term) -> (Term -> Term) -> Term -> Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [LetBinding] -> Term -> Term
Letrec [LetBinding]
newBindings
            (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NormRewrite
specializeNorm
                  (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 Context
tfCtx)
                  (Term -> Term -> Term
App Term
e1 (ConstantSpecInfo -> Term
csrNewTerm ConstantSpecInfo
specInfo))

       else
        -- e2 has no constant parts
        Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
constantSpec _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e


-- Experimental

-- | Propagate arguments of application inwards; except for 'Lam' where the
-- argument becomes let-bound.
--
-- Note [AppProp deshadow]
--
-- Imagine:
--
-- @
-- (case x of
--    D a b -> h a) (f x y)
-- @
--
-- rewriting this to:
--
-- @
-- let b = f x y
-- in  case x of
--       D a b -> h a b
-- @
--
-- is very bad because 'b' in 'h a b' is now bound by the pattern instead of the
-- newly introduced let-binding
--
-- instead me must rewrite to:
--
-- @
-- let b1 = f x y
-- in  case x of
--       D a b -> h a b1
-- @
--
-- Note [AppProp no-shadow invariant]
--
-- Imagine
--
-- @
-- (\x -> e) u
-- @
--
-- where @u@ has a free variable named @x@, rewriting this to:
--
-- @
-- let x = u
-- in  e
-- @
--
-- would be very bad, because the let-binding suddenly captures the free
-- variable in @u@. The same for:
--
-- @
-- (let x = w in e) u
-- @
--
-- where @u@ again has a free variable @x@, rewriting this to:
--
-- @
-- let x = w in (e u)
-- @
--
-- would be bad because the let-binding now captures the free variable in @u@.
--
-- To prevent this from happening, we can either:
--
-- 1. Rename the bindings, so that they cannot capture
-- 2. Ensure that @AppProp@ is only called in a context where there is no
--    shadowing, i.e. the bindings can never never collide with the current
--    inScopeSet.
--
-- We have gone for option 2 so that AppProp requires less computation and
-- because AppProp is such a commonly applied transformation. This
-- means that when normalisation starts we deshadow the expression, and when
-- we inline global binders, we ensure that inlined expression is deshadowed
-- taking the InScopeSet of the context into account.
appProp :: HasCallStack => NormRewrite
appProp :: NormRewrite
appProp (TransformContext is0 :: InScopeSet
is0 _) (App (Term -> (Term, [TickInfo])
collectTicks -> (Lam v :: Id
v e :: Term
e,ticks :: [TickInfo]
ticks)) arg :: Term
arg) =
  if Term -> Bool
isWorkFree Term
arg Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
arg
    then do
      let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) Id
v Term
arg
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [TickInfo] -> Term
mkTicks (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "appProp.AppLam" Subst
subst Term
e) [TickInfo]
ticks
    else Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ [LetBinding] -> Term -> Term
Letrec [(Id
v, Term
arg)] (Term -> [TickInfo] -> Term
mkTicks Term
e [TickInfo]
ticks)

appProp _ (App (Term -> (Term, [TickInfo])
collectTicks -> (Letrec v :: [LetBinding]
v e :: Term
e, ticks :: [TickInfo]
ticks)) arg :: Term
arg) = do
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
v (Term -> Term -> Term
App (Term -> [TickInfo] -> Term
mkTicks Term
e [TickInfo]
ticks) Term
arg))

appProp ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) (App (Term -> (Term, [TickInfo])
collectTicks -> (Case scrut :: Term
scrut ty :: Kind
ty alts :: [Alt]
alts,ticks :: [TickInfo]
ticks)) arg :: Term
arg) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let argTy :: Kind
argTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
arg
      ty' :: Kind
ty' = TyConMap -> Kind -> Kind -> Kind
applyFunTy TyConMap
tcm Kind
ty Kind
argTy
  if Term -> Bool
isWorkFree Term
arg Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
arg
    then do
      let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> Alt -> Alt
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term -> Term -> Term
`App` Term
arg)) [Alt]
alts
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [TickInfo] -> Term
mkTicks (Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
ty' [Alt]
alts') [TickInfo]
ticks
    else do
      -- See Note [AppProp deshadow]
      let is2 :: InScopeSet
is2 = InScopeSet -> InScopeSet -> InScopeSet
unionInScope InScopeSet
is0 ((UniqSet (Var Any) -> InScopeSet
mkInScopeSet (UniqSet (Var Any) -> InScopeSet)
-> ([Alt] -> UniqSet (Var Any)) -> [Alt] -> InScopeSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var Any] -> UniqSet (Var Any)
forall a. [Var a] -> UniqSet (Var Any)
mkVarSet ([Var Any] -> UniqSet (Var Any))
-> ([Alt] -> [Var Any]) -> [Alt] -> UniqSet (Var Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alt -> [Var Any]) -> [Alt] -> [Var Any]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat -> [Var Any]
forall a. Pat -> [Var a]
patVars (Pat -> [Var Any]) -> (Alt -> Pat) -> Alt -> [Var Any]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst)) [Alt]
alts)
      Id
boundArg <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is2 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "app_arg") Term
arg
      let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> Alt -> Alt
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term -> Term -> Term
`App` (Id -> Term
Var Id
boundArg))) [Alt]
alts
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [(Id
boundArg, Term
arg)] (Term -> [TickInfo] -> Term
mkTicks (Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
ty' [Alt]
alts') [TickInfo]
ticks))

appProp (TransformContext is0 :: InScopeSet
is0 _) (TyApp (Term -> (Term, [TickInfo])
collectTicks -> (TyLam tv :: TyVar
tv e :: Term
e,ticks :: [TickInfo]
ticks)) t :: Kind
t) = do
  let subst :: Subst
subst = Subst -> TyVar -> Kind -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) TyVar
tv Kind
t
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [TickInfo] -> Term
mkTicks (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "appProp.TyAppTyLam" Subst
subst Term
e) [TickInfo]
ticks

appProp _ (TyApp (Term -> (Term, [TickInfo])
collectTicks -> (Letrec v :: [LetBinding]
v e :: Term
e,ticks :: [TickInfo]
ticks)) t :: Kind
t) = do
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
v (Term -> [TickInfo] -> Term
mkTicks (Term -> Kind -> Term
TyApp Term
e Kind
t) [TickInfo]
ticks))

appProp _ (TyApp (Term -> (Term, [TickInfo])
collectTicks -> (Case scrut :: Term
scrut altsTy :: Kind
altsTy alts :: [Alt]
alts,ticks :: [TickInfo]
ticks)) ty :: Kind
ty) = do
  let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> Alt -> Alt
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term -> Kind -> Term
`TyApp` Kind
ty)) [Alt]
alts
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let ty' :: Kind
ty' = TyConMap -> Kind -> Kind -> Kind
piResultTy TyConMap
tcm Kind
altsTy Kind
ty
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks (Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
ty' [Alt]
alts') [TickInfo]
ticks)

appProp _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Unlike 'appProp', which propagates a single argument in an application one
-- level down (and should be called in an innermost traversal), 'appPropFast'
-- tries to propagate as many arguments as possible, down as many levels as
-- possible; and should be called in a top-down traversal.
--
-- The idea is that this reduces the number of traversals, which hopefully leads
-- to shorter compile times.
--
-- Implementation only works if terms are fully deshadowed, see
-- Note [AppProp deshadow]
appPropFast :: HasCallStack => NormRewrite
appPropFast :: NormRewrite
appPropFast ctx :: TransformContext
ctx@(TransformContext is :: InScopeSet
is _) = \case
  e :: Term
e@App {}   -> (Term
 -> [Either Term Kind]
 -> [TickInfo]
 -> RewriteMonad NormalizeState Term)
-> (Term, [Either Term Kind], [TickInfo])
-> RewriteMonad NormalizeState Term
forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 (InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is) (Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
e)
  e :: Term
e@TyApp {} -> (Term
 -> [Either Term Kind]
 -> [TickInfo]
 -> RewriteMonad NormalizeState Term)
-> (Term, [Either Term Kind], [TickInfo])
-> RewriteMonad NormalizeState Term
forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 (InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is) (Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
e)
  e :: Term
e          -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
 where
  go :: InScopeSet -> Term -> [Either Term Type] -> [TickInfo]
     -> NormalizeSession Term
  go :: InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go is0 :: InScopeSet
is0 (Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (fun :: Term
fun,args0 :: [Either Term Kind]
args0@(_:_),ticks0 :: [TickInfo]
ticks0)) args1 :: [Either Term Kind]
args1 ticks1 :: [TickInfo]
ticks1 =
    InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 Term
fun ([Either Term Kind]
args0 [Either Term Kind] -> [Either Term Kind] -> [Either Term Kind]
forall a. [a] -> [a] -> [a]
++ [Either Term Kind]
args1) ([TickInfo]
ticks0 [TickInfo] -> [TickInfo] -> [TickInfo]
forall a. [a] -> [a] -> [a]
++ [TickInfo]
ticks1)

  go is0 :: InScopeSet
is0 (Lam v :: Id
v e :: Term
e) (Left arg :: Term
arg:args :: [Either Term Kind]
args) ticks :: [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    if Term -> Bool
isWorkFree Term
arg Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
arg
      then do
        let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) Id
v Term
arg
        (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "appPropFast.AppLam" Subst
subst Term
e) [Either Term Kind]
args []
      else do
        let is1 :: InScopeSet
is1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
v
        [LetBinding] -> Term -> Term
Letrec [(Id
v, Term
arg)] (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Kind]
args [TickInfo]
ticks

  go is0 :: InScopeSet
is0 (Letrec vs :: [LetBinding]
vs e :: Term
e) args :: [Either Term Kind]
args@(_:_) ticks :: [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let vbs :: [Id]
vbs  = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
vs
        is1 :: InScopeSet
is1  = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
vbs
    [LetBinding] -> Term -> Term
Letrec [LetBinding]
vs (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Kind]
args [TickInfo]
ticks

  go is0 :: InScopeSet
is0 (TyLam tv :: TyVar
tv e :: Term
e) (Right t :: Kind
t:args :: [Either Term Kind]
args) ticks :: [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let subst :: Subst
subst = Subst -> TyVar -> Kind -> Subst
extendTvSubst (InScopeSet -> Subst
mkSubst InScopeSet
is0) TyVar
tv Kind
t
    (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "appPropFast.TyAppTyLam" Subst
subst Term
e) [Either Term Kind]
args []

  go is0 :: InScopeSet
is0 (Case scrut :: Term
scrut ty0 :: Kind
ty0 alts :: [Alt]
alts) args0 :: [Either Term Kind]
args0@(_:_) ticks :: [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    let isA1 :: InScopeSet
isA1 = InScopeSet -> InScopeSet -> InScopeSet
unionInScope
                 InScopeSet
is0
                 ((UniqSet (Var Any) -> InScopeSet
mkInScopeSet (UniqSet (Var Any) -> InScopeSet)
-> ([Alt] -> UniqSet (Var Any)) -> [Alt] -> InScopeSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var Any] -> UniqSet (Var Any)
forall a. [Var a] -> UniqSet (Var Any)
mkVarSet ([Var Any] -> UniqSet (Var Any))
-> ([Alt] -> [Var Any]) -> [Alt] -> UniqSet (Var Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alt -> [Var Any]) -> [Alt] -> [Var Any]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat -> [Var Any]
forall a. Pat -> [Var a]
patVars (Pat -> [Var Any]) -> (Alt -> Pat) -> Alt -> [Var Any]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst)) [Alt]
alts)
    (ty1 :: Kind
ty1,vs :: [LetBinding]
vs,args1 :: [Either Term Kind]
args1) <- InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> RewriteMonad
     NormalizeState (Kind, [LetBinding], [Either Term Kind])
forall (m :: * -> *).
(MonadReader RewriteEnv m, MonadUnique m, MonadFail m) =>
InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> m (Kind, [LetBinding], [Either Term Kind])
goCaseArg InScopeSet
isA1 Kind
ty0 [] [Either Term Kind]
args0
    case [LetBinding]
vs of
      [] -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
ty1 ([Alt] -> Term)
-> RewriteMonad NormalizeState [Alt]
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> [Either Term Kind] -> Alt -> RewriteMonad NormalizeState Alt
goAlt InScopeSet
is0 [Either Term Kind]
args1) [Alt]
alts
      _  -> do
        let vbs :: [Id]
vbs = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
vs
            is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
vbs
        [LetBinding] -> Term -> Term
Letrec [LetBinding]
vs (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term) -> ([Alt] -> Term) -> [Alt] -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Kind -> [Alt] -> Term
Case Term
scrut Kind
ty1 ([Alt] -> Term)
-> RewriteMonad NormalizeState [Alt]
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt -> RewriteMonad NormalizeState Alt)
-> [Alt] -> RewriteMonad NormalizeState [Alt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> [Either Term Kind] -> Alt -> RewriteMonad NormalizeState Alt
goAlt InScopeSet
is1 [Either Term Kind]
args1) [Alt]
alts

  go is0 :: InScopeSet
is0 (Tick sp :: TickInfo
sp e :: Term
e) args :: [Either Term Kind]
args ticks :: [TickInfo]
ticks = do
    RewriteMonad NormalizeState ()
forall extra. RewriteMonad extra ()
setChanged
    InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is0 Term
e [Either Term Kind]
args (TickInfo
spTickInfo -> [TickInfo] -> [TickInfo]
forall a. a -> [a] -> [a]
:[TickInfo]
ticks)

  go _ fun :: Term
fun args :: [Either Term Kind]
args ticks :: [TickInfo]
ticks = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) [Either Term Kind]
args)

  goAlt :: InScopeSet
-> [Either Term Kind] -> Alt -> RewriteMonad NormalizeState Alt
goAlt is0 :: InScopeSet
is0 args0 :: [Either Term Kind]
args0 (p :: Pat
p,e :: Term
e) = do
    let (tvs :: [TyVar]
tvs,ids :: [Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
        is1 :: InScopeSet
is1       = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [TyVar]
tvs) [Id]
ids
    (Pat
p,) (Term -> Alt)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Alt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Term
-> [Either Term Kind]
-> [TickInfo]
-> RewriteMonad NormalizeState Term
go InScopeSet
is1 Term
e [Either Term Kind]
args0 []

  goCaseArg :: InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> m (Kind, [LetBinding], [Either Term Kind])
goCaseArg isA :: InScopeSet
isA ty0 :: Kind
ty0 ls0 :: [LetBinding]
ls0 (Right t :: Kind
t:args0 :: [Either Term Kind]
args0) = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let ty1 :: Kind
ty1 = TyConMap -> Kind -> Kind -> Kind
piResultTy TyConMap
tcm Kind
ty0 Kind
t
    (ty2 :: Kind
ty2,ls1 :: [LetBinding]
ls1,args1 :: [Either Term Kind]
args1) <- InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> m (Kind, [LetBinding], [Either Term Kind])
goCaseArg InScopeSet
isA Kind
ty1 [LetBinding]
ls0 [Either Term Kind]
args0
    (Kind, [LetBinding], [Either Term Kind])
-> m (Kind, [LetBinding], [Either Term Kind])
forall (m :: * -> *) a. Monad m => a -> m a
return (Kind
ty2,[LetBinding]
ls1,Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
tEither Term Kind -> [Either Term Kind] -> [Either Term Kind]
forall a. a -> [a] -> [a]
:[Either Term Kind]
args1)

  goCaseArg isA0 :: InScopeSet
isA0 ty0 :: Kind
ty0 ls0 :: [LetBinding]
ls0 (Left arg :: Term
arg:args0 :: [Either Term Kind]
args0) = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let argTy :: Kind
argTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
arg
        ty1 :: Kind
ty1   = TyConMap -> Kind -> Kind -> Kind
applyFunTy TyConMap
tcm Kind
ty0 Kind
argTy
    case Term -> Bool
isWorkFree Term
arg Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
arg of
      True -> do
        (ty2 :: Kind
ty2,ls1 :: [LetBinding]
ls1,args1 :: [Either Term Kind]
args1) <- InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> m (Kind, [LetBinding], [Either Term Kind])
goCaseArg InScopeSet
isA0 Kind
ty1 [LetBinding]
ls0 [Either Term Kind]
args0
        (Kind, [LetBinding], [Either Term Kind])
-> m (Kind, [LetBinding], [Either Term Kind])
forall (m :: * -> *) a. Monad m => a -> m a
return (Kind
ty2,[LetBinding]
ls1,Term -> Either Term Kind
forall a b. a -> Either a b
Left Term
argEither Term Kind -> [Either Term Kind] -> [Either Term Kind]
forall a. a -> [a] -> [a]
:[Either Term Kind]
args1)
      False -> do
        Id
boundArg <- InScopeSet -> TyConMap -> Name Term -> Term -> m Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
isA0 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "app_arg") Term
arg
        let isA1 :: InScopeSet
isA1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isA0 Id
boundArg
        (ty2 :: Kind
ty2,ls1 :: [LetBinding]
ls1,args1 :: [Either Term Kind]
args1) <- InScopeSet
-> Kind
-> [LetBinding]
-> [Either Term Kind]
-> m (Kind, [LetBinding], [Either Term Kind])
goCaseArg InScopeSet
isA1 Kind
ty1 [LetBinding]
ls0 [Either Term Kind]
args0
        (Kind, [LetBinding], [Either Term Kind])
-> m (Kind, [LetBinding], [Either Term Kind])
forall (m :: * -> *) a. Monad m => a -> m a
return (Kind
ty2,(Id
boundArg,Term
arg)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
ls1,Term -> Either Term Kind
forall a b. a -> Either a b
Left (Id -> Term
Var Id
boundArg)Either Term Kind -> [Either Term Kind] -> [Either Term Kind]
forall a. a -> [a] -> [a]
:[Either Term Kind]
args1)

  goCaseArg _ ty :: Kind
ty ls :: [LetBinding]
ls [] = (Kind, [LetBinding], [Either Term Kind])
-> m (Kind, [LetBinding], [Either Term Kind])
forall (m :: * -> *) a. Monad m => a -> m a
return (Kind
ty,[LetBinding]
ls,[])

-- | Flatten ridiculous case-statements generated by GHC
--
-- For case-statements in haskell of the form:
--
-- @
-- f :: Unsigned 4 -> Unsigned 4
-- f x = case x of
--   0 -> 3
--   1 -> 2
--   2 -> 1
--   3 -> 0
-- @
--
-- GHC generates Core that looks like:
--
-- @
-- f = \(x :: Unsigned 4) -> case x == fromInteger 3 of
--                             False -> case x == fromInteger 2 of
--                               False -> case x == fromInteger 1 of
--                                 False -> case x == fromInteger 0 of
--                                   False -> error "incomplete case"
--                                   True  -> fromInteger 3
--                                 True -> fromInteger 2
--                               True -> fromInteger 1
--                             True -> fromInteger 0
-- @
--
-- Which would result in a priority decoder circuit where a normal decoder
-- circuit was desired.
--
-- This transformation transforms the above Core to the saner:
--
-- @
-- f = \(x :: Unsigned 4) -> case x of
--        _ -> error "incomplete case"
--        0 -> fromInteger 3
--        1 -> fromInteger 2
--        2 -> fromInteger 1
--        3 -> fromInteger 0
-- @
caseFlat :: HasCallStack => NormRewrite
caseFlat :: NormRewrite
caseFlat _ e :: Term
e@(Case (Term -> Maybe (Term, Term)
collectEqArgs -> Just (scrut' :: Term
scrut',_)) ty :: Kind
ty _)
  = do
       case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut' Term
e of
         Just alts' :: [Alt]
alts' -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Kind -> [Alt] -> Term
Case Term
scrut' Kind
ty ([Alt] -> Alt
forall a. [a] -> a
last [Alt]
alts' Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt] -> [Alt]
forall a. [a] -> [a]
init [Alt]
alts'))
         Nothing    -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

caseFlat _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

collectFlat :: Term -> Term -> Maybe [(Pat,Term)]
collectFlat :: Term -> Term -> Maybe [Alt]
collectFlat scrut :: Term
scrut (Case (Term -> Maybe (Term, Term)
collectEqArgs -> Just (scrut' :: Term
scrut', val :: Term
val)) _ty :: Kind
_ty [lAlt :: Alt
lAlt,rAlt :: Alt
rAlt])
  | Term
scrut' Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
scrut
  = case Term -> (Term, [Either Term Kind])
collectArgs Term
val of
      (Prim nm' :: Text
nm' _,args' :: [Either Term Kind]
args') | Text -> Bool
isFromInt Text
nm' ->
        Either Term Kind -> Maybe [Alt]
forall b. Either Term b -> Maybe [Alt]
go ([Either Term Kind] -> Either Term Kind
forall a. [a] -> a
last [Either Term Kind]
args')
      (Data dc :: DataCon
dc,args' :: [Either Term Kind]
args')    | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "GHC.Types.I#" ->
        Either Term Kind -> Maybe [Alt]
forall b. Either Term b -> Maybe [Alt]
go ([Either Term Kind] -> Either Term Kind
forall a. [a] -> a
last [Either Term Kind]
args')
      _ -> Maybe [Alt]
forall a. Maybe a
Nothing
  where
    go :: Either Term b -> Maybe [Alt]
go (Left (Literal i :: Literal
i)) = case (Alt
lAlt,Alt
rAlt) of
              ((pl :: Pat
pl,el :: Term
el),(pr :: Pat
pr,er :: Term
er))
                | Pat -> Bool
isFalseDcPat Pat
pl Bool -> Bool -> Bool
|| Pat -> Bool
isTrueDcPat Pat
pr ->
                   case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut Term
el of
                     Just alts' :: [Alt]
alts' -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just ((Literal -> Pat
LitPat Literal
i, Term
er) Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt]
alts')
                     Nothing    -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just [(Literal -> Pat
LitPat Literal
i, Term
er)
                                        ,(Pat
DefaultPat, Term
el)
                                        ]
                | Bool
otherwise ->
                   case Term -> Term -> Maybe [Alt]
collectFlat Term
scrut Term
er of
                     Just alts' :: [Alt]
alts' -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just ((Literal -> Pat
LitPat Literal
i, Term
el) Alt -> [Alt] -> [Alt]
forall a. a -> [a] -> [a]
: [Alt]
alts')
                     Nothing    -> [Alt] -> Maybe [Alt]
forall a. a -> Maybe a
Just [(Literal -> Pat
LitPat Literal
i, Term
el)
                                        ,(Pat
DefaultPat, Term
er)
                                        ]
    go _ = Maybe [Alt]
forall a. Maybe a
Nothing

    isFalseDcPat :: Pat -> Bool
isFalseDcPat (DataPat p :: DataCon
p _ _)
      = ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "GHC.Types.False") (Text -> Bool) -> (DataCon -> Text) -> DataCon -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DataCon -> Text
forall a. Name a -> Text
nameOcc (Name DataCon -> Text)
-> (DataCon -> Name DataCon) -> DataCon -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataCon -> Name DataCon
dcName) DataCon
p
    isFalseDcPat _ = Bool
False

    isTrueDcPat :: Pat -> Bool
isTrueDcPat (DataPat p :: DataCon
p _ _)
      = ((Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "GHC.Types.True") (Text -> Bool) -> (DataCon -> Text) -> DataCon -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name DataCon -> Text
forall a. Name a -> Text
nameOcc (Name DataCon -> Text)
-> (DataCon -> Name DataCon) -> DataCon -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataCon -> Name DataCon
dcName) DataCon
p
    isTrueDcPat _ = Bool
False

collectFlat _ _ = Maybe [Alt]
forall a. Maybe a
Nothing

collectEqArgs :: Term -> Maybe (Term,Term)
collectEqArgs :: Term -> Maybe (Term, Term)
collectEqArgs (Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks -> (Prim nm :: Text
nm _, args :: [Either Term Kind]
args, ticks :: [TickInfo]
ticks))
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.eq#"
    = let [_,_,Left scrut :: Term
scrut,Left val :: Term
val] = [Either Term Kind]
args
      in (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Index.eq#"  Bool -> Bool -> Bool
||
    Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Signed.eq#" Bool -> Bool -> Bool
||
    Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.Unsigned.eq#"
    = let [_,Left scrut :: Term
scrut,Left val :: Term
val] = [Either Term Kind]
args
      in (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
  | Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Transformations.eqInt"
    = let [Left scrut :: Term
scrut,Left val :: Term
val] = [Either Term Kind]
args
      in  (Term, Term) -> Maybe (Term, Term)
forall a. a -> Maybe a
Just (Term -> [TickInfo] -> Term
mkTicks Term
scrut [TickInfo]
ticks,Term
val)
collectEqArgs _ = Maybe (Term, Term)
forall a. Maybe a
Nothing

type NormRewriteW = Transform (StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState))

-- | See Note [ANF InScopeSet]
tellBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m ()
tellBinders :: [LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders bs :: [LetBinding]
bs = (([LetBinding], InScopeSet) -> ([LetBinding], InScopeSet))
-> StateT ([LetBinding], InScopeSet) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (([LetBinding]
bs [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++) ([LetBinding] -> [LetBinding])
-> (InScopeSet -> InScopeSet)
-> ([LetBinding], InScopeSet)
-> ([LetBinding], InScopeSet)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
`extendInScopeSetList` ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bs)))

-- | Turn an expression into a modified ANF-form. As opposed to standard ANF,
-- constants do not become let-bound.
makeANF :: HasCallStack => NormRewrite
makeANF :: NormRewrite
makeANF (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) (Lam bndr :: Id
bndr e :: Term
e) = do
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
makeANF (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr)
                                  (Id -> CoreContext
LamBody Id
bndrCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
ctx))
                Term
e
  Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Term -> Term
Lam Id
bndr Term
e')

makeANF _ e :: Term
e@(TyLam {}) = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

makeANF ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) e0 :: Term
e0
  = do
    -- We need to freshen all binders in `e` because we're shuffling them around
    -- into a single let-binder, because even when binders don't shadow, they
    -- don't have to be unique within an expression. And so lifting them all
    -- to a single let-binder will cause issues when they're not unique.
    --
    -- We cannot make freshening part of collectANF, because when we generate
    -- new binders, we need to make sure those names do not conflict with _any_
    -- of the existing binders in the expression.
    --
    -- See also Note [ANF InScopeSet]
    let (is2 :: InScopeSet
is2,e1 :: Term
e1) = InScopeSet -> Term -> (InScopeSet, Term)
freshenTm InScopeSet
is0 Term
e0
    (e2 :: Term
e2,(bndrs :: [LetBinding]
bndrs,_)) <- StateT
  ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
-> ([LetBinding], InScopeSet)
-> RewriteMonad NormalizeState (Term, ([LetBinding], InScopeSet))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Transform
  (StateT ([LetBinding], InScopeSet) (RewriteMonad NormalizeState))
-> Transform
     (StateT ([LetBinding], InScopeSet) (RewriteMonad NormalizeState))
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR HasCallStack =>
Transform
  (StateT ([LetBinding], InScopeSet) (RewriteMonad NormalizeState))
Transform
  (StateT ([LetBinding], InScopeSet) (RewriteMonad NormalizeState))
collectANF TransformContext
ctx Term
e1) ([],InScopeSet
is2)
    case [LetBinding]
bndrs of
      [] -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e0
      _  -> do
        let (e3 :: Term
e3,ticks :: [TickInfo]
ticks) = Term -> (Term, [TickInfo])
collectTicks Term
e2
            (srcTicks :: [TickInfo]
srcTicks,nmTicks :: [TickInfo]
nmTicks) = [TickInfo] -> ([TickInfo], [TickInfo])
partitionTicks [TickInfo]
ticks
        -- Ensure that `AppendName` ticks still scope over the entire expression
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks ([LetBinding] -> Term -> Term
Letrec [LetBinding]
bndrs (Term -> [TickInfo] -> Term
mkTicks Term
e3 [TickInfo]
srcTicks)) [TickInfo]
nmTicks)

-- | Note [ANF InScopeSet]
--
-- The InScopeSet contains:
--
--    1. All the free variables of the expression we are traversing
--
--    2. All the bound variables of the expression we are traversing
--
--    3. The newly created let-bindings as we recurse back up the traversal
--
-- All of these are needed to created let-bindings that
--
--    * Do not shadow
--    * Are not shadowed
--    * Nor conflict with each other (i.e. have the same unique)
--
-- Initially we start with the local InScopeSet and add the global variables:
--
-- @
-- is1 <- unionInScope is0 <$> Lens.use globalInScope
-- @
--
-- Which will gives us the (superset of) free variables of the expression. Then
-- we call  'freshenTm'
--
-- @
-- let (is2,e1) = freshenTm is1 e0
-- @
--
-- Which extends the InScopeSet with all the bound variables in 'e1', the
-- version of 'e0' where all binders are unique (not just deshadowed).
--
-- So we start out with an InScopeSet that satisfies points 1 and 2, now every
-- time we create a new binder we must add it to the InScopeSet to satisfy
-- point 3.
collectANF :: HasCallStack => NormRewriteW
collectANF :: Transform
  (StateT ([LetBinding], InScopeSet) (RewriteMonad NormalizeState))
collectANF ctx :: TransformContext
ctx e :: Term
e@(App appf :: Term
appf arg :: Term
arg)
  | (conVarPrim :: Term
conVarPrim, _) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e
  , Term -> Bool
isCon Term
conVarPrim Bool -> Bool -> Bool
|| Term -> Bool
isPrim Term
conVarPrim Bool -> Bool -> Bool
|| Term -> Bool
isVar Term
conVarPrim
  = do
    Bool
untranslatable <- RewriteMonad NormalizeState Bool
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
arg)
    let localVar :: Bool
localVar   = Term -> Bool
isLocalVar Term
arg
    Bool
constantNoCR   <- RewriteMonad NormalizeState Bool
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Term -> RewriteMonad NormalizeState Bool
forall extra. Term -> RewriteMonad extra Bool
isConstantNotClockReset Term
arg)
    case (Bool
untranslatable,Bool
localVar Bool -> Bool -> Bool
|| Bool
constantNoCR,Term
arg) of
      (False,False,_) -> do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        -- See Note [ANF InScopeSet]
        InScopeSet
is1   <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
        Id
argId <- RewriteMonad NormalizeState Id
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Id
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is1 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "app_arg") Term
arg)
        -- See Note [ANF InScopeSet]
        [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [(Id
argId,Term
arg)]
        Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Term -> Term
App Term
appf (Id -> Term
Var Id
argId))
      (True,False,Letrec binds :: [LetBinding]
binds body :: Term
body) -> do
        [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [LetBinding]
binds
        Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Term -> Term
App Term
appf Term
body)
      _ -> Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

collectANF _ (Letrec binds :: [LetBinding]
binds body :: Term
body) = do
  [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [LetBinding]
binds
  Bool
untranslatable <- RewriteMonad NormalizeState Bool
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Bool -> Term -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
False Term
body)
  let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
body
  if Bool
localVar Bool -> Bool -> Bool
|| Bool
untranslatable
    then Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
body
    else do
      TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
      -- See Note [ANF InScopeSet]
      InScopeSet
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
      Id
argId <- RewriteMonad NormalizeState Id
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Id
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (InScopeSet
-> TyConMap -> Name Any -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is1 TyConMap
tcm (Text -> Int -> Name Any
forall a. Text -> Int -> Name a
mkUnsafeSystemName "result" 0) Term
body)
      -- See Note [ANF InScopeSet]
      [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [(Id
argId,Term
body)]
      Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Term
Var Id
argId)

-- TODO: The code below special-cases ANF for the ':-' constructor for the
-- 'Signal' type. The 'Signal' type is essentially treated as a "transparent"
-- type by the Clash compiler, so observing its constructor leads to all kinds
-- of problems. In this case that "Clash.Rewrite.Util.mkSelectorCase" will
-- try to project the LHS and RHS of the ':-' constructor, however,
-- 'mkSelectorCase' uses 'coreView1' to find the "real" data-constructor.
-- 'coreView1' however looks through the 'Signal' type, and hence 'mkSelector'
-- finds the data constructors for the element type of Signal. This resulted in
-- error #24 (https://github.com/christiaanb/clash2/issues/24), where we
-- try to get the first field out of the 'Vec's 'Nil' constructor.
--
-- Ultimately we should stop treating Signal as a "transparent" type and deal
-- handling of the Signal type, and the involved co-recursive functions,
-- properly. At the moment, Clash cannot deal with this recursive type and the
-- recursive functions involved, hence the need for special-casing code. After
-- everything is done properly, we should remove the two lines below.
collectANF _ e :: Term
e@(Case _ _ [(DataPat dc :: DataCon
dc _ _,_)])
  | Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dc) Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Signal.Internal.:-" = Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

collectANF ctx :: TransformContext
ctx (Case subj :: Term
subj ty :: Kind
ty alts :: [Alt]
alts) = do
    let localVar :: Bool
localVar = Term -> Bool
isLocalVar Term
subj
    let isConstantSubj :: Bool
isConstantSubj = Term -> Bool
isConstant Term
subj

    Term
subj' <- if Bool
localVar Bool -> Bool -> Bool
|| Bool
isConstantSubj
      then Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
subj
      else do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        -- See Note [ANF InScopeSet]
        InScopeSet
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
        Id
argId <- RewriteMonad NormalizeState Id
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Id
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is1 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "case_scrut") Term
subj)
        -- See Note [ANF InScopeSet]
        [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [(Id
argId,Term
subj)]
        Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Id -> Term
Var Id
argId)

    [Alt]
alts' <- (Alt
 -> StateT
      ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt)
-> [Alt]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) [Alt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Term
-> Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
doAlt Term
subj') [Alt]
alts

    case [Alt]
alts' of
      [(DataPat _ [] xs :: [Id]
xs,altExpr :: Term
altExpr)]
        | [Id]
xs [Id] -> Term -> Bool
`localIdsDoNotOccurIn` Term
altExpr
        -> Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
altExpr
      _ -> Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Kind -> [Alt] -> Term
Case Term
subj' Kind
ty [Alt]
alts')
  where
    doAlt
      :: Term -> (Pat,Term)
      -> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState)
                (Pat,Term)
    doAlt :: Term
-> Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
doAlt subj' :: Term
subj' alt :: Alt
alt@(DataPat dc :: DataCon
dc exts :: [TyVar]
exts xs :: [Id]
xs,altExpr :: Term
altExpr) | Bool -> Bool
not ([TyVar] -> [Id] -> Bool
forall a. [TyVar] -> [Var a] -> Bool
bindsExistentials [TyVar]
exts [Id]
xs) = do
      let lv :: Bool
lv = Term -> Bool
isLocalVar Term
altExpr
      [LetBinding]
patSels <- (Id
 -> Int
 -> StateT
      ([LetBinding], InScopeSet)
      (RewriteMonad NormalizeState)
      LetBinding)
-> [Id]
-> [Int]
-> StateT
     ([LetBinding], InScopeSet)
     (RewriteMonad NormalizeState)
     [LetBinding]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM (Term
-> DataCon
-> Id
-> Int
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) LetBinding
doPatBndr Term
subj' DataCon
dc) [Id]
xs [0..]
      let altExprIsConstant :: Bool
altExprIsConstant = Term -> Bool
isConstant Term
altExpr
      let usesXs :: Term -> Bool
usesXs (Var n :: Id
n) = (Id -> Bool) -> [Id] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
n) [Id]
xs
          usesXs _       = Bool
False
      if (Bool
lv Bool -> Bool -> Bool
&& (Bool -> Bool
not (Term -> Bool
usesXs Term
altExpr) Bool -> Bool -> Bool
|| [Alt] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Alt]
alts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 1)) Bool -> Bool -> Bool
|| Bool
altExprIsConstant
        then do
          -- See Note [ANF InScopeSet]
          [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [LetBinding]
patSels
          Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
forall (m :: * -> *) a. Monad m => a -> m a
return Alt
alt
        else do
          TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
          -- See Note [ANF InScopeSet]
          InScopeSet
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
          Id
altId <- RewriteMonad NormalizeState Id
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Id
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is1 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "case_alt") Term
altExpr)
          -- See Note [ANF InScopeSet]
          [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders ((Id
altId,Term
altExpr)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
patSels)
          Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
forall (m :: * -> *) a. Monad m => a -> m a
return (DataCon -> [TyVar] -> [Id] -> Pat
DataPat DataCon
dc [TyVar]
exts [Id]
xs,Id -> Term
Var Id
altId)
    doAlt _ alt :: Alt
alt@(DataPat {}, _) = Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
forall (m :: * -> *) a. Monad m => a -> m a
return Alt
alt
    doAlt _ alt :: Alt
alt@(pat :: Pat
pat,altExpr :: Term
altExpr) = do
      let lv :: Bool
lv = Term -> Bool
isLocalVar Term
altExpr
      let altExprIsConstant :: Bool
altExprIsConstant = Term -> Bool
isConstant Term
altExpr
      if Bool
lv Bool -> Bool -> Bool
|| Bool
altExprIsConstant
        then Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
forall (m :: * -> *) a. Monad m => a -> m a
return Alt
alt
        else do
          TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
          -- See Note [ANF InScopeSet]
          InScopeSet
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
          Id
altId <- RewriteMonad NormalizeState Id
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Id
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: * -> *) a.
(Monad m, MonadUnique m, MonadFail m) =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is1 TyConMap
tcm (TransformContext -> Text -> Name Term
mkDerivedName TransformContext
ctx "case_alt") Term
altExpr)
          [LetBinding]
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) ()
forall (m :: * -> *).
Monad m =>
[LetBinding] -> StateT ([LetBinding], InScopeSet) m ()
tellBinders [(Id
altId,Term
altExpr)]
          Alt
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Alt
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat
pat,Id -> Term
Var Id
altId)

    doPatBndr
      :: Term -> DataCon -> Id -> Int
      -> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState)
                LetBinding
    doPatBndr :: Term
-> DataCon
-> Id
-> Int
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) LetBinding
doPatBndr subj' :: Term
subj' dc :: DataCon
dc pId :: Id
pId i :: Int
i
      = do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        -- See Note [ANF InScopeSet]
        InScopeSet
is1 <- Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) InScopeSet
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting InScopeSet ([LetBinding], InScopeSet) InScopeSet
forall s t a b. Field2 s t a b => Lens s t a b
_2
        Term
patExpr <- RewriteMonad NormalizeState Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (String
-> InScopeSet
-> TyConMap
-> Term
-> Int
-> Int
-> RewriteMonad NormalizeState Term
forall (m :: * -> *).
(HasCallStack, Functor m, Monad m, MonadUnique m) =>
String -> InScopeSet -> TyConMap -> Term -> Int -> Int -> m Term
mkSelectorCase ($(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "doPatBndr") InScopeSet
is1 TyConMap
tcm Term
subj' (DataCon -> Int
dcTag DataCon
dc) Int
i)
        -- No need to 'tellBinders' here because 'pId' is already in the ANF
        -- InScopeSet.
        --
        -- See also Note [ANF InScopeSet]
        LetBinding
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) LetBinding
forall (m :: * -> *) a. Monad m => a -> m a
return (Id
pId,Term
patExpr)

collectANF _ e :: Term
e = Term
-> StateT
     ([LetBinding], InScopeSet) (RewriteMonad NormalizeState) Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Eta-expand top-level lambda's (DON'T use in a traversal!)
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL :: NormRewrite
etaExpansionTL (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) (Lam bndr :: Id
bndr e :: Term
e) = do
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL
          (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr) (Id -> CoreContext
LamBody Id
bndrCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
ctx))
          Term
e
  Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Id -> Term -> Term
Lam Id
bndr Term
e'

etaExpansionTL (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) (Letrec xes :: [LetBinding]
xes e :: Term
e) = do
  let bndrs :: [Id]
bndrs = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL
          (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
bndrs)
                            ([Id] -> CoreContext
LetBody [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
ctx))
          Term
e
  case Term -> ([Id], Term)
stripLambda Term
e' of
    (bs :: [Id]
bs@(_:_),e2 :: Term
e2) -> do
      let e3 :: Term
e3 = [LetBinding] -> Term -> Term
Letrec [LetBinding]
xes Term
e2
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes Term
e')
  where
    stripLambda :: Term -> ([Id],Term)
    stripLambda :: Term -> ([Id], Term)
stripLambda (Lam bndr :: Id
bndr e0 :: Term
e0) =
      let (bndrs :: [Id]
bndrs,e1 :: Term
e1) = Term -> ([Id], Term)
stripLambda Term
e0
      in  (Id
bndrId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
bndrs,Term
e1)
    stripLambda e' :: Term
e' = ([],Term
e')

etaExpansionTL (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) e :: Term
e
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    if TyConMap -> Term -> Bool
isFun TyConMap
tcm Term
e
      then do
        let argTy :: Kind
argTy = ( (Kind, Kind) -> Kind
forall a b. (a, b) -> a
fst
                    ((Kind, Kind) -> Kind) -> (Term -> (Kind, Kind)) -> Term -> Kind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Kind, Kind) -> Maybe (Kind, Kind) -> (Kind, Kind)
forall a. a -> Maybe a -> a
Maybe.fromMaybe (String -> (Kind, Kind)
forall a. HasCallStack => String -> a
error (String -> (Kind, Kind)) -> String -> (Kind, Kind)
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "etaExpansion splitFunTy")
                    (Maybe (Kind, Kind) -> (Kind, Kind))
-> (Term -> Maybe (Kind, Kind)) -> Term -> (Kind, Kind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Kind -> Maybe (Kind, Kind)
splitFunTy TyConMap
tcm
                    (Kind -> Maybe (Kind, Kind))
-> (Term -> Kind) -> Term -> Maybe (Kind, Kind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Kind
termType TyConMap
tcm
                    ) Term
e
        Id
newId <- InScopeSet -> Text -> Kind -> RewriteMonad NormalizeState Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> Text -> Kind -> m Id
mkInternalVar InScopeSet
is0 "arg" Kind
argTy
        Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
newId)
                                               (Id -> CoreContext
LamBody Id
newIdCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
ctx))
                             (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId))
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId Term
e')
      else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Eta-expand functions with a Synthesize annotation, needed to allow such
-- functions to appear as arguments to higher-order primitives.
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn :: NormRewrite
etaExpandSyn (TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) e :: Term
e@(Term -> (Term, [Either Term Kind])
collectArgs -> (Var f :: Id
f, _)) = do
  UniqSet (Var Any)
topEnts <- Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
-> RewriteMonad NormalizeState (UniqSet (Var Any))
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
Lens' RewriteEnv (UniqSet (Var Any))
topEntities
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  let isTopEnt :: Bool
isTopEnt = Id
f Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
topEnts
      isAppFunCtx :: Context -> Bool
isAppFunCtx =
        \case
          AppFun:_ -> Bool
True
          TickC _:c :: Context
c -> Context -> Bool
isAppFunCtx Context
c
          _ -> Bool
False
      argTyM :: Maybe Kind
argTyM = ((Kind, Kind) -> Kind) -> Maybe (Kind, Kind) -> Maybe Kind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Kind, Kind) -> Kind
forall a b. (a, b) -> a
fst (TyConMap -> Kind -> Maybe (Kind, Kind)
splitFunTy TyConMap
tcm (TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e))
  case Maybe Kind
argTyM of
    Just argTy :: Kind
argTy | Bool
isTopEnt Bool -> Bool -> Bool
&& Bool -> Bool
not (Context -> Bool
isAppFunCtx Context
ctx) -> do
      Id
newId <- InScopeSet -> Text -> Kind -> RewriteMonad NormalizeState Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> Text -> Kind -> m Id
mkInternalVar InScopeSet
is0 "arg" Kind
argTy
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId)))
    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

etaExpandSyn _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

isClassConstraint :: Type -> Bool
isClassConstraint :: Kind -> Bool
isClassConstraint (Kind -> TypeView
tyView -> TyConApp nm0 :: TyConName
nm0 _) =
  if -- Constraint tuple:
     | "GHC.Classes.(%" Text -> Text -> Bool
`Text.isInfixOf` Text
nm1 -> Bool
True
     -- Constraint class:
     | "C:" Text -> Text -> Bool
`Text.isInfixOf` Text
nm2 -> Bool
True
     | Bool
otherwise -> Bool
False
 where
  nm1 :: Text
nm1 = TyConName -> Text
forall a. Name a -> Text
nameOcc TyConName
nm0
  nm2 :: Text
nm2 = (Text, Text) -> Text
forall a b. (a, b) -> b
snd (Text -> Text -> (Text, Text)
Text.breakOnEnd "." Text
nm1)

isClassConstraint _ = Bool
False


-- | Turn a  normalized recursive function, where the recursive calls only pass
-- along the unchanged original arguments, into let-recursive function. This
-- means that all recursive calls are replaced by the same variable reference as
-- found in the body of the top-level let-expression.
recToLetRec :: HasCallStack => NormRewrite
recToLetRec :: NormRewrite
recToLetRec (TransformContext is0 :: InScopeSet
is0 []) e :: Term
e = do
  (fn :: Id
fn,_) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
  TyConMap
tcm    <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  case TyConMap -> Term -> Either String ([Id], [LetBinding], Id)
splitNormalized TyConMap
tcm Term
e of
    Right (args :: [Id]
args,bndrs :: [LetBinding]
bndrs,res :: Id
res) -> do
      let args' :: [Term]
args'             = (Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
args
          (toInline :: [LetBinding]
toInline,others :: [LetBinding]
others) = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (TyConMap -> Id -> [Term] -> Term -> Bool
eqApp TyConMap
tcm Id
fn [Term]
args' (Term -> Bool) -> (LetBinding -> Term) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bndrs
          resV :: Term
resV              = Id -> Term
Var Id
res
      case ([LetBinding]
toInline,[LetBinding]
others) of
        (_:_,_:_) -> do
          let is1 :: InScopeSet
is1          = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ([Id]
args [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
bndrs)
          let substsInline :: Subst
substsInline = Subst -> [LetBinding] -> Subst
extendIdSubstList (InScopeSet -> Subst
mkSubst InScopeSet
is1)
                           ([LetBinding] -> Subst) -> [LetBinding] -> Subst
forall a b. (a -> b) -> a -> b
$ (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term -> Term -> Term
forall a b. a -> b -> a
const Term
resV)) [LetBinding]
toInline
              others' :: [LetBinding]
others'      = (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "recToLetRec" Subst
substsInline))
                                 [LetBinding]
others
          Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Term -> [Id] -> Term
mkLams ([LetBinding] -> Term -> Term
Letrec [LetBinding]
others' Term
resV) [Id]
args
        _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    -- This checks whether things are semantically equal. For example, say we
    -- have:
    --
    --   x :: (a, (b, c))
    --
    -- and
    --
    --   y :: (a, (b, c))
    --
    -- If we can determine that 'y' is constructed solely using the
    -- corresponding fields in 'x', then we can say they are semantically
    -- equal. The algorithm below keeps track of what (sub)field it is
    -- constructing, and checks if the field-expression projects the
    -- corresponding (sub)field from the target variable.
    --
    -- TODO: See [Note: Breaks on constants and predetermined equality]
    eqApp :: TyConMap -> Id -> [Term] -> Term -> Bool
eqApp tcm :: TyConMap
tcm v :: Id
v args :: [Term]
args (Term -> (Term, [Either Term Kind])
collectArgs -> (Var v' :: Id
v',args' :: [Either Term Kind]
args'))
      | Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
v'
      , Id
v Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
v'
      , let args2 :: [Term]
args2 = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args'
      , [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
args2
      = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Term -> Term -> Bool) -> [Term] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (TyConMap -> Term -> Term -> Bool
eqArg TyConMap
tcm) [Term]
args [Term]
args2)
    eqApp _ _ _ _ = Bool
False

    eqArg :: TyConMap -> Term -> Term -> Bool
eqArg _ v1 :: Term
v1 v2 :: Term
v2@(Var {})
      = Term
v1 Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
v2
    eqArg tcm :: TyConMap
tcm v1 :: Term
v1 v2 :: Term
v2@(Term -> (Term, [Either Term Kind])
collectArgs -> (Data _, args' :: [Either Term Kind]
args'))
      | let t1 :: Kind
t1 = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
v1
      , let t2 :: Kind
t2 = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
v2
      , Kind
t1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
t2
      = if Kind -> Bool
isClassConstraint Kind
t1 then
          -- Class constraints are equal if their types are equal, so we can
          -- take a shortcut here.
          Bool
True
        else
          -- Check whether all arguments to the data constructor are projections
          --
          [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and (([Int] -> Term -> Bool) -> [[Int]] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Term -> [Int] -> Term -> Bool
eqDat Term
v1) ((Int -> [Int]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map Int -> [Int]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [0..]) ([Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args'))
    eqArg _ _ _
      = Bool
False

    -- Recursively check whether a term /e/ is semantically equal to some variable /v/.
    -- Currently it can only assert equality when /e/ is  syntactically equal
    -- to /v/, or is constructed out of projections of /v/, importantly:
    --
    -- [Note: Breaks on constants and predetermined equality]
    -- This function currently breaks if:
    --
    --   * One or more subfields are constants. Constants might have been
    --     inlined for the construction, instead of being a projection of the
    --     target variable.
    --
    --   * One or more subfields are determined to be equal and one is simply
    --     swapped / replaced by the other. For example, say we have
    --     `x :: (a, a)`. If GHC determines that both elements of the tuple will
    --     always be the same, it might replace the (semantically equal to 'x')
    --     construction of `y` with `(fst x, fst x)`.
    --
    eqDat :: Term -> [Int] -> Term -> Bool
    eqDat :: Term -> [Int] -> Term -> Bool
eqDat v :: Term
v fTrace :: [Int]
fTrace (Term -> (Term, [Either Term Kind])
collectArgs -> (Data _, args :: [Either Term Kind]
args)) =
      [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and (([Int] -> Term -> Bool) -> [[Int]] -> [Term] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Term -> [Int] -> Term -> Bool
eqDat Term
v) ((Int -> [Int]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
fTrace) [0..]) ([Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args))
    eqDat v1 :: Term
v1 fTrace :: [Int]
fTrace v2 :: Term
v2 =
      case [Int] -> Term -> Term -> Maybe [Int]
stripProjection ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
fTrace) Term
v1 Term
v2 of
        Just [] -> Bool
True
        _ -> Bool
False

    stripProjection :: [Int] -> Term -> Term -> Maybe [Int]
    stripProjection :: [Int] -> Term -> Term -> Maybe [Int]
stripProjection fTrace0 :: [Int]
fTrace0 vTarget0 :: Term
vTarget0 (Case v :: Term
v _ [(DataPat _ _ xs :: [Id]
xs, r :: Term
r)]) = do
      -- Get projection made in subject of case:
      [Int]
fTrace1 <- [Int] -> Term -> Term -> Maybe [Int]
stripProjection [Int]
fTrace0 Term
vTarget0 Term
v

      -- Extract projection of this case statement. Subsequent calls to
      -- 'stripProjection' will check if new target is actually used.
      Int
n <- [Int] -> Maybe Int
forall a. [a] -> Maybe a
headMaybe [Int]
fTrace1
      Id
vTarget1 <- [Id] -> Int -> Maybe Id
forall a. [a] -> Int -> Maybe a
indexMaybe [Id]
xs Int
n
      [Int]
fTrace2 <- [Int] -> Maybe [Int]
forall a. [a] -> Maybe [a]
tailMaybe [Int]
fTrace1

      [Int] -> Term -> Term -> Maybe [Int]
stripProjection [Int]
fTrace2 (Id -> Term
Var Id
vTarget1) Term
r

    stripProjection fTrace :: [Int]
fTrace (Var sTarget :: Id
sTarget) (Var s :: Id
s) =
      if Id
sTarget Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
s then [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
fTrace else Maybe [Int]
forall a. Maybe a
Nothing

    stripProjection _fTrace :: [Int]
_fTrace _vTarget :: Term
_vTarget _v :: Term
_v =
      Maybe [Int]
forall a. Maybe a
Nothing

recToLetRec _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Inline a function with functional arguments
inlineHO :: HasCallStack => NormRewrite
inlineHO :: NormRewrite
inlineHO (TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(App _ _)
  | (Var f :: Id
f, args :: [Either Term Kind]
args, ticks :: [TickInfo]
ticks) <- Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
e
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let hasPolyFunArgs :: Bool
hasPolyFunArgs = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ((Either Term Kind -> Bool) -> [Either Term Kind] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Bool) -> (Kind -> Bool) -> Either Term Kind -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (TyConMap -> Term -> Bool
isPolyFun TyConMap
tcm) (Bool -> Kind -> Bool
forall a b. a -> b -> a
const Bool
False)) [Either Term Kind]
args)
    if Bool
hasPolyFunArgs
      then do (cf :: Id
cf,_)    <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
              Maybe Int
isInlined <- State NormalizeState (Maybe Int)
-> RewriteMonad NormalizeState (Maybe Int)
forall extra a. State extra a -> RewriteMonad extra a
zoomExtra (Id -> Id -> State NormalizeState (Maybe Int)
alreadyInlined Id
f Id
cf)
              Int
limit     <- Getting Int (RewriteState NormalizeState) Int
-> RewriteMonad NormalizeState Int
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Int NormalizeState)
-> RewriteState NormalizeState
-> Const Int (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Int NormalizeState)
 -> RewriteState NormalizeState
 -> Const Int (RewriteState NormalizeState))
-> ((Int -> Const Int Int)
    -> NormalizeState -> Const Int NormalizeState)
-> Getting Int (RewriteState NormalizeState) Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int -> Const Int Int)
-> NormalizeState -> Const Int NormalizeState
Lens' NormalizeState Int
inlineLimit)
              if (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
Maybe.fromMaybe 0 Maybe Int
isInlined) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
limit
                then do
                  DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
                  Bool
-> String
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
> DebugLevel
DebugNone) ($(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "InlineHO: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall a. Show a => a -> String
show Id
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ " already inlined " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
limit String -> String -> String
forall a. [a] -> [a] -> [a]
++ " times in:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall a. Show a => a -> String
show Id
cf) (Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e)
                else do
                  Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe <- Id
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
f (VarEnv (Id, SrcSpan, InlineSpec, Term)
 -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
                  case Maybe (Id, SrcSpan, InlineSpec, Term)
bodyMaybe of
                    Just (_,_,_,body :: Term
body) -> do
                      State NormalizeState () -> RewriteMonad NormalizeState ()
forall extra a. State extra a -> RewriteMonad extra a
zoomExtra (Id -> Id -> State NormalizeState ()
addNewInline Id
f Id
cf)
                      -- See Note [AppProp no-shadow invariant]
                      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Either Term Kind] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
is0 Term
body) [TickInfo]
ticks) [Either Term Kind]
args)
                    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

inlineHO _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Simplified CSE, only works on let-bindings, works from top to bottom
simpleCSE :: HasCallStack => NormRewrite
simpleCSE :: NormRewrite
simpleCSE (TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(Letrec binders :: [LetBinding]
binders body :: Term
body) = do
  let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binders)
  let (reducedBindings :: [LetBinding]
reducedBindings,body' :: Term
body') = InScopeSet -> [LetBinding] -> Term -> ([LetBinding], Term)
reduceBindersFix InScopeSet
is1 [LetBinding]
binders Term
body
  if [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
binders Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
reducedBindings
     then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
reducedBindings Term
body')
     else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

simpleCSE _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

reduceBindersFix
  :: InScopeSet
  -> [LetBinding]
  -> Term
  -> ([LetBinding],Term)
reduceBindersFix :: InScopeSet -> [LetBinding] -> Term -> ([LetBinding], Term)
reduceBindersFix is :: InScopeSet
is binders :: [LetBinding]
binders body :: Term
body =
  if [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
binders Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [LetBinding] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetBinding]
reduced
     then InScopeSet -> [LetBinding] -> Term -> ([LetBinding], Term)
reduceBindersFix InScopeSet
is [LetBinding]
reduced Term
body'
     else ([LetBinding]
binders,Term
body)
 where
  (reduced :: [LetBinding]
reduced,body' :: Term
body') = InScopeSet
-> [LetBinding] -> Term -> [LetBinding] -> ([LetBinding], Term)
reduceBinders InScopeSet
is [] Term
body [LetBinding]
binders

reduceBinders
  :: InScopeSet
  -> [LetBinding]
  -> Term
  -> [LetBinding]
  -> ([LetBinding],Term)
reduceBinders :: InScopeSet
-> [LetBinding] -> Term -> [LetBinding] -> ([LetBinding], Term)
reduceBinders _  processed :: [LetBinding]
processed body :: Term
body [] = ([LetBinding]
processed,Term
body)
reduceBinders is :: InScopeSet
is processed :: [LetBinding]
processed body :: Term
body ((id_ :: Id
id_,expr :: Term
expr):binders :: [LetBinding]
binders) = case (LetBinding -> Bool) -> [LetBinding] -> Maybe LetBinding
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find ((Term -> Term -> Bool
forall a. Eq a => a -> a -> Bool
== Term
expr) (Term -> Bool) -> (LetBinding -> Term) -> LetBinding -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
processed of
    Just (id2 :: Id
id2,_) ->
      let subst :: Subst
subst      = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is) Id
id_ (Id -> Term
Var Id
id2)
          processed' :: [LetBinding]
processed' = (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "reduceBinders.processed" Subst
subst)) [LetBinding]
processed
          binders' :: [LetBinding]
binders'   = (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "reduceBinders.binders"   Subst
subst)) [LetBinding]
binders
          body' :: Term
body'      = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "reduceBinders.body" Subst
subst Term
body
      in  InScopeSet
-> [LetBinding] -> Term -> [LetBinding] -> ([LetBinding], Term)
reduceBinders InScopeSet
is [LetBinding]
processed' Term
body' [LetBinding]
binders'
    Nothing -> InScopeSet
-> [LetBinding] -> Term -> [LetBinding] -> ([LetBinding], Term)
reduceBinders InScopeSet
is ((Id
id_,Term
expr)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
processed) Term
body [LetBinding]
binders

reduceConst :: HasCallStack => NormRewrite
reduceConst :: NormRewrite
reduceConst ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(App _ _)
  | (Prim nm0 :: Text
nm0 _, _) <- Term -> (Term, [Either Term Kind])
collectArgs Term
e
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs <- Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
-> RewriteMonad
     NormalizeState (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
  (RewriteState NormalizeState)
  (VarEnv (Id, SrcSpan, InlineSpec, Term))
forall extra.
Lens' (RewriteState extra) (VarEnv (Id, SrcSpan, InlineSpec, Term))
bindings
    PrimEvaluator
primEval <- Getting PrimEvaluator RewriteEnv PrimEvaluator
-> RewriteMonad NormalizeState PrimEvaluator
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting PrimEvaluator RewriteEnv PrimEvaluator
Lens' RewriteEnv PrimEvaluator
evaluator
    Supply
ids <- Getting Supply (RewriteState NormalizeState) Supply
-> RewriteMonad NormalizeState Supply
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting Supply (RewriteState NormalizeState) Supply
forall extra. Lens' (RewriteState extra) Supply
uniqSupply
    let (ids1 :: Supply
ids1,ids2 :: Supply
ids2) = Supply -> (Supply, Supply)
splitSupply Supply
ids
    (Supply -> Identity Supply)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra. Lens' (RewriteState extra) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> Supply -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= Supply
ids2
    GlobalHeap
gh <- Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
-> RewriteMonad NormalizeState GlobalHeap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
forall extra. Lens' (RewriteState extra) GlobalHeap
globalHeap
    case PrimEvaluator
-> VarEnv (Id, SrcSpan, InlineSpec, Term)
-> TyConMap
-> GlobalHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (GlobalHeap, PureHeap, Term)
whnf' PrimEvaluator
primEval VarEnv (Id, SrcSpan, InlineSpec, Term)
bndrs TyConMap
tcm GlobalHeap
gh Supply
ids1 InScopeSet
is0 Bool
False Term
e of
      (gh' :: GlobalHeap
gh',ph' :: PureHeap
ph',e' :: Term
e') -> do
        (GlobalHeap -> Identity GlobalHeap)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra. Lens' (RewriteState extra) GlobalHeap
globalHeap ((GlobalHeap -> Identity GlobalHeap)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> GlobalHeap -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= GlobalHeap
gh'
        TransformContext
-> TyConMap
-> PureHeap
-> (TransformContext -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall extra.
TransformContext
-> TyConMap
-> PureHeap
-> (TransformContext -> RewriteMonad extra Term)
-> RewriteMonad extra Term
bindPureHeap TransformContext
ctx TyConMap
tcm PureHeap
ph' ((TransformContext -> RewriteMonad NormalizeState Term)
 -> RewriteMonad NormalizeState Term)
-> (TransformContext -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ \_ctx' :: TransformContext
_ctx' -> case Term
e' of
          (Term -> (Term, [Either Term Kind])
collectArgs -> (Prim nm1 :: Text
nm1 _, _)) | Text
nm0 Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
nm1 -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e'

reduceConst _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Replace primitives by their "definition" if they would lead to let-bindings
-- with a non-representable type when a function is in ANF. This happens for
-- example when Clash.Size.Vector.map consumes or produces a vector of
-- non-representable elements.
--
-- Basically what this transformation does is replace a primitive the completely
-- unrolled recursive definition that it represents. e.g.
--
-- > zipWith ($) (xs :: Vec 2 (Int -> Int)) (ys :: Vec 2 Int)
--
-- is replaced by:
--
-- > let (x0  :: (Int -> Int))       = case xs  of (:>) _ x xr -> x
-- >     (xr0 :: Vec 1 (Int -> Int)) = case xs  of (:>) _ x xr -> xr
-- >     (x1  :: (Int -> Int)(       = case xr0 of (:>) _ x xr -> x
-- >     (y0  :: Int)                = case ys  of (:>) _ y yr -> y
-- >     (yr0 :: Vec 1 Int)          = case ys  of (:>) _ y yr -> xr
-- >     (y1  :: Int                 = case yr0 of (:>) _ y yr -> y
-- > in  (($) x0 y0 :> ($) x1 y1 :> Nil)
--
-- Currently, it only handles the following functions:
--
-- * Clash.Sized.Vector.zipWith
-- * Clash.Sized.Vector.map
-- * Clash.Sized.Vector.traverse#
-- * Clash.Sized.Vector.fold
-- * Clash.Sized.Vector.foldr
-- * Clash.Sized.Vector.dfold
-- * Clash.Sized.Vector.(++)
-- * Clash.Sized.Vector.head
-- * Clash.Sized.Vector.tail
-- * Clash.Sized.Vector.last
-- * Clash.Sized.Vector.init
-- * Clash.Sized.Vector.unconcat
-- * Clash.Sized.Vector.transpose
-- * Clash.Sized.Vector.replicate
-- * Clash.Sized.Vector.replace_int
-- * Clash.Sized.Vector.imap
-- * Clash.Sized.Vector.dtfold
-- * Clash.Sized.RTree.tdfold
-- * Clash.Sized.RTree.treplicate
-- * Clash.Sized.Internal.BitVector.split#
-- * Clash.Sized.Internal.BitVector.eq#
reduceNonRepPrim :: HasCallStack => NormRewrite
reduceNonRepPrim :: NormRewrite
reduceNonRepPrim c :: TransformContext
c@(TransformContext is0 :: InScopeSet
is0 ctx :: Context
ctx) e :: Term
e@(App _ _) | (Prim nm :: Text
nm _, args :: [Either Term Kind]
args, ticks :: [TickInfo]
ticks) <- Term -> (Term, [Either Term Kind], [TickInfo])
collectArgsTicks Term
e = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
  Bool
shouldReduce1 <- Context -> RewriteMonad NormalizeState Bool
shouldReduce Context
ctx
  Bool
ultra <- Getting Bool (RewriteState NormalizeState) Bool
-> RewriteMonad NormalizeState Bool
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Bool NormalizeState)
-> RewriteState NormalizeState
-> Const Bool (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Bool NormalizeState)
 -> RewriteState NormalizeState
 -> Const Bool (RewriteState NormalizeState))
-> ((Bool -> Const Bool Bool)
    -> NormalizeState -> Const Bool NormalizeState)
-> Getting Bool (RewriteState NormalizeState) Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Bool -> Const Bool Bool)
-> NormalizeState -> Const Bool NormalizeState
Lens' NormalizeState Bool
normalizeUltra)
  let eTy :: Kind
eTy = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e
  case Kind -> TypeView
tyView Kind
eTy of
    (TyConApp vecTcNm :: TyConName
vecTcNm@(TyConName -> Text
forall a. Name a -> Text
nameOcc -> Text
"Clash.Sized.Vector.Vec")
              [Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (Except String Integer -> Either String Integer)
-> (Kind -> Except String Integer) -> Kind -> Either String Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm -> Right 0, aTy :: Kind
aTy]) -> do
      let (Just vecTc :: TyCon
vecTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
vecTcNm TyConMap
tcm
          [nilCon :: DataCon
nilCon,consCon :: DataCon
consCon] = TyCon -> [DataCon]
tyConDataCons TyCon
vecTc
          nilE :: Term
nilE = DataCon -> DataCon -> Kind -> Integer -> [Term] -> Term
mkVec DataCon
nilCon DataCon
consCon Kind
aTy 0 []
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks Term
nilE [TickInfo]
ticks)
    tv :: TypeView
tv -> case Text
nm of
      "Clash.Sized.Vector.zipWith" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 7 -> do
        let [lhsElTy :: Kind
lhsElTy,rhsElty :: Kind
rhsElty,resElTy :: Kind
resElTy,nTy :: Kind
nTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            [Bool]
untranslatableTys <- (Kind -> RewriteMonad NormalizeState Bool)
-> [Kind] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly [Kind
lhsElTy,Kind
rhsElty,Kind
resElTy]
            if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
untranslatableTys Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra Bool -> Bool -> Bool
|| Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2
               then let [fun :: Term
fun,lhsArg :: Term
lhsArg,rhsArg :: Term
rhsArg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
                    in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                        TransformContext
-> Integer
-> Kind
-> Kind
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceZipWith TransformContext
c Integer
n Kind
lhsElTy Kind
rhsElty Kind
resElTy Term
fun Term
lhsArg Term
rhsArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.map" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 5 -> do
        let [argElTy :: Kind
argElTy,resElTy :: Kind
resElTy,nTy :: Kind
nTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            [Bool]
untranslatableTys <- (Kind -> RewriteMonad NormalizeState Bool)
-> [Kind] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly [Kind
argElTy,Kind
resElTy]
            if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
untranslatableTys Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra Bool -> Bool -> Bool
|| Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2
               then let [fun :: Term
fun,arg :: Term
arg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
                    in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformContext
-> Integer
-> Kind
-> Kind
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceMap TransformContext
c Integer
n Kind
argElTy Kind
resElTy Term
fun Term
arg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.traverse#" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 7 ->
        let [aTy :: Kind
aTy,fTy :: Kind
fTy,bTy :: Kind
bTy,nTy :: Kind
nTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
        in  case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n ->
            let [dict :: Term
dict,fun :: Term
fun,arg :: Term
arg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
            in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformContext
-> Integer
-> Kind
-> Kind
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceTraverse TransformContext
c Integer
n Kind
aTy Kind
fTy Kind
bTy Term
dict Term
fun Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.fold" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 4 -> do
        let [aTy :: Kind
aTy,nTy :: Kind
nTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            isPow2 :: a -> Bool
isPow2 x :: a
x  = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 Bool -> Bool -> Bool
&& (a
x a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a -> a
forall a. Bits a => a -> a
complement a
x a -> a -> a
forall a. Num a => a -> a -> a
+ 1)) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
        Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n | Bool -> Bool
not (Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPow2 (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ 1)) Bool -> Bool -> Bool
|| Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra Bool -> Bool -> Bool
|| Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 ->
            let [fun :: Term
fun,arg :: Term
arg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
            in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformContext
-> Integer
-> Kind
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceFold TransformContext
c (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ 1) Kind
aTy Term
fun Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.foldr" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 6 ->
        let [aTy :: Kind
aTy,bTy :: Kind
bTy,nTy :: Kind
nTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
        in  case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            [Bool]
untranslatableTys <- (Kind -> RewriteMonad NormalizeState Bool)
-> [Kind] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly [Kind
aTy,Kind
bTy]
            if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
untranslatableTys Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra
              then let [fun :: Term
fun,start :: Term
start,arg :: Term
arg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
                   in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformContext
-> Integer
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceFoldr TransformContext
c Integer
n Kind
aTy Term
fun Term
start Term
arg
              else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.dfold" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 8 ->
        let ([_kn :: Term
_kn,_motive :: Term
_motive,fun :: Term
fun,start :: Term
start,arg :: Term
arg],[_mTy :: Kind
_mTy,nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        in  case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceDFold InScopeSet
is0 Integer
n Kind
aTy Term
fun Term
start Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.++" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 5 ->
        let [nTy :: Kind
nTy,aTy :: Kind
aTy,mTy :: Kind
mTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            [lArg :: Term
lArg,rArg :: Term
rArg]   = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
        in case (Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy), Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
mTy)) of
              (Right n :: Integer
n, Right m :: Integer
m)
                | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
rArg
                | Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
lArg
                | Bool
otherwise -> do
                    Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
                    if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
                       then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Integer
-> Kind
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceAppend InScopeSet
is0 Integer
n Integer
m Kind
aTy Term
lArg Term
rArg
                       else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
              _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.head" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3 -> do
        let [nTy :: Kind
nTy,aTy :: Kind
aTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            [vArg :: Term
vArg]    = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceHead InScopeSet
is0 (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+1) Kind
aTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.tail" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3 -> do
        let [nTy :: Kind
nTy,aTy :: Kind
aTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            [vArg :: Term
vArg]    = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceTail InScopeSet
is0 (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+1) Kind
aTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.last" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3 -> do
        let [nTy :: Kind
nTy,aTy :: Kind
aTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            [vArg :: Term
vArg]    = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceLast InScopeSet
is0 (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+1) Kind
aTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.init" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 3 -> do
        let [nTy :: Kind
nTy,aTy :: Kind
aTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
            [vArg :: Term
vArg]    = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceInit InScopeSet
is0 (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+1) Kind
aTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.unconcat" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 6 -> do
        let ([_knN :: Term
_knN,_sm :: Term
_sm,arg :: Term
arg],[mTy :: Kind
mTy,nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case (Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy), Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
mTy)) of
          (Right n :: Integer
n, Right 0) -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceUnconcat Integer
n 0 Kind
aTy Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.transpose" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 5 -> do
        let ([_knN :: Term
_knN,arg :: Term
arg],[mTy :: Kind
mTy,nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case (Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy), Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
mTy)) of
          (Right n :: Integer
n, Right 0) -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceTranspose Integer
n 0 Kind
aTy Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.replicate" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 4 -> do
        let ([_sArg :: Term
_sArg,vArg :: Term
vArg],[nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Kind -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceReplicate Integer
n Kind
aTy Kind
eTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
       -- replace_int :: KnownNat n => Vec n a -> Int -> a -> Vec n a
      "Clash.Sized.Vector.replace_int" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 6 -> do
        let ([_knArg :: Term
_knArg,vArg :: Term
vArg,iArg :: Term
iArg,aArg :: Term
aArg],[nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Kind
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceReplace_int InScopeSet
is0 Integer
n Kind
aTy Kind
eTy Term
vArg Term
iArg Term
aArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

      "Clash.Sized.Vector.index_int" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 5 -> do
        let ([_knArg :: Term
_knArg,vArg :: Term
vArg,iArg :: Term
iArg],[nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Kind
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceIndex_int InScopeSet
is0 Integer
n Kind
aTy Term
vArg Term
iArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

      "Clash.Sized.Vector.imap" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 6 -> do
        let [nTy :: Kind
nTy,argElTy :: Kind
argElTy,resElTy :: Kind
resElTy] = [Either Term Kind] -> [Kind]
forall a b. [Either a b] -> [b]
Either.rights [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            [Bool]
untranslatableTys <- (Kind -> RewriteMonad NormalizeState Bool)
-> [Kind] -> RewriteMonad NormalizeState [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Kind -> RewriteMonad NormalizeState Bool
forall extra. Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly [Kind
argElTy,Kind
resElTy]
            if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
untranslatableTys Bool -> Bool -> Bool
|| Bool
shouldReduce1 Bool -> Bool -> Bool
|| Bool
ultra Bool -> Bool -> Bool
|| Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 2
               then let [_,fun :: Term
fun,arg :: Term
arg] = [Either Term Kind] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Kind]
args
                    in  (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformContext
-> Integer
-> Kind
-> Kind
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceImap TransformContext
c Integer
n Kind
argElTy Kind
resElTy Term
fun Term
arg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Vector.dtfold" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 8 ->
        let ([_kn :: Term
_kn,_motive :: Term
_motive,lrFun :: Term
lrFun,brFun :: Term
brFun,arg :: Term
arg],[_mTy :: Kind
_mTy,nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        in  case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceDTFold InScopeSet
is0 Integer
n Kind
aTy Term
lrFun Term
brFun Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

      "Clash.Sized.Vector.reverse"
        | Bool
ultra
        , ([vArg :: Term
vArg],[nTy :: Kind
nTy,aTy :: Kind
aTy]) <- [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        , Right n :: Integer
n <- Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy)
        -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceReverse InScopeSet
is0 Integer
n Kind
aTy Term
vArg

      "Clash.Sized.RTree.tdfold" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 8 ->
        let ([_kn :: Term
_kn,_motive :: Term
_motive,lrFun :: Term
lrFun,brFun :: Term
brFun,arg :: Term
arg],[_mTy :: Kind
_mTy,nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        in  case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> Integer
-> Kind
-> Term
-> Term
-> Term
-> RewriteMonad NormalizeState Term
reduceTFold InScopeSet
is0 Integer
n Kind
aTy Term
lrFun Term
brFun Term
arg
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.RTree.treplicate" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 4 -> do
        let ([_sArg :: Term
_sArg,vArg :: Term
vArg],[nTy :: Kind
nTy,aTy :: Kind
aTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy) of
          Right n :: Integer
n -> do
            Bool
untranslatableTy <- Bool -> Kind -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Kind -> RewriteMonad extra Bool
isUntranslatableType Bool
False Kind
aTy
            if Bool
untranslatableTy Bool -> Bool -> Bool
|| Bool
shouldReduce1
               then (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
ticks) (Term -> Term)
-> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Kind -> Kind -> Term -> RewriteMonad NormalizeState Term
reduceTReplicate Integer
n Kind
aTy Kind
eTy Term
vArg
               else Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Internal.BitVector.split#" | [Either Term Kind] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Either Term Kind]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 4 -> do
        let ([_knArg :: Term
_knArg,bvArg :: Term
bvArg],[nTy :: Kind
nTy,mTy :: Kind
mTy]) = [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        case (Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy), Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
mTy), TypeView
tv) of
          (Right n :: Integer
n, Right m :: Integer
m, TyConApp tupTcNm :: TyConName
tupTcNm [lTy :: Kind
lTy,rTy :: Kind
rTy])
            | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 -> do
              let (Just tupTc :: TyCon
tupTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
tupTcNm TyConMap
tcm
                  [tupDc :: DataCon
tupDc]      = TyCon -> [DataCon]
tyConDataCons TyCon
tupTc
                  tup :: Term
tup          = Term -> [Either Term Kind] -> Term
mkApps (DataCon -> Term
Data DataCon
tupDc)
                                    [Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
lTy
                                    ,Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
rTy
                                    ,Term -> Either Term Kind
forall a b. a -> Either a b
Left  Term
bvArg
                                    ,Term -> Either Term Kind
forall a b. a -> Either a b
Left  (Kind -> Term
removedTm Kind
rTy)
                                    ]

              Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks Term
tup [TickInfo]
ticks)
            | Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 -> do
              let (Just tupTc :: TyCon
tupTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
tupTcNm TyConMap
tcm
                  [tupDc :: DataCon
tupDc]      = TyCon -> [DataCon]
tyConDataCons TyCon
tupTc
                  tup :: Term
tup          = Term -> [Either Term Kind] -> Term
mkApps (DataCon -> Term
Data DataCon
tupDc)
                                    [Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
lTy
                                    ,Kind -> Either Term Kind
forall a b. b -> Either a b
Right Kind
rTy
                                    ,Term -> Either Term Kind
forall a b. a -> Either a b
Left  (Kind -> Term
removedTm Kind
lTy)
                                    ,Term -> Either Term Kind
forall a b. a -> Either a b
Left  Term
bvArg
                                    ]

              Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks Term
tup [TickInfo]
ticks)
          _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
      "Clash.Sized.Internal.BitVector.eq#"
        | ([_,_],[nTy :: Kind
nTy]) <- [Either Term Kind] -> ([Term], [Kind])
forall a b. [Either a b] -> ([a], [b])
Either.partitionEithers [Either Term Kind]
args
        , Right 0 <- Except String Integer -> Either String Integer
forall e a. Except e a -> Either e a
runExcept (TyConMap -> Kind -> Except String Integer
tyNatSize TyConMap
tcm Kind
nTy)
        , TyConApp boolTcNm :: TyConName
boolTcNm [] <- TypeView
tv
        -> let (Just boolTc :: TyCon
boolTc) = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
boolTcNm TyConMap
tcm
               [_falseDc :: DataCon
_falseDc,trueDc :: DataCon
trueDc] = TyCon -> [DataCon]
tyConDataCons TyCon
boolTc
           in  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [TickInfo] -> Term
mkTicks (DataCon -> Term
Data DataCon
trueDc) [TickInfo]
ticks)
      _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
  where
    isUntranslatableType_not_poly :: Kind -> RewriteMonad extra Bool
isUntranslatableType_not_poly t :: Kind
t = do
      Bool
u <- Bool -> Kind -> RewriteMonad extra Bool
forall extra. Bool -> Kind -> RewriteMonad extra Bool
isUntranslatableType Bool
False Kind
t
      if Bool
u
         then Bool -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => a -> m a
return ([TyVar] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([TyVar] -> Bool) -> [TyVar] -> Bool
forall a b. (a -> b) -> a -> b
$ Getting (Endo [TyVar]) Kind TyVar -> Kind -> [TyVar]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [TyVar]) Kind TyVar
Fold Kind TyVar
typeFreeVars Kind
t)
         else Bool -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

reduceNonRepPrim _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | This transformation lifts applications of global binders out of
-- alternatives of case-statements.
--
-- e.g. It converts:
--
-- @
-- case x of
--   A -> f 3 y
--   B -> f x x
--   C -> h x
-- @
--
-- into:
--
-- @
-- let f_arg0 = case x of {A -> 3; B -> x}
--     f_arg1 = case x of {A -> y; B -> x}
--     f_out  = f f_arg0 f_arg1
-- in  case x of
--       A -> f_out
--       B -> f_out
--       C -> h x
-- @
disjointExpressionConsolidation :: HasCallStack => NormRewrite
disjointExpressionConsolidation :: NormRewrite
disjointExpressionConsolidation ctx :: TransformContext
ctx@(TransformContext is0 :: InScopeSet
is0 _) e :: Term
e@(Case _scrut :: Term
_scrut _ty :: Kind
_ty _alts :: [Alt]
_alts@(_:_:_)) = do
    (_,collected :: [(Term, ([Term], CaseTree [Either Term Kind]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Kind]))])
collectGlobals InScopeSet
is0 [] [] Term
e
    let disJoint :: [(Term, ([Term], CaseTree [Either Term Kind]))]
disJoint = ((Term, ([Term], CaseTree [Either Term Kind])) -> Bool)
-> [(Term, ([Term], CaseTree [Either Term Kind]))]
-> [(Term, ([Term], CaseTree [Either Term Kind]))]
forall a. (a -> Bool) -> [a] -> [a]
filter (CaseTree [Either Term Kind] -> Bool
isDisjoint (CaseTree [Either Term Kind] -> Bool)
-> ((Term, ([Term], CaseTree [Either Term Kind]))
    -> CaseTree [Either Term Kind])
-> (Term, ([Term], CaseTree [Either Term Kind]))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Term], CaseTree [Either Term Kind])
-> CaseTree [Either Term Kind]
forall a b. (a, b) -> b
snd (([Term], CaseTree [Either Term Kind])
 -> CaseTree [Either Term Kind])
-> ((Term, ([Term], CaseTree [Either Term Kind]))
    -> ([Term], CaseTree [Either Term Kind]))
-> (Term, ([Term], CaseTree [Either Term Kind]))
-> CaseTree [Either Term Kind]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Term, ([Term], CaseTree [Either Term Kind]))
-> ([Term], CaseTree [Either Term Kind])
forall a b. (a, b) -> b
snd) [(Term, ([Term], CaseTree [Either Term Kind]))]
collected
    if [(Term, ([Term], CaseTree [Either Term Kind]))] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Term, ([Term], CaseTree [Either Term Kind]))]
disJoint
       then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e
       else do
         [(Term, [Term])]
exprs <- ((Term, ([Term], CaseTree [Either Term Kind]))
 -> RewriteMonad NormalizeState (Term, [Term]))
-> [(Term, ([Term], CaseTree [Either Term Kind]))]
-> RewriteMonad NormalizeState [(Term, [Term])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> (Term, ([Term], CaseTree [Either Term Kind]))
-> RewriteMonad NormalizeState (Term, [Term])
mkDisjointGroup InScopeSet
is0) [(Term, ([Term], CaseTree [Either Term Kind]))]
disJoint
         TyConMap
tcm   <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
         [Id]
lids  <- ((Term, ([Term], CaseTree [Either Term Kind]))
 -> (Term, [Term]) -> RewriteMonad NormalizeState Id)
-> [(Term, ([Term], CaseTree [Either Term Kind]))]
-> [(Term, [Term])]
-> RewriteMonad NormalizeState [Id]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM (InScopeSet
-> TyConMap
-> (Term, ([Term], CaseTree [Either Term Kind]))
-> (Term, [Term])
-> RewriteMonad NormalizeState Id
forall (m :: * -> *) b b.
(Monad m, MonadUnique m) =>
InScopeSet -> TyConMap -> (Term, b) -> (Term, b) -> m Id
mkFunOut InScopeSet
is0 TyConMap
tcm) [(Term, ([Term], CaseTree [Either Term Kind]))]
disJoint [(Term, [Term])]
exprs
         let substitution :: [(Term, Term)]
substitution = [Term] -> [Term] -> [(Term, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Term, ([Term], CaseTree [Either Term Kind])) -> Term)
-> [(Term, ([Term], CaseTree [Either Term Kind]))] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, ([Term], CaseTree [Either Term Kind])) -> Term
forall a b. (a, b) -> a
fst [(Term, ([Term], CaseTree [Either Term Kind]))]
disJoint) ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
lids)
             subsMatrix :: [[(Term, Term)]]
subsMatrix   = [(Term, Term)] -> [[(Term, Term)]]
forall a. [a] -> [[a]]
l2m [(Term, Term)]
substitution
         (exprs' :: [Term]
exprs',_) <- [(Term, [(Term, ([Term], CaseTree [Either Term Kind]))])]
-> ([Term], [[(Term, ([Term], CaseTree [Either Term Kind]))]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Term, [(Term, ([Term], CaseTree [Either Term Kind]))])]
 -> ([Term], [[(Term, ([Term], CaseTree [Either Term Kind]))]]))
-> RewriteMonad
     NormalizeState
     [(Term, [(Term, ([Term], CaseTree [Either Term Kind]))])]
-> RewriteMonad
     NormalizeState
     ([Term], [[(Term, ([Term], CaseTree [Either Term Kind]))]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(Term, Term)]
 -> (Term, [Term])
 -> RewriteMonad
      NormalizeState
      (Term, [(Term, ([Term], CaseTree [Either Term Kind]))]))
-> [[(Term, Term)]]
-> [(Term, [Term])]
-> RewriteMonad
     NormalizeState
     [(Term, [(Term, ([Term], CaseTree [Either Term Kind]))])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
Monad.zipWithM
                        (\s :: [(Term, Term)]
s (e' :: Term
e',seen :: [Term]
seen) -> InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Kind]))])
collectGlobals InScopeSet
is0 [(Term, Term)]
s [Term]
seen Term
e')
                        [[(Term, Term)]]
subsMatrix
                        [(Term, [Term])]
exprs
         (e' :: Term
e',_) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Kind]))])
collectGlobals InScopeSet
is0 [(Term, Term)]
substitution [] Term
e
         let lb :: Term
lb = [LetBinding] -> Term -> Term
Letrec ([Id] -> [Term] -> [LetBinding]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
lids [Term]
exprs') Term
e'
         Term
lb' <- NormRewrite -> NormRewrite
forall (m :: * -> *). Monad m => Transform m -> Transform m
bottomupR HasCallStack => NormRewrite
NormRewrite
deadCode TransformContext
ctx Term
lb
         Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
lb'
  where
    mkFunOut :: InScopeSet -> TyConMap -> (Term, b) -> (Term, b) -> m Id
mkFunOut isN :: InScopeSet
isN tcm :: TyConMap
tcm (fun :: Term
fun,_) (e' :: Term
e',_) = do
      let ty :: Kind
ty  = TyConMap -> Term -> Kind
termType TyConMap
tcm Term
e'
          nm :: Text
nm  = case Term -> (Term, [Either Term Kind])
collectArgs Term
fun of
                   (Var v :: Id
v,_)      -> Name Term -> Text
forall a. Name a -> Text
nameOcc (Id -> Name Term
forall a. Var a -> Name a
varName Id
v)
                   (Prim nm' :: Text
nm' _,_) -> Text
nm'
                   _             -> "complex_expression_"
          nm'' :: Text
nm'' = [Text] -> Text
forall a. [a] -> a
last (Text -> Text -> [Text]
Text.splitOn "." Text
nm) Text -> Text -> Text
`Text.append` "Out"
      InScopeSet -> Text -> Kind -> m Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> Text -> Kind -> m Id
mkInternalVar InScopeSet
isN Text
nm'' Kind
ty

    l2m :: [a] -> [[a]]
l2m = [a] -> [a] -> [[a]]
forall a. [a] -> [a] -> [[a]]
go []
      where
        go :: [a] -> [a] -> [[a]]
go _  []     = []
        go xs :: [a]
xs (y :: a
y:ys :: [a]
ys) = ([a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys) [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [[a]]
go ([a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
y]) [a]
ys

disjointExpressionConsolidation _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Given a function in the desired normal form, inline all the following
-- let-bindings:
--
-- Let-bindings with an internal name that is only used once, where it binds:
--   * a primitive that will be translated to an HDL expression (as opposed to
--     a HDL declaration)
--   * a projection case-expression (1 alternative)
--   * a data constructor
inlineCleanup :: HasCallStack => NormRewrite
inlineCleanup :: NormRewrite
inlineCleanup (TransformContext is0 :: InScopeSet
is0 _) (Letrec binds :: [LetBinding]
binds body :: Term
body) = do
  HashMap Text GuardedCompiledPrimitive
prims <- Getting
  (HashMap Text GuardedCompiledPrimitive)
  (RewriteState NormalizeState)
  (HashMap Text GuardedCompiledPrimitive)
-> RewriteMonad
     NormalizeState (HashMap Text GuardedCompiledPrimitive)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState
 -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
-> RewriteState NormalizeState
-> Const
     (HashMap Text GuardedCompiledPrimitive)
     (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState
  -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
 -> RewriteState NormalizeState
 -> Const
      (HashMap Text GuardedCompiledPrimitive)
      (RewriteState NormalizeState))
-> ((HashMap Text GuardedCompiledPrimitive
     -> Const
          (HashMap Text GuardedCompiledPrimitive)
          (HashMap Text GuardedCompiledPrimitive))
    -> NormalizeState
    -> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState)
-> Getting
     (HashMap Text GuardedCompiledPrimitive)
     (RewriteState NormalizeState)
     (HashMap Text GuardedCompiledPrimitive)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(HashMap Text GuardedCompiledPrimitive
 -> Const
      (HashMap Text GuardedCompiledPrimitive)
      (HashMap Text GuardedCompiledPrimitive))
-> NormalizeState
-> Const (HashMap Text GuardedCompiledPrimitive) NormalizeState
Lens' NormalizeState (HashMap Text GuardedCompiledPrimitive)
primitives)
      -- For all let-bindings, count the number of times they are referenced.
      -- We only inline let-bindings which are referenced only once, otherwise
      -- we would lose sharing.
  -- let allOccs       = List.foldl' (HashMap.unionWith (+)) HashMap.empty
  --                   $ map ( List.foldl' countOcc HashMap.empty
  --                         . Lens.toListOf termFreeIds . unembed . snd) binds
  let is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
binds)
  let allOccs :: VarEnv Int
allOccs       = (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int -> [VarEnv Int] -> VarEnv Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)) VarEnv Int
forall a. VarEnv a
emptyVarEnv
                    ([VarEnv Int] -> VarEnv Int) -> [VarEnv Int] -> VarEnv Int
forall a b. (a -> b) -> a -> b
$ (LetBinding -> VarEnv Int) -> [LetBinding] -> [VarEnv Int]
forall a b. (a -> b) -> [a] -> [b]
map (Fold Term Id
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int
-> (Id -> VarEnv Int)
-> Term
-> VarEnv Int
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf Fold Term Id
freeLocalIds ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                            VarEnv Int
forall a. VarEnv a
emptyVarEnv (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
`unitVarEnv` 1) (Term -> VarEnv Int)
-> (LetBinding -> Term) -> LetBinding -> VarEnv Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd)
                          [LetBinding]
binds
      bodyFVs :: UniqSet (Var Any)
bodyFVs       = Getting (UniqSet (Var Any)) Term Id
-> (Id -> UniqSet (Var Any)) -> Term -> UniqSet (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting (UniqSet (Var Any)) Term Id
Fold Term Id
freeLocalIds Id -> UniqSet (Var Any)
forall a. Var a -> UniqSet (Var Any)
unitVarSet Term
body
      (il :: [LetBinding]
il,keep :: [LetBinding]
keep)     = (LetBinding -> Bool)
-> [LetBinding] -> ([LetBinding], [LetBinding])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (VarEnv Int
-> HashMap Text GuardedCompiledPrimitive
-> UniqSet (Var Any)
-> LetBinding
-> Bool
isInteresting VarEnv Int
allOccs HashMap Text GuardedCompiledPrimitive
prims UniqSet (Var Any)
bodyFVs) [LetBinding]
binds
      keep' :: [LetBinding]
keep'         = InScopeSet -> [LetBinding] -> [LetBinding] -> [LetBinding]
inlineBndrs InScopeSet
is1 [LetBinding]
keep [LetBinding]
il
  if [LetBinding] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LetBinding]
il then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return  ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds Term
body)
             else Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
keep' Term
body)
  where
    -- Determine whether a let-binding is interesting to inline
    isInteresting
      :: VarEnv Int
      -> CompiledPrimMap
      -> VarSet
      -> (Id, Term)
      -> Bool
    isInteresting :: VarEnv Int
-> HashMap Text GuardedCompiledPrimitive
-> UniqSet (Var Any)
-> LetBinding
-> Bool
isInteresting allOccs :: VarEnv Int
allOccs prims :: HashMap Text GuardedCompiledPrimitive
prims bodyFVs :: UniqSet (Var Any)
bodyFVs (id_ :: Id
id_,((Term, [Either Term Kind]) -> Term
forall a b. (a, b) -> a
fst((Term, [Either Term Kind]) -> Term)
-> (Term -> (Term, [Either Term Kind])) -> Term -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Term -> (Term, [Either Term Kind])
collectArgs) -> Term
tm)
      | Name Term -> NameSort
forall a. Name a -> NameSort
nameSort (Id -> Name Term
forall a. Var a -> Name a
varName Id
id_) NameSort -> NameSort -> Bool
forall a. Eq a => a -> a -> Bool
/= NameSort
User
      , Id
id_ Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`notElemVarSet` UniqSet (Var Any)
bodyFVs
      = case Term
tm of
          Prim nm :: Text
nm _
            | Just (GuardedCompiledPrimitive -> Maybe CompiledPrimitive
forall a. PrimitiveGuard a -> Maybe a
extractPrim -> Just p :: CompiledPrimitive
p@(BlackBox {})) <- Text
-> HashMap Text GuardedCompiledPrimitive
-> Maybe GuardedCompiledPrimitive
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HashMap.lookup Text
nm HashMap Text GuardedCompiledPrimitive
prims
            , TemplateKind
TExpr <- CompiledPrimitive -> TemplateKind
forall a b c d. Primitive a b c d -> TemplateKind
kind CompiledPrimitive
p
            , Just occ :: Int
occ <- Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
id_ VarEnv Int
allOccs
            , Int
occ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 2
            -> Bool
True
          Case _ _ [_] -> Bool
True
          Data _ -> Bool
True
          _ -> Bool
False
      | Id
id_ Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`notElemVarSet` UniqSet (Var Any)
bodyFVs
      = case Term
tm of
          Case _ _ [(DataPat dcE :: DataCon
dcE _ _,_)]
            -> let nm :: Text
nm = (Name DataCon -> Text
forall a. Name a -> Text
nameOcc (DataCon -> Name DataCon
dcName DataCon
dcE))
               in -- Inlines WW projection that exposes internals of the BitVector types
                  Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.BV"  Bool -> Bool -> Bool
||
                  Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.Bit" Bool -> Bool -> Bool
||
                  -- Inlines projections out of constraint-tuples (e.g. HiddenClockReset)
                  "GHC.Classes" Text -> Text -> Bool
`Text.isPrefixOf` Text
nm
          _ -> Bool
False

    isInteresting _ _ _ _ = Bool
False

    -- Inline let-bindings we want to inline into let-bindings we want to keep.
    inlineBndrs
      :: InScopeSet
      -> [(Id, Term)]
      -- let-bindings we keep
      -> [(Id, Term)]
      -- let-bindings we want to inline
      -> [(Id, Term)]
    inlineBndrs :: InScopeSet -> [LetBinding] -> [LetBinding] -> [LetBinding]
inlineBndrs _   keep :: [LetBinding]
keep [] = [LetBinding]
keep
    inlineBndrs isN :: InScopeSet
isN keep :: [LetBinding]
keep ((v :: Id
v,e :: Term
e):il :: [LetBinding]
il) =
      let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
isN) Id
v Term
e
      in  if Id
v Id -> Term -> Bool
`localIdOccursIn` Term
e -- don't inline recursive binders
          then InScopeSet -> [LetBinding] -> [LetBinding] -> [LetBinding]
inlineBndrs InScopeSet
isN ((Id
v,Term
e)LetBinding -> [LetBinding] -> [LetBinding]
forall a. a -> [a] -> [a]
:[LetBinding]
keep) [LetBinding]
il
          else InScopeSet -> [LetBinding] -> [LetBinding] -> [LetBinding]
inlineBndrs InScopeSet
isN
                 ((LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "inlineCleanup.inlineBndrs" Subst
subst)) [LetBinding]
keep)
                 ((LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "inlineCleanup.inlineBndrs" Subst
subst)) [LetBinding]
il)
      -- We must not forget to inline the /current/ @to-inline@ let-binding into
      -- the list of /remaining/ @to-inline@ let-bindings, because it might
      -- only occur in /remaining/ @to-inline@ bindings. If we don't, we would
      -- introduce free variables, because the @to-inline@ bindings are removed.

inlineCleanup _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e

-- | Flatten's letrecs after `inlineCleanup`
--
-- `inlineCleanup` sometimes exposes additional possibilities for `caseCon`,
-- which then introduces let-bindings in what should be ANF. This transformation
-- flattens those nested let-bindings again.
--
-- NB: must only be called in the cleaning up phase.
flattenLet :: HasCallStack => NormRewrite
flattenLet :: NormRewrite
flattenLet (TransformContext is0 :: InScopeSet
is0 _) letrec :: Term
letrec@(Letrec _ _) = do
  let (is2 :: InScopeSet
is2, Letrec binds :: [LetBinding]
binds body :: Term
body) = InScopeSet -> Term -> (InScopeSet, Term)
freshenTm InScopeSet
is0 Term
letrec
      bodyOccs :: VarEnv Int
bodyOccs = Fold Term Id
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int
-> (Id -> VarEnv Int)
-> Term
-> VarEnv Int
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf
                   Fold Term Id
freeLocalIds ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                   VarEnv Int
forall a. VarEnv a
emptyVarEnv (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
`unitVarEnv` (1 :: Int))
                   Term
body
  [LetBinding]
binds' <- [[LetBinding]] -> [LetBinding]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[LetBinding]] -> [LetBinding])
-> RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState [LetBinding]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
go InScopeSet
is2) [LetBinding]
binds
  case [LetBinding]
binds' of
    -- inline binders into the body when there's only a single binder, and only
    -- if that binder doesn't perform any work or is only used once in the body
    [(id' :: Id
id',e' :: Term
e')] | Just occ :: Int
occ <- Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
id' VarEnv Int
bodyOccs, Term -> Bool
isWorkFree Term
e' Bool -> Bool -> Bool
|| Int
occ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 2 ->
      if Id
id' Id -> Term -> Bool
`localIdOccursIn` Term
e'
         -- Except when the binder is recursive!
         then Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds' Term
body)
         else let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is2) Id
id' Term
e'
              in Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "flattenLet" Subst
subst Term
body)
    _ -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
binds' Term
body)
  where
    go :: InScopeSet -> LetBinding -> NormalizeSession [LetBinding]
    go :: InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
go isN :: InScopeSet
isN (id_ :: Id
id_,Term -> (Term, [TickInfo])
collectTicks -> (Letrec binds' :: [LetBinding]
binds' body' :: Term
body',ticks :: [TickInfo]
ticks)) = do
      let bodyOccs :: VarEnv Int
bodyOccs = Fold Term Id
-> (VarEnv Int -> VarEnv Int -> VarEnv Int)
-> VarEnv Int
-> (Id -> VarEnv Int)
-> Term
-> VarEnv Int
forall s a r. Fold s a -> (r -> r -> r) -> r -> (a -> r) -> s -> r
Lens.foldMapByOf
                       Fold Term Id
freeLocalIds ((Int -> Int -> Int) -> VarEnv Int -> VarEnv Int -> VarEnv Int
forall a. (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
unionVarEnvWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                       VarEnv Int
forall a. VarEnv a
emptyVarEnv (Id -> Int -> VarEnv Int
forall b a. Var b -> a -> VarEnv a
`unitVarEnv` (1 :: Int))
                       Term
body'
          (srcTicks :: [TickInfo]
srcTicks,nmTicks :: [TickInfo]
nmTicks) = [TickInfo] -> ([TickInfo], [TickInfo])
partitionTicks [TickInfo]
ticks
      -- Distribute the name ticks of the let-expression over all the bindings
      (LetBinding -> LetBinding) -> [LetBinding] -> [LetBinding]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> LetBinding -> LetBinding
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term -> [TickInfo] -> Term
`mkTicks` [TickInfo]
nmTicks)) ([LetBinding] -> [LetBinding])
-> RewriteMonad NormalizeState [LetBinding]
-> RewriteMonad NormalizeState [LetBinding]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case [LetBinding]
binds' of
        -- inline binders into the body when there's only a single binder, and
        -- only if that binder doesn't perform any work or is only used once in
        -- the body
        [(id' :: Id
id',e' :: Term
e')] | Just occ :: Int
occ <- Id -> VarEnv Int -> Maybe Int
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
id' VarEnv Int
bodyOccs, Term -> Bool
isWorkFree Term
e' Bool -> Bool -> Bool
|| Int
occ Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 2 ->
          if Id
id' Id -> Term -> Bool
`localIdOccursIn` Term
e'
             -- Except when the binder is recursive!
             then [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
id',Term
e'),(Id
id_, Term
body')]
             else let subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
isN) Id
id' Term
e'
                  in  [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
id_
                               -- Only apply srcTicks to the body
                               ,Term -> [TickInfo] -> Term
mkTicks (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "flattenLetGo" Subst
subst Term
body')
                                        [TickInfo]
srcTicks)]
        bs :: [LetBinding]
bs -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding]
bs [LetBinding] -> [LetBinding] -> [LetBinding]
forall a. [a] -> [a] -> [a]
++ [(Id
id_
                               -- Only apply srcTicks to the body
                              ,Term -> [TickInfo] -> Term
mkTicks Term
body' [TickInfo]
srcTicks)])
    go _ b :: LetBinding
b = [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: * -> *) a. Monad m => a -> m a
return [LetBinding
b]

flattenLet _ e :: Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
e