{-|
  Copyright  :  (C) 2015-2016, University of Twente
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Helper functions for the 'disjointExpressionConsolidation' transformation

  The 'disjointExpressionConsolidation' transformation lifts applications of
  global binders out of alternatives of case-statements.

  e.g. It converts:

  > case x of
  >   A -> f 3 y
  >   B -> f x x
  >   C -> h x

  into:

  > let f_arg0 = case x of {A -> 3; B -> x}
  >     f_arg1 = case x of {A -> y; B -> x}
  >     f_out  = f f_arg0 f_arg1
  > in  case x of
  >       A -> f_out
  >       B -> f_out
  >       C -> h x
-}

{-# LANGUAGE DeriveFoldable    #-}
{-# LANGUAGE DeriveFunctor     #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo       #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE ViewPatterns      #-}

module Clash.Normalize.DEC
  (collectGlobals
  ,isDisjoint
  ,mkDisjointGroup
  )
where

-- external
import           Control.Concurrent.Supply        (splitSupply)
import qualified Control.Lens                     as Lens
import           Data.Bits                        ((.&.),complement)
import           Data.Coerce                      (coerce)
import qualified Data.Either                      as Either
import qualified Data.Foldable                    as Foldable
import qualified Data.IntMap.Strict               as IM
import qualified Data.IntSet                      as IntSet
import qualified Data.List                        as List
import qualified Data.Map.Strict                  as Map
import qualified Data.Maybe                       as Maybe
import           Data.Monoid                      (All (..))

-- internal
import Clash.Core.DataCon    (DataCon, dcTag)
import Clash.Core.Evaluator  (whnf')
import Clash.Core.FreeVars
  (termFreeVars', typeFreeVars', localVarsDoNotOccurIn)
import Clash.Core.Literal    (Literal (..))
import Clash.Core.Term
  (LetBinding, Pat (..), PrimInfo (..), Term (..), collectArgs, collectArgsTicks)
import Clash.Core.TyCon      (tyConDataCons)
import Clash.Core.Type       (Type, isPolyFunTy, mkTyConApp, splitFunForallTy)
import Clash.Core.Util       (mkApps, mkTicks, patIds, termType)
import Clash.Core.Var        (isGlobalId)
import Clash.Core.VarEnv
  (InScopeSet, elemInScopeSet, notElemInScopeSet)
import Clash.Normalize.Types (NormalizeState)
import Clash.Rewrite.Types
  (RewriteMonad, bindings, evaluator, globalHeap, tcCache, tupleTcCache, uniqSupply)
import Clash.Rewrite.Util    (mkInternalVar, mkSelectorCase,
                              isUntranslatableType, isConstant)
import Clash.Unique          (lookupUniqMap)
import Clash.Util

data CaseTree a
  = Leaf a
  | LB [LetBinding] (CaseTree a)
  | Branch Term [(Pat,CaseTree a)]
  deriving (CaseTree a -> CaseTree a -> Bool
(CaseTree a -> CaseTree a -> Bool)
-> (CaseTree a -> CaseTree a -> Bool) -> Eq (CaseTree a)
forall a. Eq a => CaseTree a -> CaseTree a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CaseTree a -> CaseTree a -> Bool
$c/= :: forall a. Eq a => CaseTree a -> CaseTree a -> Bool
== :: CaseTree a -> CaseTree a -> Bool
$c== :: forall a. Eq a => CaseTree a -> CaseTree a -> Bool
Eq,Int -> CaseTree a -> ShowS
[CaseTree a] -> ShowS
CaseTree a -> String
(Int -> CaseTree a -> ShowS)
-> (CaseTree a -> String)
-> ([CaseTree a] -> ShowS)
-> Show (CaseTree a)
forall a. Show a => Int -> CaseTree a -> ShowS
forall a. Show a => [CaseTree a] -> ShowS
forall a. Show a => CaseTree a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CaseTree a] -> ShowS
$cshowList :: forall a. Show a => [CaseTree a] -> ShowS
show :: CaseTree a -> String
$cshow :: forall a. Show a => CaseTree a -> String
showsPrec :: Int -> CaseTree a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> CaseTree a -> ShowS
Show,a -> CaseTree b -> CaseTree a
(a -> b) -> CaseTree a -> CaseTree b
(forall a b. (a -> b) -> CaseTree a -> CaseTree b)
-> (forall a b. a -> CaseTree b -> CaseTree a) -> Functor CaseTree
forall a b. a -> CaseTree b -> CaseTree a
forall a b. (a -> b) -> CaseTree a -> CaseTree b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> CaseTree b -> CaseTree a
$c<$ :: forall a b. a -> CaseTree b -> CaseTree a
fmap :: (a -> b) -> CaseTree a -> CaseTree b
$cfmap :: forall a b. (a -> b) -> CaseTree a -> CaseTree b
Functor,CaseTree a -> Bool
(a -> m) -> CaseTree a -> m
(a -> b -> b) -> b -> CaseTree a -> b
(forall m. Monoid m => CaseTree m -> m)
-> (forall m a. Monoid m => (a -> m) -> CaseTree a -> m)
-> (forall m a. Monoid m => (a -> m) -> CaseTree a -> m)
-> (forall a b. (a -> b -> b) -> b -> CaseTree a -> b)
-> (forall a b. (a -> b -> b) -> b -> CaseTree a -> b)
-> (forall b a. (b -> a -> b) -> b -> CaseTree a -> b)
-> (forall b a. (b -> a -> b) -> b -> CaseTree a -> b)
-> (forall a. (a -> a -> a) -> CaseTree a -> a)
-> (forall a. (a -> a -> a) -> CaseTree a -> a)
-> (forall a. CaseTree a -> [a])
-> (forall a. CaseTree a -> Bool)
-> (forall a. CaseTree a -> Int)
-> (forall a. Eq a => a -> CaseTree a -> Bool)
-> (forall a. Ord a => CaseTree a -> a)
-> (forall a. Ord a => CaseTree a -> a)
-> (forall a. Num a => CaseTree a -> a)
-> (forall a. Num a => CaseTree a -> a)
-> Foldable CaseTree
forall a. Eq a => a -> CaseTree a -> Bool
forall a. Num a => CaseTree a -> a
forall a. Ord a => CaseTree a -> a
forall m. Monoid m => CaseTree m -> m
forall a. CaseTree a -> Bool
forall a. CaseTree a -> Int
forall a. CaseTree a -> [a]
forall a. (a -> a -> a) -> CaseTree a -> a
forall m a. Monoid m => (a -> m) -> CaseTree a -> m
forall b a. (b -> a -> b) -> b -> CaseTree a -> b
forall a b. (a -> b -> b) -> b -> CaseTree a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: CaseTree a -> a
$cproduct :: forall a. Num a => CaseTree a -> a
sum :: CaseTree a -> a
$csum :: forall a. Num a => CaseTree a -> a
minimum :: CaseTree a -> a
$cminimum :: forall a. Ord a => CaseTree a -> a
maximum :: CaseTree a -> a
$cmaximum :: forall a. Ord a => CaseTree a -> a
elem :: a -> CaseTree a -> Bool
$celem :: forall a. Eq a => a -> CaseTree a -> Bool
length :: CaseTree a -> Int
$clength :: forall a. CaseTree a -> Int
null :: CaseTree a -> Bool
$cnull :: forall a. CaseTree a -> Bool
toList :: CaseTree a -> [a]
$ctoList :: forall a. CaseTree a -> [a]
foldl1 :: (a -> a -> a) -> CaseTree a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> CaseTree a -> a
foldr1 :: (a -> a -> a) -> CaseTree a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> CaseTree a -> a
foldl' :: (b -> a -> b) -> b -> CaseTree a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> CaseTree a -> b
foldl :: (b -> a -> b) -> b -> CaseTree a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> CaseTree a -> b
foldr' :: (a -> b -> b) -> b -> CaseTree a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> CaseTree a -> b
foldr :: (a -> b -> b) -> b -> CaseTree a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> CaseTree a -> b
foldMap' :: (a -> m) -> CaseTree a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> CaseTree a -> m
foldMap :: (a -> m) -> CaseTree a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> CaseTree a -> m
fold :: CaseTree m -> m
$cfold :: forall m. Monoid m => CaseTree m -> m
Foldable)

-- | Test if a 'CaseTree' collected from an expression indicates that
-- application of a global binder is disjoint: occur in separate branches of a
-- case-expression.
isDisjoint :: CaseTree ([Either Term Type])
           -> Bool
isDisjoint :: CaseTree [Either Term Type] -> Bool
isDisjoint (Branch _ [_]) = Bool
False
isDisjoint ct :: CaseTree [Either Term Type]
ct = CaseTree [Either Term Type] -> Bool
forall b a. Eq b => CaseTree [Either a b] -> Bool
go CaseTree [Either Term Type]
ct
  where
    go :: CaseTree [Either a b] -> Bool
go (Leaf _)             = Bool
False
    go (LB _ ct' :: CaseTree [Either a b]
ct')           = CaseTree [Either a b] -> Bool
go CaseTree [Either a b]
ct'
    go (Branch _ [])        = Bool
False
    go (Branch _ [(_,x :: CaseTree [Either a b]
x)])   = CaseTree [Either a b] -> Bool
go CaseTree [Either a b]
x
    go b :: CaseTree [Either a b]
b@(Branch _ (_:_:_)) = [[b]] -> Bool
forall a. Eq a => [a] -> Bool
allEqual (([Either a b] -> [b]) -> [[Either a b]] -> [[b]]
forall a b. (a -> b) -> [a] -> [b]
map [Either a b] -> [b]
forall a b. [Either a b] -> [b]
Either.rights (CaseTree [Either a b] -> [[Either a b]]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList CaseTree [Either a b]
b))

-- Remove empty branches from a 'CaseTree'
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty :: CaseTree [a] -> CaseTree [a]
removeEmpty l :: CaseTree [a]
l@(Leaf _) = CaseTree [a]
l
removeEmpty (LB lb :: [LetBinding]
lb ct :: CaseTree [a]
ct) =
  case CaseTree [a] -> CaseTree [a]
forall a. Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty CaseTree [a]
ct of
    Leaf [] -> [a] -> CaseTree [a]
forall a. a -> CaseTree a
Leaf []
    ct' :: CaseTree [a]
ct'     -> [LetBinding] -> CaseTree [a] -> CaseTree [a]
forall a. [LetBinding] -> CaseTree a -> CaseTree a
LB [LetBinding]
lb CaseTree [a]
ct'
removeEmpty (Branch s :: Term
s bs :: [(Pat, CaseTree [a])]
bs) =
  case ((Pat, CaseTree [a]) -> Bool)
-> [(Pat, CaseTree [a])] -> [(Pat, CaseTree [a])]
forall a. (a -> Bool) -> [a] -> [a]
filter ((CaseTree [a] -> CaseTree [a] -> Bool
forall a. Eq a => a -> a -> Bool
/= ([a] -> CaseTree [a]
forall a. a -> CaseTree a
Leaf [])) (CaseTree [a] -> Bool)
-> ((Pat, CaseTree [a]) -> CaseTree [a])
-> (Pat, CaseTree [a])
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pat, CaseTree [a]) -> CaseTree [a]
forall a b. (a, b) -> b
snd) (((Pat, CaseTree [a]) -> (Pat, CaseTree [a]))
-> [(Pat, CaseTree [a])] -> [(Pat, CaseTree [a])]
forall a b. (a -> b) -> [a] -> [b]
map ((CaseTree [a] -> CaseTree [a])
-> (Pat, CaseTree [a]) -> (Pat, CaseTree [a])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second CaseTree [a] -> CaseTree [a]
forall a. Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty) [(Pat, CaseTree [a])]
bs) of
    []  -> [a] -> CaseTree [a]
forall a. a -> CaseTree a
Leaf []
    bs' :: [(Pat, CaseTree [a])]
bs' -> Term -> [(Pat, CaseTree [a])] -> CaseTree [a]
forall a. Term -> [(Pat, CaseTree a)] -> CaseTree a
Branch Term
s [(Pat, CaseTree [a])]
bs'

-- | Test if all elements in a list are equal to each other.
allEqual :: Eq a => [a] -> Bool
allEqual :: [a] -> Bool
allEqual []     = Bool
True
allEqual (x :: a
x:xs :: [a]
xs) = (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x) [a]
xs

-- | Collect 'CaseTree's for (potentially) disjoint applications of globals out
-- of an expression. Also substitute truly disjoint applications of globals by a
-- reference to a lifted out application.
collectGlobals'
  :: InScopeSet
  -> [(Term,Term)]
  -- ^ Substitution of (applications of) a global binder by a reference to a
  -- lifted term.
  -> [Term]
  -- ^ List of already seen global binders
  -> Term
  -- ^ The expression
  -> Bool
  -- ^ Whether expression is constant
  -> RewriteMonad
      NormalizeState
      (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals' :: InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> Bool
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals' inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen (Case scrut :: Term
scrut ty :: Type
ty alts :: [Alt]
alts) _eIsConstant :: Bool
_eIsConstant = do
  rec (alts' :: [Alt]
alts' ,collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected)  <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> [Alt]
-> RewriteMonad
     NormalizeState
     ([Alt], [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsAlts InScopeSet
inScope [(Term, Term)]
substitution [Term]
seen Term
scrut'
                                                [Alt]
alts
      (scrut' :: Term
scrut',collected' :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected') <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals InScopeSet
inScope [(Term, Term)]
substitution
                                            (((Term, ([Term], CaseTree [Either Term Type])) -> Term)
-> [(Term, ([Term], CaseTree [Either Term Type]))] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, ([Term], CaseTree [Either Term Type])) -> Term
forall a b. (a, b) -> a
fst [(Term, ([Term], CaseTree [Either Term Type]))]
collected [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
seen) Term
scrut
  (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> Type -> [Alt] -> Term
Case Term
scrut' Type
ty [Alt]
alts',[(Term, ([Term], CaseTree [Either Term Type]))]
collected [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall a. [a] -> [a] -> [a]
++ [(Term, ([Term], CaseTree [Either Term Type]))]
collected')

collectGlobals' inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (fun :: Term
fun, args :: [Either Term Type]
args@(_:_), ticks :: [TickInfo]
ticks)) eIsconstant :: Bool
eIsconstant
  | Bool -> Bool
not Bool
eIsconstant = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    BindingMap
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
    PrimEvaluator
primEval <- Getting PrimEvaluator RewriteEnv PrimEvaluator
-> RewriteMonad NormalizeState PrimEvaluator
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting PrimEvaluator RewriteEnv PrimEvaluator
Lens' RewriteEnv PrimEvaluator
evaluator
    GlobalHeap
gh <- Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
-> RewriteMonad NormalizeState GlobalHeap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting GlobalHeap (RewriteState NormalizeState) GlobalHeap
forall extra. Lens' (RewriteState extra) GlobalHeap
globalHeap
    Supply
ids <- Getting Supply (RewriteState NormalizeState) Supply
-> RewriteMonad NormalizeState Supply
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting Supply (RewriteState NormalizeState) Supply
forall extra. Lens' (RewriteState extra) Supply
uniqSupply
    let (ids1 :: Supply
ids1,ids2 :: Supply
ids2) = Supply -> (Supply, Supply)
splitSupply Supply
ids
    (Supply -> Identity Supply)
-> RewriteState NormalizeState
-> Identity (RewriteState NormalizeState)
forall extra. Lens' (RewriteState extra) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState NormalizeState
 -> Identity (RewriteState NormalizeState))
-> Supply -> RewriteMonad NormalizeState ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= Supply
ids2
    let eval :: Term -> Term
eval = (Getting Term (GlobalHeap, PureHeap, Term) Term
-> (GlobalHeap, PureHeap, Term) -> Term
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting Term (GlobalHeap, PureHeap, Term) Term
forall s t a b. Field3 s t a b => Lens s t a b
Lens._3) ((GlobalHeap, PureHeap, Term) -> Term)
-> (Term -> (GlobalHeap, PureHeap, Term)) -> Term -> Term
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimEvaluator
-> BindingMap
-> TyConMap
-> GlobalHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (GlobalHeap, PureHeap, Term)
whnf' PrimEvaluator
primEval BindingMap
bndrs TyConMap
tcm GlobalHeap
gh Supply
ids1 InScopeSet
inScope Bool
False
        eTy :: Type
eTy  = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e
    Bool
untran <- Bool -> Type -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
False Type
eTy
    case Bool
untran of
      -- Don't lift out non-representable values, because they cannot be let-bound
      -- in our desired normal form.
      False -> do
        Maybe Term
isInteresting <- InScopeSet
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> RewriteMonad NormalizeState (Maybe Term)
forall extra.
InScopeSet
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> RewriteMonad extra (Maybe Term)
interestingToLift InScopeSet
inScope Term -> Term
eval Term
fun [Either Term Type]
args
        case Maybe Term
isInteresting of
          Just fun' :: Term
fun' | Term
fun' Term -> [Term] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Term]
seen -> do
            (args' :: [Either Term Type]
args',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> [Either Term Type]
-> RewriteMonad
     NormalizeState
     ([Either Term Type],
      [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsArgs InScopeSet
inScope [(Term, Term)]
substitution
                                                    (Term
fun'Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
:[Term]
seen) [Either Term Type]
args
            let e' :: Term
e' = Term -> Maybe Term -> Term
forall a. a -> Maybe a -> a
Maybe.fromMaybe (Term -> [Either Term Type] -> Term
mkApps Term
fun' [Either Term Type]
args') (Term -> [(Term, Term)] -> Maybe Term
forall a b. Eq a => a -> [(a, b)] -> Maybe b
List.lookup Term
fun' [(Term, Term)]
substitution)
            -- This function is lifted out an environment with the currently 'seen'
            -- binders. When we later apply substitution, we need to start with this
            -- environment, otherwise we perform incorrect substitutions in the
            -- arguments.
            (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term
e',(Term
fun',([Term]
seen,[Either Term Type] -> CaseTree [Either Term Type]
forall a. a -> CaseTree a
Leaf [Either Term Type]
args'))(Term, ([Term], CaseTree [Either Term Type]))
-> [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall a. a -> [a] -> [a]
:[(Term, ([Term], CaseTree [Either Term Type]))]
collected)
          _ -> do (args' :: [Either Term Type]
args',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> [Either Term Type]
-> RewriteMonad
     NormalizeState
     ([Either Term Type],
      [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsArgs InScopeSet
inScope [(Term, Term)]
substitution
                                                          [Term]
seen [Either Term Type]
args
                  (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) [Either Term Type]
args',[(Term, ([Term], CaseTree [Either Term Type]))]
collected)
      _ -> (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term
e,[])

-- FIXME: This duplicates A LOT of let-bindings, where I just pray that after
-- the ANF, CSE, and DeadCodeRemoval pass all duplicates are removed.
--
-- I think we should be able to do better, but perhaps we cannot fix it here.
collectGlobals' inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen (Letrec lbs :: [LetBinding]
lbs body :: Term
body) _eIsConstant :: Bool
_eIsConstant = do
  (body' :: Term
body',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected)   <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals    InScopeSet
inScope [(Term, Term)]
substitution [Term]
seen Term
body
  (lbs' :: [LetBinding]
lbs',collected' :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected')   <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> [LetBinding]
-> RewriteMonad
     NormalizeState
     ([LetBinding], [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsLbs InScopeSet
inScope [(Term, Term)]
substitution
                                           (((Term, ([Term], CaseTree [Either Term Type])) -> Term)
-> [(Term, ([Term], CaseTree [Either Term Type]))] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, ([Term], CaseTree [Either Term Type])) -> Term
forall a b. (a, b) -> a
fst [(Term, ([Term], CaseTree [Either Term Type]))]
collected [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
seen)
                                           [LetBinding]
lbs
  (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
lbs' Term
body'
         ,((Term, ([Term], CaseTree [Either Term Type]))
 -> (Term, ([Term], CaseTree [Either Term Type])))
-> [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall a b. (a -> b) -> [a] -> [b]
map ((([Term], CaseTree [Either Term Type])
 -> ([Term], CaseTree [Either Term Type]))
-> (Term, ([Term], CaseTree [Either Term Type]))
-> (Term, ([Term], CaseTree [Either Term Type]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((CaseTree [Either Term Type] -> CaseTree [Either Term Type])
-> ([Term], CaseTree [Either Term Type])
-> ([Term], CaseTree [Either Term Type])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([LetBinding]
-> CaseTree [Either Term Type] -> CaseTree [Either Term Type]
forall a. [LetBinding] -> CaseTree a -> CaseTree a
LB [LetBinding]
lbs'))) ([(Term, ([Term], CaseTree [Either Term Type]))]
collected [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall a. [a] -> [a] -> [a]
++ [(Term, ([Term], CaseTree [Either Term Type]))]
collected')
         )

collectGlobals' _ _ _ e :: Term
e _ = (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term
e,[])

-- | Collect 'CaseTree's for (potentially) disjoint applications of globals out
-- of an expression. Also substitute truly disjoint applications of globals by a
-- reference to a lifted out application.
collectGlobals
  :: InScopeSet
  -> [(Term,Term)]
  -- ^ Substitution of (applications of) a global binder by a reference to
  -- a lifted term.
  -> [Term]
  -- ^ List of already seen global binders
  -> Term
  -- ^ The expression
  -> RewriteMonad
      NormalizeState
      (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals :: InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen e :: Term
e =
  InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> Bool
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals' InScopeSet
inScope [(Term, Term)]
substitution [Term]
seen Term
e (Term -> Bool
isConstant Term
e)

-- | Collect 'CaseTree's for (potentially) disjoint applications of globals out
-- of a list of application arguments. Also substitute truly disjoint
-- applications of globals by a reference to a lifted out application.
collectGlobalsArgs ::
     InScopeSet
  -> [(Term,Term)] -- ^ Substitution of (applications of) a global
                   -- binder by a reference to a lifted term.
  -> [Term] -- ^ List of already seen global binders
  -> [Either Term Type] -- ^ The list of arguments
  -> RewriteMonad NormalizeState
                  ([Either Term Type]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsArgs :: InScopeSet
-> [(Term, Term)]
-> [Term]
-> [Either Term Type]
-> RewriteMonad
     NormalizeState
     ([Either Term Type],
      [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsArgs inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen args :: [Either Term Type]
args = do
    (_,(args' :: [Either Term Type]
args',collected :: [[(Term, ([Term], CaseTree [Either Term Type]))]]
collected)) <- ([(Either Term Type,
   [(Term, ([Term], CaseTree [Either Term Type]))])]
 -> ([Either Term Type],
     [[(Term, ([Term], CaseTree [Either Term Type]))]]))
-> ([Term],
    [(Either Term Type,
      [(Term, ([Term], CaseTree [Either Term Type]))])])
-> ([Term],
    ([Either Term Type],
     [[(Term, ([Term], CaseTree [Either Term Type]))]]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second [(Either Term Type,
  [(Term, ([Term], CaseTree [Either Term Type]))])]
-> ([Either Term Type],
    [[(Term, ([Term], CaseTree [Either Term Type]))]])
forall a b. [(a, b)] -> ([a], [b])
unzip (([Term],
  [(Either Term Type,
    [(Term, ([Term], CaseTree [Either Term Type]))])])
 -> ([Term],
     ([Either Term Type],
      [[(Term, ([Term], CaseTree [Either Term Type]))]])))
-> RewriteMonad
     NormalizeState
     ([Term],
      [(Either Term Type,
        [(Term, ([Term], CaseTree [Either Term Type]))])])
-> RewriteMonad
     NormalizeState
     ([Term],
      ([Either Term Type],
       [[(Term, ([Term], CaseTree [Either Term Type]))]]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Term]
 -> Either Term Type
 -> RewriteMonad
      NormalizeState
      ([Term],
       (Either Term Type,
        [(Term, ([Term], CaseTree [Either Term Type]))])))
-> [Term]
-> [Either Term Type]
-> RewriteMonad
     NormalizeState
     ([Term],
      [(Either Term Type,
        [(Term, ([Term], CaseTree [Either Term Type]))])])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM [Term]
-> Either Term Type
-> RewriteMonad
     NormalizeState
     ([Term],
      (Either Term Type,
       [(Term, ([Term], CaseTree [Either Term Type]))]))
forall b.
[Term]
-> Either Term b
-> RewriteMonad
     NormalizeState
     ([Term],
      (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
go [Term]
seen [Either Term Type]
args
    ([Either Term Type],
 [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     ([Either Term Type],
      [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Either Term Type]
args',[[(Term, ([Term], CaseTree [Either Term Type]))]]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Term, ([Term], CaseTree [Either Term Type]))]]
collected)
  where
    go :: [Term]
-> Either Term b
-> RewriteMonad
     NormalizeState
     ([Term],
      (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
go s :: [Term]
s (Left tm :: Term
tm) = do
      (tm' :: Term
tm',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals InScopeSet
inScope [(Term, Term)]
substitution [Term]
s Term
tm
      ([Term],
 (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
-> RewriteMonad
     NormalizeState
     ([Term],
      (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Term, ([Term], CaseTree [Either Term Type])) -> Term)
-> [(Term, ([Term], CaseTree [Either Term Type]))] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, ([Term], CaseTree [Either Term Type])) -> Term
forall a b. (a, b) -> a
fst [(Term, ([Term], CaseTree [Either Term Type]))]
collected [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
s,(Term -> Either Term b
forall a b. a -> Either a b
Left Term
tm',[(Term, ([Term], CaseTree [Either Term Type]))]
collected))
    go s :: [Term]
s (Right ty :: b
ty) = ([Term],
 (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
-> RewriteMonad
     NormalizeState
     ([Term],
      (Either Term b, [(Term, ([Term], CaseTree [Either Term Type]))]))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Term]
s,(b -> Either Term b
forall a b. b -> Either a b
Right b
ty,[]))

-- | Collect 'CaseTree's for (potentially) disjoint applications of globals out
-- of a list of alternatives. Also substitute truly disjoint applications of
-- globals by a reference to a lifted out application.
collectGlobalsAlts ::
     InScopeSet
  -> [(Term,Term)] -- ^ Substitution of (applications of) a global
                   -- binder by a reference to a lifted term.
  -> [Term] -- ^ List of already seen global binders
  -> Term -- ^ The subject term
  -> [(Pat,Term)] -- ^ The list of alternatives
  -> RewriteMonad NormalizeState
                  ([(Pat,Term)]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsAlts :: InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> [Alt]
-> RewriteMonad
     NormalizeState
     ([Alt], [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsAlts inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen scrut :: Term
scrut alts :: [Alt]
alts = do
    (alts' :: [Alt]
alts',collected :: [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]]
collected) <- [(Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))])]
-> ([Alt],
    [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))])]
 -> ([Alt],
     [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]]))
-> RewriteMonad
     NormalizeState
     [(Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))])]
-> RewriteMonad
     NormalizeState
     ([Alt], [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Alt
 -> RewriteMonad
      NormalizeState
      (Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))]))
-> [Alt]
-> RewriteMonad
     NormalizeState
     [(Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Alt
-> RewriteMonad
     NormalizeState
     (Alt, [(Term, ([Term], (Pat, CaseTree [Either Term Type])))])
forall t.
(t, Term)
-> RewriteMonad
     NormalizeState
     ((t, Term), [(Term, ([Term], (t, CaseTree [Either Term Type])))])
go [Alt]
alts
    let collectedM :: [Map Term ([Term], [(Pat, CaseTree [Either Term Type])])]
collectedM  = ([(Term, ([Term], (Pat, CaseTree [Either Term Type])))]
 -> Map Term ([Term], [(Pat, CaseTree [Either Term Type])]))
-> [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]]
-> [Map Term ([Term], [(Pat, CaseTree [Either Term Type])])]
forall a b. (a -> b) -> [a] -> [b]
map ([(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))]
-> Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))]
 -> Map Term ([Term], [(Pat, CaseTree [Either Term Type])]))
-> ([(Term, ([Term], (Pat, CaseTree [Either Term Type])))]
    -> [(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))])
-> [(Term, ([Term], (Pat, CaseTree [Either Term Type])))]
-> Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Term, ([Term], (Pat, CaseTree [Either Term Type])))
 -> (Term, ([Term], [(Pat, CaseTree [Either Term Type])])))
-> [(Term, ([Term], (Pat, CaseTree [Either Term Type])))]
-> [(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))]
forall a b. (a -> b) -> [a] -> [b]
map ((([Term], (Pat, CaseTree [Either Term Type]))
 -> ([Term], [(Pat, CaseTree [Either Term Type])]))
-> (Term, ([Term], (Pat, CaseTree [Either Term Type])))
-> (Term, ([Term], [(Pat, CaseTree [Either Term Type])]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (((Pat, CaseTree [Either Term Type])
 -> [(Pat, CaseTree [Either Term Type])])
-> ([Term], (Pat, CaseTree [Either Term Type]))
-> ([Term], [(Pat, CaseTree [Either Term Type])])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((Pat, CaseTree [Either Term Type])
-> [(Pat, CaseTree [Either Term Type])]
-> [(Pat, CaseTree [Either Term Type])]
forall a. a -> [a] -> [a]
:[])))) [[(Term, ([Term], (Pat, CaseTree [Either Term Type])))]]
collected
        collectedUN :: Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
collectedUN = (([Term], [(Pat, CaseTree [Either Term Type])])
 -> ([Term], [(Pat, CaseTree [Either Term Type])])
 -> ([Term], [(Pat, CaseTree [Either Term Type])]))
-> [Map Term ([Term], [(Pat, CaseTree [Either Term Type])])]
-> Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
Map.unionsWith (\(l1 :: [Term]
l1,r1 :: [(Pat, CaseTree [Either Term Type])]
r1) (l2 :: [Term]
l2,r2 :: [(Pat, CaseTree [Either Term Type])]
r2) -> ([Term] -> [Term]
forall a. Eq a => [a] -> [a]
List.nub ([Term]
l1 [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
l2),[(Pat, CaseTree [Either Term Type])]
r1 [(Pat, CaseTree [Either Term Type])]
-> [(Pat, CaseTree [Either Term Type])]
-> [(Pat, CaseTree [Either Term Type])]
forall a. [a] -> [a] -> [a]
++ [(Pat, CaseTree [Either Term Type])]
r2)) [Map Term ([Term], [(Pat, CaseTree [Either Term Type])])]
collectedM
        collected' :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected'  = ((Term, ([Term], [(Pat, CaseTree [Either Term Type])]))
 -> (Term, ([Term], CaseTree [Either Term Type])))
-> [(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall a b. (a -> b) -> [a] -> [b]
map ((([Term], [(Pat, CaseTree [Either Term Type])])
 -> ([Term], CaseTree [Either Term Type]))
-> (Term, ([Term], [(Pat, CaseTree [Either Term Type])]))
-> (Term, ([Term], CaseTree [Either Term Type]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (([(Pat, CaseTree [Either Term Type])]
 -> CaseTree [Either Term Type])
-> ([Term], [(Pat, CaseTree [Either Term Type])])
-> ([Term], CaseTree [Either Term Type])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Term
-> [(Pat, CaseTree [Either Term Type])]
-> CaseTree [Either Term Type]
forall a. Term -> [(Pat, CaseTree a)] -> CaseTree a
Branch Term
scrut))) (Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
-> [(Term, ([Term], [(Pat, CaseTree [Either Term Type])]))]
forall k a. Map k a -> [(k, a)]
Map.toList Map Term ([Term], [(Pat, CaseTree [Either Term Type])])
collectedUN)
    ([Alt], [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     ([Alt], [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Alt]
alts',[(Term, ([Term], CaseTree [Either Term Type]))]
collected')
  where
    go :: (t, Term)
-> RewriteMonad
     NormalizeState
     ((t, Term), [(Term, ([Term], (t, CaseTree [Either Term Type])))])
go (p :: t
p,e :: Term
e) = do
      (e' :: Term
e',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals InScopeSet
inScope [(Term, Term)]
substitution [Term]
seen Term
e
      ((t, Term), [(Term, ([Term], (t, CaseTree [Either Term Type])))])
-> RewriteMonad
     NormalizeState
     ((t, Term), [(Term, ([Term], (t, CaseTree [Either Term Type])))])
forall (m :: * -> *) a. Monad m => a -> m a
return ((t
p,Term
e'),((Term, ([Term], CaseTree [Either Term Type]))
 -> (Term, ([Term], (t, CaseTree [Either Term Type]))))
-> [(Term, ([Term], CaseTree [Either Term Type]))]
-> [(Term, ([Term], (t, CaseTree [Either Term Type])))]
forall a b. (a -> b) -> [a] -> [b]
map ((([Term], CaseTree [Either Term Type])
 -> ([Term], (t, CaseTree [Either Term Type])))
-> (Term, ([Term], CaseTree [Either Term Type]))
-> (Term, ([Term], (t, CaseTree [Either Term Type])))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((CaseTree [Either Term Type] -> (t, CaseTree [Either Term Type]))
-> ([Term], CaseTree [Either Term Type])
-> ([Term], (t, CaseTree [Either Term Type]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (t
p,))) [(Term, ([Term], CaseTree [Either Term Type]))]
collected)

-- | Collect 'CaseTree's for (potentially) disjoint applications of globals out
-- of a list of let-bindings. Also substitute truly disjoint applications of
-- globals by a reference to a lifted out application.
collectGlobalsLbs ::
     InScopeSet
  -> [(Term,Term)] -- ^ Substitution of (applications of) a global
                   -- binder by a reference to a lifted term.
  -> [Term] -- ^ List of already seen global binders
  -> [LetBinding] -- ^ The list let-bindings
  -> RewriteMonad NormalizeState
                  ([LetBinding]
                  ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                  )
collectGlobalsLbs :: InScopeSet
-> [(Term, Term)]
-> [Term]
-> [LetBinding]
-> RewriteMonad
     NormalizeState
     ([LetBinding], [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobalsLbs inScope :: InScopeSet
inScope substitution :: [(Term, Term)]
substitution seen :: [Term]
seen lbs :: [LetBinding]
lbs = do
    (_,(lbs' :: [LetBinding]
lbs',collected :: [[(Term, ([Term], CaseTree [Either Term Type]))]]
collected)) <- ([(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])]
 -> ([LetBinding],
     [[(Term, ([Term], CaseTree [Either Term Type]))]]))
-> ([Term],
    [(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])])
-> ([Term],
    ([LetBinding], [[(Term, ([Term], CaseTree [Either Term Type]))]]))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second [(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])]
-> ([LetBinding],
    [[(Term, ([Term], CaseTree [Either Term Type]))]])
forall a b. [(a, b)] -> ([a], [b])
unzip (([Term],
  [(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])])
 -> ([Term],
     ([LetBinding], [[(Term, ([Term], CaseTree [Either Term Type]))]])))
-> RewriteMonad
     NormalizeState
     ([Term],
      [(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])])
-> RewriteMonad
     NormalizeState
     ([Term],
      ([LetBinding], [[(Term, ([Term], CaseTree [Either Term Type]))]]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Term]
 -> LetBinding
 -> RewriteMonad
      NormalizeState
      ([Term],
       (LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])))
-> [Term]
-> [LetBinding]
-> RewriteMonad
     NormalizeState
     ([Term],
      [(LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))])])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM [Term]
-> LetBinding
-> RewriteMonad
     NormalizeState
     ([Term],
      (LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))]))
go [Term]
seen [LetBinding]
lbs
    ([LetBinding], [(Term, ([Term], CaseTree [Either Term Type]))])
-> RewriteMonad
     NormalizeState
     ([LetBinding], [(Term, ([Term], CaseTree [Either Term Type]))])
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding]
lbs',[[(Term, ([Term], CaseTree [Either Term Type]))]]
-> [(Term, ([Term], CaseTree [Either Term Type]))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Term, ([Term], CaseTree [Either Term Type]))]]
collected)
  where
    go :: [Term] -> LetBinding
       -> RewriteMonad NormalizeState
                  ([Term]
                  ,(LetBinding
                   ,[(Term,([Term],CaseTree [(Either Term Type)]))]
                   )
                  )
    go :: [Term]
-> LetBinding
-> RewriteMonad
     NormalizeState
     ([Term],
      (LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))]))
go s :: [Term]
s (id_ :: Id
id_, e :: Term
e) = do
      (e' :: Term
e',collected :: [(Term, ([Term], CaseTree [Either Term Type]))]
collected) <- InScopeSet
-> [(Term, Term)]
-> [Term]
-> Term
-> RewriteMonad
     NormalizeState
     (Term, [(Term, ([Term], CaseTree [Either Term Type]))])
collectGlobals InScopeSet
inScope [(Term, Term)]
substitution [Term]
s Term
e
      ([Term],
 (LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))]))
-> RewriteMonad
     NormalizeState
     ([Term],
      (LetBinding, [(Term, ([Term], CaseTree [Either Term Type]))]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Term, ([Term], CaseTree [Either Term Type])) -> Term)
-> [(Term, ([Term], CaseTree [Either Term Type]))] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Term, ([Term], CaseTree [Either Term Type])) -> Term
forall a b. (a, b) -> a
fst [(Term, ([Term], CaseTree [Either Term Type]))]
collected [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ [Term]
s,((Id
id_,Term
e'),[(Term, ([Term], CaseTree [Either Term Type]))]
collected))

-- | Given a case-tree corresponding to a disjoint interesting \"term-in-a-
-- function-position\", return a let-expression: where the let-binding holds
-- a case-expression selecting between the distinct arguments of the case-tree,
-- and the body is an application of the term applied to the shared arguments of
-- the case tree, and projections of let-binding corresponding to the distinct
-- argument positions.
mkDisjointGroup
  :: InScopeSet
  -- ^ Variables in scope at the very top of the case-tree, i.e., the original
  -- expression
  -> (Term,([Term],CaseTree [(Either Term Type)]))
  -- ^ Case-tree of arguments belonging to the applied term.
  -> RewriteMonad NormalizeState (Term,[Term])
mkDisjointGroup :: InScopeSet
-> (Term, ([Term], CaseTree [Either Term Type]))
-> RewriteMonad NormalizeState (Term, [Term])
mkDisjointGroup inScope :: InScopeSet
inScope (fun :: Term
fun,(seen :: [Term]
seen,cs :: CaseTree [Either Term Type]
cs)) = do
    let argss :: [[Either Term Type]]
argss    = CaseTree [Either Term Type] -> [[Either Term Type]]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList CaseTree [Either Term Type]
cs
        argssT :: [(Int, [Either Term Type])]
argssT   = [Int] -> [[Either Term Type]] -> [(Int, [Either Term Type])]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] ([[Either Term Type]] -> [[Either Term Type]]
forall a. [[a]] -> [[a]]
List.transpose [[Either Term Type]]
argss)
        (sharedT :: [(Int, [Either Term Type])]
sharedT,distinctT :: [(Int, [Either Term Type])]
distinctT) = ((Int, [Either Term Type]) -> Bool)
-> [(Int, [Either Term Type])]
-> ([(Int, [Either Term Type])], [(Int, [Either Term Type])])
forall a. (a -> Bool) -> [a] -> ([a], [a])
List.partition (InScopeSet -> [Either Term Type] -> Bool
areShared InScopeSet
inScope ([Either Term Type] -> Bool)
-> ((Int, [Either Term Type]) -> [Either Term Type])
-> (Int, [Either Term Type])
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, [Either Term Type]) -> [Either Term Type]
forall a b. (a, b) -> b
snd) [(Int, [Either Term Type])]
argssT
        shared :: [(Int, Either Term Type)]
shared   = ((Int, [Either Term Type]) -> (Int, Either Term Type))
-> [(Int, [Either Term Type])] -> [(Int, Either Term Type)]
forall a b. (a -> b) -> [a] -> [b]
map (([Either Term Type] -> Either Term Type)
-> (Int, [Either Term Type]) -> (Int, Either Term Type)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second [Either Term Type] -> Either Term Type
forall a. [a] -> a
head) [(Int, [Either Term Type])]
sharedT
        distinct :: [[Term]]
distinct = ([Either Term Type] -> [Term]) -> [[Either Term Type]] -> [[Term]]
forall a b. (a -> b) -> [a] -> [b]
map ([Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts) ([[Either Term Type]] -> [[Either Term Type]]
forall a. [[a]] -> [[a]]
List.transpose (((Int, [Either Term Type]) -> [Either Term Type])
-> [(Int, [Either Term Type])] -> [[Either Term Type]]
forall a b. (a -> b) -> [a] -> [b]
map (Int, [Either Term Type]) -> [Either Term Type]
forall a b. (a, b) -> b
snd [(Int, [Either Term Type])]
distinctT))
        cs' :: CaseTree [(Int, Either Term Type)]
cs'      = ([Either Term Type] -> [(Int, Either Term Type)])
-> CaseTree [Either Term Type]
-> CaseTree [(Int, Either Term Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int] -> [Either Term Type] -> [(Int, Either Term Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..]) CaseTree [Either Term Type]
cs
        cs'' :: CaseTree [Term]
cs''     = CaseTree [Term] -> CaseTree [Term]
forall a. Eq a => CaseTree [a] -> CaseTree [a]
removeEmpty
                 (CaseTree [Term] -> CaseTree [Term])
-> CaseTree [Term] -> CaseTree [Term]
forall a b. (a -> b) -> a -> b
$ ([(Int, Either Term Type)] -> [Term])
-> CaseTree [(Int, Either Term Type)] -> CaseTree [Term]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts ([Either Term Type] -> [Term])
-> ([(Int, Either Term Type)] -> [Either Term Type])
-> [(Int, Either Term Type)]
-> [Term]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Either Term Type) -> Either Term Type)
-> [(Int, Either Term Type)] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Either Term Type) -> Either Term Type
forall a b. (a, b) -> b
snd)
                        (if [(Int, Either Term Type)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Int, Either Term Type)]
shared
                           then CaseTree [(Int, Either Term Type)]
cs'
                           else ([(Int, Either Term Type)] -> [(Int, Either Term Type)])
-> CaseTree [(Int, Either Term Type)]
-> CaseTree [(Int, Either Term Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Int, Either Term Type) -> Bool)
-> [(Int, Either Term Type)] -> [(Int, Either Term Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int, Either Term Type) -> [(Int, Either Term Type)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [(Int, Either Term Type)]
shared)) CaseTree [(Int, Either Term Type)]
cs')
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    (distinctCaseM :: Maybe LetBinding
distinctCaseM,distinctProjections :: [Term]
distinctProjections) <- case [[Term]]
distinct of
      -- only shared arguments: do nothing.
      [] -> (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LetBinding
forall a. Maybe a
Nothing,[])
      -- Create selectors and projections
      (uc :: [Term]
uc:_) -> do
        let argTys :: [Type]
argTys = (Term -> Type) -> [Term] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> Term -> Type
termType TyConMap
tcm) [Term]
uc
        InScopeSet
-> [Type]
-> CaseTree [Term]
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
disJointSelProj InScopeSet
inScope [Type]
argTys CaseTree [Term]
cs''
    let newArgs :: [Either Term Type]
newArgs = Int -> [(Int, Either Term Type)] -> [Term] -> [Either Term Type]
mkDJArgs 0 [(Int, Either Term Type)]
shared [Term]
distinctProjections
    case Maybe LetBinding
distinctCaseM of
      Just lb :: LetBinding
lb -> (Term, [Term]) -> RewriteMonad NormalizeState (Term, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding
lb] (Term -> [Either Term Type] -> Term
mkApps Term
fun [Either Term Type]
newArgs), [Term]
seen)
      Nothing -> (Term, [Term]) -> RewriteMonad NormalizeState (Term, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps Term
fun [Either Term Type]
newArgs, [Term]
seen)

-- | Create a single selector for all the representable distinct arguments by
-- selecting between tuples. This selector is only ('Just') created when the
-- number of representable uncommmon arguments is larger than one, otherwise it
-- is not ('Nothing').
--
-- It also returns:
--
-- * For all the non-representable distinct arguments: a selector
-- * For all the representable distinct arguments: a projection out of the tuple
--   created by the larger selector. If this larger selector does not exist, a
--   single selector is created for the single representable distinct argument.
disJointSelProj
  :: InScopeSet
  -> [Type]
  -- ^ Types of the arguments
  -> CaseTree [Term]
  -- The case-tree of arguments
  -> RewriteMonad NormalizeState (Maybe LetBinding,[Term])
disJointSelProj :: InScopeSet
-> [Type]
-> CaseTree [Term]
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
disJointSelProj _ _ (Leaf []) = (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LetBinding
forall a. Maybe a
Nothing,[])
disJointSelProj inScope :: InScopeSet
inScope argTys :: [Type]
argTys cs :: CaseTree [Term]
cs = do
    let maxIndex :: Int
maxIndex = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTys Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1
        css :: [CaseTree [Term]]
css = (Int -> CaseTree [Term]) -> [Int] -> [CaseTree [Term]]
forall a b. (a -> b) -> [a] -> [b]
map (\i :: Int
i -> ([Term] -> [Term]) -> CaseTree [Term] -> CaseTree [Term]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
:[]) (Term -> [Term]) -> ([Term] -> Term) -> [Term] -> [Term]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Term] -> Int -> Term
forall a. [a] -> Int -> a
!!Int
i)) CaseTree [Term]
cs) [0..Int
maxIndex]
    (untran :: [(Int, Type)]
untran,tran :: [(Int, Type)]
tran) <- ((Int, Type) -> RewriteMonad NormalizeState Bool)
-> [(Int, Type)]
-> RewriteMonad NormalizeState ([(Int, Type)], [(Int, Type)])
forall (m :: * -> *) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Bool -> Type -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
False (Type -> RewriteMonad NormalizeState Bool)
-> ((Int, Type) -> Type)
-> (Int, Type)
-> RewriteMonad NormalizeState Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Type) -> Type
forall a b. (a, b) -> b
snd) ([Int] -> [Type] -> [(Int, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] [Type]
argTys)
    let untranCs :: [CaseTree [Term]]
untranCs   = (Int -> CaseTree [Term]) -> [Int] -> [CaseTree [Term]]
forall a b. (a -> b) -> [a] -> [b]
map ([CaseTree [Term]]
css[CaseTree [Term]] -> Int -> CaseTree [Term]
forall a. [a] -> Int -> a
!!) (((Int, Type) -> Int) -> [(Int, Type)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Type) -> Int
forall a b. (a, b) -> a
fst [(Int, Type)]
untran)
        untranSels :: [Term]
untranSels = ((Int, Type) -> CaseTree [Term] -> Term)
-> [(Int, Type)] -> [CaseTree [Term]] -> [Term]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(_,ty :: Type
ty) cs' :: CaseTree [Term]
cs' -> Type -> Maybe DataCon -> [Type] -> CaseTree [Term] -> Term
genCase Type
ty Maybe DataCon
forall a. Maybe a
Nothing []  CaseTree [Term]
cs')
                             [(Int, Type)]
untran [CaseTree [Term]]
untranCs
    (lbM :: Maybe LetBinding
lbM,projs :: [Term]
projs) <- case [(Int, Type)]
tran of
      []       -> (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LetBinding
forall a. Maybe a
Nothing,[])
      [(i :: Int
i,ty :: Type
ty)] -> (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LetBinding
forall a. Maybe a
Nothing,[Type -> Maybe DataCon -> [Type] -> CaseTree [Term] -> Term
genCase Type
ty Maybe DataCon
forall a. Maybe a
Nothing [] ([CaseTree [Term]]
css[CaseTree [Term]] -> Int -> CaseTree [Term]
forall a. [a] -> Int -> a
!!Int
i)])
      tys :: [(Int, Type)]
tys      -> do
        TyConMap
tcm    <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
        IntMap TyConName
tupTcm <- Getting (IntMap TyConName) RewriteEnv (IntMap TyConName)
-> RewriteMonad NormalizeState (IntMap TyConName)
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting (IntMap TyConName) RewriteEnv (IntMap TyConName)
Lens' RewriteEnv (IntMap TyConName)
tupleTcCache
        let m :: Int
m            = [(Int, Type)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Type)]
tys
            Just tupTcNm :: TyConName
