{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Safe #-}
-- | This module implements facilities for determining whether a
-- reduction or fold can be expressed in a closed form (i.e. not as a
-- SOAC).
--
-- Right now, the module can detect only trivial cases.  In the
-- future, we would like to make it more powerful, as well as possibly
-- also being able to analyse sequential loops.
module Futhark.Optimise.Simplify.ClosedForm
  ( foldClosedForm
  , loopClosedForm
  , VarLookup
  )
where

import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as M

import Futhark.Construct
import Futhark.IR
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Rule

-- | A function that, given a variable name, returns its definition.
type VarLookup lore = VName -> Maybe (Exp lore, Certificates)

{-
Motivation:

  let {*[int,x_size_27] map_computed_shape_1286} = replicate(x_size_27,
                                                             all_equal_shape_1044) in
  let {*[bool,x_size_27] map_size_checks_1292} = replicate(x_size_27, x_1291) in
  let {bool all_equal_checked_1298, int all_equal_shape_1299} =
    reduceT(fn {bool, int} (bool bacc_1293, int nacc_1294, bool belm_1295,
                            int nelm_1296) =>
              let {bool tuplit_elems_1297} = bacc_1293 && belm_1295 in
              {tuplit_elems_1297, nelm_1296},
            {True, 0}, map_size_checks_1292, map_computed_shape_1286)
-}

-- | @foldClosedForm look foldfun accargs arrargs@ determines whether
-- each of the results of @foldfun@ can be expressed in a closed form.
foldClosedForm :: (ASTLore lore, BinderOps lore) =>
                  VarLookup lore
               -> Pattern lore
               -> Lambda lore
               -> [SubExp] -> [VName]
               -> RuleM lore ()

foldClosedForm :: VarLookup lore
-> Pattern lore
-> Lambda lore
-> [SubExp]
-> [VName]
-> RuleM lore ()
foldClosedForm VarLookup lore
look Pattern lore
pat Lambda lore
lam [SubExp]
accs [VName]
arrs = do
  SubExp
inputsize <- Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([TypeBase Shape NoUniqueness] -> SubExp)
-> RuleM lore [TypeBase Shape NoUniqueness] -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> RuleM lore (TypeBase Shape NoUniqueness))
-> [VName] -> RuleM lore [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
arrs

  PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of [Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
                                [TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify

  Body lore
closedBody <- [VName]
-> SubExp
-> Names
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) SubExp
inputsize Names
forall a. Monoid a => a
mempty Map VName SubExp
knownBnds
                ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam))
                (Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) [SubExp]
accs
  VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int32) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0)
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
                    (Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
accs
                    RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
                    RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal))
  where knownBnds :: Map VName SubExp
knownBnds = VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
forall lore.
VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs

-- | @loopClosedForm pat respat merge bound bodys@ determines whether
-- the do-loop can be expressed in a closed form.
loopClosedForm :: (ASTLore lore, BinderOps lore) =>
                  Pattern lore
               -> [(FParam lore,SubExp)]
               -> Names -> SubExp -> Body lore
               -> RuleM lore ()
loopClosedForm :: Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
merge Names
i SubExp
bound Body lore
body = do
  PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of [Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
                                [TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify

  Body lore
closedBody <- [VName]
-> SubExp
-> Names
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults [VName]
mergenames SubExp
bound Names
i Map VName SubExp
knownBnds
                ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents) Body lore
body [SubExp]
mergeexp
  VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int32) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0)

  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
                    (Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
mergeexp
                    RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
                    RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal))
  where ([FParam lore]
mergepat, [SubExp]
mergeexp) = [(FParam lore, SubExp)] -> ([FParam lore], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
merge
        mergeidents :: [Ident]
mergeidents = (FParam lore -> Ident) -> [FParam lore] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam lore]
mergepat
        mergenames :: [VName]