tupTcNm = Int -> IntMap TyConName -> Maybe TyConName
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
m IntMap TyConName
tupTcm
            Just tupTc :: TyCon
tupTc   = TyConName -> TyConMap -> Maybe TyCon
forall a b. Uniquable a => a -> UniqMap b -> Maybe b
lookupUniqMap TyConName
tupTcNm TyConMap
tcm
            [tupDc :: DataCon
tupDc]      = TyCon -> [DataCon]
tyConDataCons TyCon
tupTc
            (tyIxs :: [Int]
tyIxs,tys' :: [Type]
tys') = [(Int, Type)] -> ([Int], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, Type)]
tys
            tupTy :: Type
tupTy        = TyConName -> [Type] -> Type
mkTyConApp TyConName
tupTcNm [Type]
tys'
            cs' :: CaseTree [Term]
cs'          = ([Term] -> [Term]) -> CaseTree [Term] -> CaseTree [Term]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\es :: [Term]
es -> (Int -> Term) -> [Int] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map ([Term]
es [Term] -> Int -> Term
forall a. [a] -> Int -> a
!!) [Int]
tyIxs) CaseTree [Term]
cs
            djCase :: Term
djCase       = Type -> Maybe DataCon -> [Type] -> CaseTree [Term] -> Term
genCase Type
tupTy (DataCon -> Maybe DataCon
forall a. a -> Maybe a
Just DataCon
tupDc) [Type]
tys' CaseTree [Term]
cs'
        Id
scrutId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: * -> *).
(Monad m, MonadUnique m) =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
inScope "tupIn" Type
tupTy
        [Term]
projections <- (Int -> RewriteMonad NormalizeState Term)
-> [Int] -> RewriteMonad NormalizeState [Term]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> InScopeSet
-> TyConMap
-> Term
-> Int
-> Int
-> RewriteMonad NormalizeState Term
forall (m :: * -> *).
(HasCallStack, Functor m, Monad m, MonadUnique m) =>
String -> InScopeSet -> TyConMap -> Term -> Int -> Int -> m Term
mkSelectorCase ($(curLoc) String -> ShowS
forall a. [a] -> [a] -> [a]
++ "disJointSelProj")
                                            InScopeSet
inScope TyConMap
tcm (Id -> Term
Var Id
scrutId) (DataCon -> Int
dcTag DataCon
tupDc)) [0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-1]
        (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (LetBinding -> Maybe LetBinding
forall a. a -> Maybe a
Just (Id
scrutId,Term
djCase),[Term]
projections)
    let selProjs :: [Term]
selProjs = Int -> [(Int, Term)] -> [Term] -> [Term]
forall a b. (Eq a, Num a) => a -> [(a, b)] -> [b] -> [b]
tranOrUnTran 0 ([Int] -> [Term] -> [(Int, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Int, Type) -> Int) -> [(Int, Type)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Type) -> Int
forall a b. (a, b) -> a
fst [(Int, Type)]
untran) [Term]
untranSels) [Term]
projs

    (Maybe LetBinding, [Term])