mergenames = (FParam lore -> VName) -> [FParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> VName
forall dec. Param dec -> VName
paramName [FParam lore]
mergepat
        knownBnds :: Map VName SubExp
knownBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp

checkResults :: BinderOps lore =>
                [VName]
             -> SubExp
             -> Names
             -> M.Map VName SubExp
             -> [VName] -- ^ Lambda-bound
             -> Body lore
             -> [SubExp]
             -> RuleM lore (Body lore)
checkResults :: [VName]
-> SubExp
-> Names
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults [VName]
pat SubExp
size Names
untouchable Map VName SubExp
knownBnds [VName]
params Body lore
body [SubExp]
accs = do
  ((), Stms lore
bnds) <- RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
                ((VName, SubExp) -> (VName, SubExp) -> RuleM lore ())
-> [(VName, SubExp)] -> [(VName, SubExp)] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
forall lore.
BinderOps lore =>
(VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat [SubExp]
res) ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
  Stms (Lore (RuleM lore))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms lore
Stms (Lore (RuleM lore))
bnds ([SubExp] -> RuleM lore (Body (Lore (RuleM lore))))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
pat

  where bndMap :: Map VName (Exp lore)
bndMap = Body lore -> Map VName (Exp lore)
forall lore. Body lore -> Map VName (Exp lore)
makeBindMap Body lore
body
        ([VName]
accparams, [VName]
_) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
        res :: [SubExp]
res = Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body

        nonFree :: Names
nonFree = Body lore -> Names
forall lore. Body lore -> Names
boundInBody Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
untouchable

        checkResult :: (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult (VName
p, Var VName
v) (VName
accparam, SubExp
acc)
          | Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- VName -> Map VName (Exp lore) -> Maybe (Exp lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp lore)
bndMap = do
          -- One of x,y must be *this* accumulator, and the other must
          -- be something that is free in the body.
          let isThisAccum :: SubExp -> Bool
isThisAccum = (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==VName -> SubExp
Var VName
accparam)
          (SubExp
this, SubExp
el) <- Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a lore. Maybe a -> RuleM lore a
liftMaybe (Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp))
-> Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
                        case ((SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
                              (SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)) of
                          ((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
                          ((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
                          ((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_                      -> Maybe (SubExp, SubExp)
forall a. Maybe a
Nothing

          case BinOp
bop of
              BinOp
LogAnd ->
                [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
              Add IntType
t Overflow
w | Just RuleM lore SubExp
properly_typed_size <- IntType -> Maybe (RuleM lore SubExp)
forall (m :: * -> *). MonadBinder m => IntType -> Maybe (m SubExp)
properIntSize IntType
t -> do
                          SubExp
size' <- RuleM lore SubExp
properly_typed_size
                          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p] (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                            BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
t Overflow
w) (SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
                            (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
              FAdd FloatType
t | Just RuleM lore SubExp
properly_typed_size <- FloatType -> Maybe (RuleM lore SubExp)
forall (m :: * -> *).
MonadBinder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
                        SubExp
size' <- RuleM lore SubExp
properly_typed_size
                        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p] (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                          BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (FloatType -> BinOp
FAdd FloatType
t) (SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
                          (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
              BinOp
_ -> RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify -- Um... sorry.

        checkResult (VName, SubExp)
_ (VName, SubExp)
_ = RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

        asFreeSubExp :: SubExp -> Maybe SubExp
        asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
          | VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
        asFreeSubExp SubExp
se = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se

        properIntSize :: IntType -> Maybe (m SubExp)
properIntSize IntType
Int32 = m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$ SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size
        properIntSize IntType
t = m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
t) SubExp
size

        properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
          m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
Int32 FloatType
t) SubExp
size

determineKnownBindings :: VarLookup lore -> Lambda lore -> [SubExp] -> [VName]
                       -> M.Map VName SubExp
determineKnownBindings :: VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs =
  Map VName SubExp
accBnds Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
  where ([Param (LParamInfo lore)]
accparams, [Param (LParamInfo lore)]
arrparams) =
          Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
        accBnds :: Map VName SubExp
accBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
                  [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
accparams) [SubExp]
accs
        arrBnds :: Map VName SubExp
arrBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ ((VName, VName) -> Maybe (VName, SubExp))
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, VName) -> Maybe (VName, SubExp)
forall a. (a, VName) -> Maybe (a, SubExp)
isReplicate ([(VName, VName)] -> [(VName, SubExp)])
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
                  [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
arrparams) [VName]
arrs

        isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
          | Just (BasicOp (Replicate Shape
_ SubExp
ve), Certificates
cs) <- VarLookup lore
look VName
v,
            Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty = (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
        isReplicate (a, VName)
_ = Maybe (a, SubExp)
forall a. Maybe a
Nothing

makeBindMap :: Body lore -> M.Map VName (Exp lore)
makeBindMap :: Body lore -> Map VName (Exp lore)
makeBindMap = [(VName, Exp lore)] -> Map VName (Exp lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Exp lore)] -> Map VName (Exp lore))
-> (Body lore -> [(VName, Exp lore)])
-> Body lore
-> Map VName (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm lore -> Maybe (VName, Exp lore))
-> [Stm lore] -> [(VName, Exp lore)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm lore -> Maybe (VName, Exp lore)
forall lore. Stm lore -> Maybe (VName, Exp lore)
isSingletonStm ([Stm lore] -> [(VName, Exp lore)])
-> (Body lore -> [Stm lore]) -> Body lore -> [(VName, Exp lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore])
-> (Body lore -> Stms lore) -> Body lore -> [Stm lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms
  where isSingletonStm :: Stm lore -> Maybe (VName, Exp lore)
isSingletonStm (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) = case Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat of
          [VName
v] -> (VName, Exp lore) -> Maybe (VName, Exp lore)
forall a. a -> Maybe a
Just (VName
v,Exp lore
e)
          [VName]
_   -> Maybe (VName, Exp lore)
forall a. Maybe a
Nothing