-> RewriteMonad NormalizeState (Maybe LetBinding, [Term])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe LetBinding
lbM,[Term]
selProjs)
  where
    tranOrUnTran :: a -> [(a, b)] -> [b] -> [b]
tranOrUnTran _ []       projs :: [b]
projs     = [b]
projs
    tranOrUnTran _ sels :: [(a, b)]
sels     []        = ((a, b) -> b) -> [(a, b)] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> b
forall a b. (a, b) -> b
snd [(a, b)]
sels
    tranOrUnTran n :: a
n ((ut :: a
ut,s :: b
s):uts :: [(a, b)]
uts) (p :: b
p:projs :: [b]
projs)
      | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
ut   = b
s b -> [b] -> [b]
forall a. a -> [a] -> [a]
: a -> [(a, b)] -> [b] -> [b]
tranOrUnTran (a
na -> a -> a
forall a. Num a => a -> a -> a
+1) [(a, b)]
uts          (b
pb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
projs)
      | Bool
otherwise = b
p b -> [b] -> [b]
forall a. a -> [a] -> [a]
: a -> [(a, b)] -> [b] -> [b]
tranOrUnTran (a
na -> a -> a
forall a. Num a => a -> a -> a
+1) ((a
ut,b
s)(a, b) -> [(a, b)] -> [(a, b)]
forall a. a -> [a] -> [a]
:[(a, b)]
uts) [b]
projs

-- | Arguments are shared between invocations if:
--
-- * They contain _no_ references to locally-bound variables
-- * Are all equal
areShared :: InScopeSet -> [Either Term Type] -> Bool
areShared :: InScopeSet -> [Either Term Type] -> Bool
areShared _       []       = Bool
True
areShared inScope :: InScopeSet
inScope xs :: [Either Term Type]
xs@(x :: Either Term Type
x:_) = Bool
noFV1 Bool -> Bool -> Bool
&& [Either Term Type] -> Bool
forall a. Eq a => [a] -> Bool
allEqual [Either Term Type]
xs
 where
  noFV1 :: Bool
noFV1 = case Either Term Type
x of
    Right ty :: Type
ty -> All -> Bool
getAll (Getting All Type (Var Any) -> (Var Any -> All) -> Type -> All
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf ((forall b. Var b -> Bool) -> IntSet -> Getting All Type (Var Any)
forall (f :: * -> *) a.
(Contravariant f, Applicative f) =>
(forall b. Var b -> Bool)
-> IntSet -> (Var a -> f (Var a)) -> Type -> f Type
typeFreeVars' forall b. Var b -> Bool
isLocallyBound IntSet
IntSet.empty)
                                       (All -> Var Any -> All
forall a b. a -> b -> a
const (Bool -> All
All Bool
False)) Type
ty)
    Left tm :: Term
tm  -> All -> Bool
getAll (Getting All Term (Var Any) -> (Var Any -> All) -> Term -> All
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf ((forall b. Var b -> Bool) -> Getting All Term (Var Any)
forall (f :: * -> *) a.
(Contravariant f, Applicative f) =>
(forall b. Var b -> Bool) -> (Var a -> f (Var a)) -> Term -> f Term
termFreeVars' forall b. Var b -> Bool
isLocallyBound)
                                       (All -> Var Any -> All
forall a b. a -> b -> a
const (Bool -> All
All Bool
False)) Term
tm)

  isLocallyBound :: Var a -> Bool
isLocallyBound v :: Var a
v = Var a
v Var a -> InScopeSet -> Bool
forall a. Var a -> InScopeSet -> Bool
`notElemInScopeSet` InScopeSet
inScope

-- | Create a list of arguments given a map of positions to common arguments,
-- and a list of arguments
mkDJArgs :: Int -- ^ Current position
         -> [(Int,Either Term Type)] -- ^ map from position to common argument
         -> [Term] -- ^ (projections for) distinct arguments
         -> [Either Term Type]
mkDJArgs :: Int -> [(Int, Either Term Type)] -> [Term] -> [Either Term Type]
mkDJArgs _ cms :: [(Int, Either Term Type)]
cms []   = ((Int, Either Term Type) -> Either Term Type)
-> [(Int, Either Term Type)] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Either Term Type) -> Either Term Type
forall a b. (a, b) -> b
snd [(Int, Either Term Type)]
cms
mkDJArgs _ [] uncms :: [Term]
uncms = (Term -> Either Term Type) -> [Term] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
map Term -> Either Term Type
forall a b. a -> Either a b
Left [Term]
uncms
mkDJArgs n :: Int
n ((m :: Int
m,x :: Either Term Type
x):cms :: [(Int, Either Term Type)]
cms) (y :: Term
y:uncms :: [Term]
uncms)
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m    = Either Term Type
x       Either Term Type -> [Either Term Type] -> [Either Term Type]
forall a. a -> [a] -> [a]
: Int -> [(Int, Either Term Type)] -> [Term] -> [Either Term Type]
mkDJArgs (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) [(Int, Either Term Type)]
cms (Term
yTerm -> [Term] -> [Term]
forall a. a -> [a] -> [a]
:[Term]
uncms)
  | Bool
otherwise = Term -> Either Term Type
forall a b. a -> Either a b
Left Term
y  Either Term Type -> [Either Term Type] -> [Either Term Type]
forall a. a -> [a] -> [a]
: Int -> [(Int, Either Term Type)] -> [Term] -> [Either Term Type]
mkDJArgs (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) ((Int
m,Either Term Type
x)(Int, Either Term Type)
-> [(Int, Either Term Type)] -> [(Int, Either Term Type)]
forall a. a -> [a] -> [a]
:[(Int, Either Term Type)]
cms) [Term]
uncms

-- | Create a case-expression that selects between the distinct arguments given
-- a case-tree
genCase :: Type -- ^ Type of the alternatives
        -> Maybe DataCon -- ^ DataCon to pack multiple arguments
        -> [Type] -- ^ Types of the arguments
        -> CaseTree [Term] -- ^ CaseTree of arguments
        -> Term
genCase :: Type -> Maybe DataCon -> [Type] -> CaseTree [Term] -> Term
genCase ty :: Type
ty dcM :: Maybe DataCon
dcM argTys :: [Type]
argTys = CaseTree [Term] -> Term
go
  where
    go :: CaseTree [Term] -> Term
go (Leaf tms :: [Term]
tms) =
      case Maybe DataCon
dcM of
        Just dc :: DataCon
dc -> Term -> [Either Term Type] -> Term
mkApps (DataCon -> Term
Data DataCon
dc) ((Type -> Either Term Type) -> [Type] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Either Term Type
forall a b. b -> Either a b
Right [Type]
argTys [Either Term Type] -> [Either Term Type] -> [Either Term Type]
forall a. [a] -> [a] -> [a]
++ (Term -> Either Term Type) -> [Term] -> [Either Term Type]
forall a b. (a -> b) -> [a] -> [b]
map Term -> Either Term Type
forall a b. a -> Either a b
Left [Term]
tms)
        _ -> [Term] -> Term
forall a. [a] -> a
head [Term]
tms

    go (LB lb :: [LetBinding]
lb ct :: CaseTree [Term]
ct) =
      [LetBinding] -> Term -> Term
Letrec [LetBinding]
lb (CaseTree [Term] -> Term
go CaseTree [Term]
ct)

    go (Branch scrut :: Term
scrut [(p :: Pat
p,ct :: CaseTree [Term]
ct)]) =
      let ct' :: Term
ct' = CaseTree [Term] -> Term
go CaseTree [Term]
ct
          (ptvs :: [TyVar]
ptvs,pids :: [Id]
pids) = Pat -> ([TyVar], [Id])
patIds Pat
p
      in  if ([TyVar] -> [Var Any]
forall a b. Coercible a b => a -> b
coerce [TyVar]
ptvs [Var Any] -> [Var Any] -> [Var Any]
forall a. [a] -> [a] -> [a]
++ [Id] -> [Var Any]
forall a b. Coercible a b => a -> b
coerce [Id]
pids) [Var Any] -> Term -> Bool
forall a. [Var a] -> Term -> Bool
`localVarsDoNotOccurIn` Term
ct'
             then Term
ct'
             else Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty [(Pat
p,Term
ct')]

    go (Branch scrut :: Term
scrut pats :: [(Pat, CaseTree [Term])]
pats) =
      Term -> Type -> [Alt] -> Term
Case Term
scrut Type
ty (((Pat, CaseTree [Term]) -> Alt)
-> [(Pat, CaseTree [Term])] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map ((CaseTree [Term] -> Term) -> (Pat, CaseTree [Term]) -> Alt
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second CaseTree [Term] -> Term
go) [(Pat, CaseTree [Term])]
pats)

-- | Determine if a term in a function position is interesting to lift out of
-- of a case-expression.
--
-- This holds for all global functions, and certain primitives. Currently those
-- primitives are:
--
-- * All non-power-of-two multiplications
-- * All division-like operations with a non-power-of-two divisor
interestingToLift
  :: InScopeSet
  -- ^ in scope
  -> (Term -> Term)
  -- ^ Evaluator
  -> Term
  -- ^ Term in function position
  -> [Either Term Type]
  -- ^ Arguments
  -> RewriteMonad extra (Maybe Term)
interestingToLift :: InScopeSet
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> RewriteMonad extra (Maybe Term)
interestingToLift inScope :: InScopeSet
inScope _ e :: Term
e@(Var v :: Id
v) _ =
  if Id -> Bool
forall b. Var b -> Bool
isGlobalId Id
v Bool -> Bool -> Bool
||  Id
v Id -> InScopeSet -> Bool
forall a. Var a -> InScopeSet -> Bool
`elemInScopeSet` InScopeSet
inScope
     then Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Maybe Term
forall a. a -> Maybe a
Just Term
e)
     else Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Term
forall a. Maybe a
Nothing
interestingToLift inScope :: InScopeSet
inScope eval :: Term -> Term
eval e :: Term
e@(Prim nm :: OccName
nm pInfo :: PrimInfo
pInfo) args :: [Either Term Type]
args = do
  let anyArgNotConstant :: Bool
anyArgNotConstant = (Term -> Bool) -> [Term] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> Bool
not (Bool -> Bool) -> (Term -> Bool) -> Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Bool
isConstant) [Term]
lArgs
  case OccName -> [(OccName, Bool)] -> Maybe Bool
forall a b. Eq a => a -> [(a, b)] -> Maybe b
List.lookup OccName
nm [(OccName, Bool)]
interestingPrims of
    Just t :: Bool
t | Bool
t Bool -> Bool -> Bool
|| Bool
anyArgNotConstant -> Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Maybe Term
forall a. a -> Maybe a
Just Term
e)
    _ -> do
      let isInteresting :: Term -> RewriteMonad extra (Maybe Term)
isInteresting = (Term -> [Either Term Type] -> RewriteMonad extra (Maybe Term))
-> (Term, [Either Term Type]) -> RewriteMonad extra (Maybe Term)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (InScopeSet
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> RewriteMonad extra (Maybe Term)
forall extra.
InScopeSet
-> (Term -> Term)
-> Term
-> [Either Term Type]
-> RewriteMonad extra (Maybe Term)
interestingToLift InScopeSet
inScope Term -> Term
eval) ((Term, [Either Term Type]) -> RewriteMonad extra (Maybe Term))
-> (Term -> (Term, [Either Term Type]))
-> Term
-> RewriteMonad extra (Maybe Term)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> (Term, [Either Term Type])
collectArgs
      if Type -> Bool
isHOTy (PrimInfo -> Type
primType PrimInfo
pInfo) then do
        Bool
anyInteresting <- (Term -> RewriteMonad extra Bool)
-> [Term] -> RewriteMonad extra Bool
forall (m :: * -> *) a. Monad m => (a -> m Bool) -> [a] -> m Bool
anyM ((Maybe Term -> Bool)
-> RewriteMonad extra (Maybe Term) -> RewriteMonad extra Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe Term -> Bool
forall a. Maybe a -> Bool
Maybe.isJust (RewriteMonad extra (Maybe Term) -> RewriteMonad extra Bool)
-> (Term -> RewriteMonad extra (Maybe Term))
-> Term
-> RewriteMonad extra Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> RewriteMonad extra (Maybe Term)
forall extra. Term -> RewriteMonad extra (Maybe Term)
isInteresting) [Term]
lArgs
        if Bool
anyInteresting then Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Term -> Maybe Term
forall a. a -> Maybe a
Just Term
e) else Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Term
forall a. Maybe a
Nothing
      else
        Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Term
forall a. Maybe a
Nothing

  where
    interestingPrims :: [(OccName, Bool)]
interestingPrims =
      [("Clash.Sized.Internal.BitVector.*#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.BitVector.times#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.BitVector.quot#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.BitVector.rem#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Index.*#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.Index.quot#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Index.rem#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Signed.*#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.Signed.times#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.Signed.rem#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Signed.quot#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Signed.div#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Signed.mod#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Unsigned.*#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.Unsigned.times#",Bool
tailNonPow2)
      ,("Clash.Sized.Internal.Unsigned.quot#",Bool
lastNotPow2)
      ,("Clash.Sized.Internal.Unsigned.rem#",Bool
lastNotPow2)
      ,("GHC.Base.quotInt",Bool
lastNotPow2)
      ,("GHC.Base.remInt",Bool
lastNotPow2)
      ,("GHC.Base.divInt",Bool
lastNotPow2)
      ,("GHC.Base.modInt",Bool
lastNotPow2)
      ,("GHC.Classes.divInt#",Bool
lastNotPow2)
      ,("GHC.Classes.modInt#",Bool
lastNotPow2)
      ,("GHC.Integer.Type.timesInteger",Bool
allNonPow2)
      ,("GHC.Integer.Type.divInteger",Bool
lastNotPow2)
      ,("GHC.Integer.Type.modInteger",Bool
lastNotPow2)
      ,("GHC.Integer.Type.quotInteger",Bool
lastNotPow2)
      ,("GHC.Integer.Type.remInteger",Bool
lastNotPow2)
      ,("GHC.Prim.*#",Bool
allNonPow2)
      ,("GHC.Prim.quotInt#",Bool
lastNotPow2)
      ,("GHC.Prim.remInt#",Bool
lastNotPow2)
      ]

    lArgs :: [Term]
lArgs       = [Either Term Type] -> [Term]
forall a b. [Either a b] -> [a]
Either.lefts [Either Term Type]
args

    allNonPow2 :: Bool
allNonPow2  = (Term -> Bool) -> [Term] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (Term -> Bool) -> Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Bool
termIsPow2) [Term]
lArgs
    tailNonPow2 :: Bool
tailNonPow2 = case [Term]
lArgs of
                    [] -> Bool
True
                    _  -> (Term -> Bool) -> [Term] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool
not (Bool -> Bool) -> (Term -> Bool) -> Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Bool
termIsPow2) ([Term] -> [Term]
forall a. [a] -> [a]
tail [Term]
lArgs)
    lastNotPow2 :: Bool
lastNotPow2 = case [Term]
lArgs of
                    [] -> Bool
True
                    _  -> Bool -> Bool
not (Term -> Bool
termIsPow2 ([Term] -> Term
forall a. [a] -> a
last [Term]
lArgs))

    termIsPow2 :: Term -> Bool
termIsPow2 e' :: Term
e' = case Term -> Term
eval Term
e' of
      Literal (IntegerLiteral n :: Integer
n) -> Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPow2 Integer
n
      a :: Term
a -> case Term -> (Term, [Either Term Type])
collectArgs Term
a of
        (Prim nm' :: OccName
nm' _,[Right _,Left _,Left (Literal (IntegerLiteral n :: Integer
n))])
          | OccName -> Bool
forall a. (Eq a, IsString a) => a -> Bool
isFromInteger OccName
nm' -> Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPow2 Integer
n
        (Prim nm' :: OccName
nm' _,[Right _,Left _,Left _,Left (Literal (IntegerLiteral n :: Integer
n))])
          | OccName
nm' OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.fromInteger#"  -> Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPow2 Integer
n
        (Prim nm' :: OccName
nm' _,[Right _,       Left _,Left (Literal (IntegerLiteral n :: Integer
n))])
          | OccName
nm' OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== "Clash.Sized.Internal.BitVector.fromInteger##" -> Integer -> Bool
forall a. (Num a, Bits a) => a -> Bool
isPow2 Integer
n

        _ -> Bool
False

    isPow2 :: a -> Bool
isPow2 x :: a
x = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 Bool -> Bool -> Bool
&& (a
x a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a -> a
forall a. Bits a => a -> a
complement a
x a -> a -> a
forall a. Num a => a -> a -> a
+ 1)) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x

    isFromInteger :: a -> Bool
isFromInteger x :: a
x = a
x a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ["Clash.Sized.Internal.BitVector.fromInteger#"
                               ,"Clash.Sized.Integer.Index.fromInteger"
                               ,"Clash.Sized.Internal.Signed.fromInteger#"
                               ,"Clash.Sized.Internal.Unsigned.fromInteger#"
                               ]

    isHOTy :: Type -> Bool
isHOTy t :: Type
t = case Type -> ([Either TyVar Type], Type)
splitFunForallTy Type
t of
      (args' :: [Either TyVar Type]
args',_) -> (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
isPolyFunTy ([Either TyVar Type] -> [Type]
forall a b. [Either a b] -> [b]
Either.rights [Either TyVar Type]
args')

interestingToLift _ _ _ _ = Maybe Term -> RewriteMonad extra (Maybe Term)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Term
forall a. Maybe a
Nothing