{-# LANGUAGE FlexibleContexts #-}

-- | This module implements facilities for determining whether a
-- reduction or fold can be expressed in a closed form (i.e. not as a
-- SOAC).
--
-- Right now, the module can detect only trivial cases.  In the
-- future, we would like to make it more powerful, as well as possibly
-- also being able to analyse sequential loops.
module Futhark.Optimise.Simplify.Rules.ClosedForm
  ( foldClosedForm,
    loopClosedForm,
  )
where

import Control.Monad
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename

{-
Motivation:

  let {*[int,x_size_27] map_computed_shape_1286} = replicate(x_size_27,
                                                             all_equal_shape_1044) in
  let {*[bool,x_size_27] map_size_checks_1292} = replicate(x_size_27, x_1291) in
  let {bool all_equal_checked_1298, int all_equal_shape_1299} =
    reduceT(fn {bool, int} (bool bacc_1293, int nacc_1294, bool belm_1295,
                            int nelm_1296) =>
              let {bool tuplit_elems_1297} = bacc_1293 && belm_1295 in
              {tuplit_elems_1297, nelm_1296},
            {True, 0}, map_size_checks_1292, map_computed_shape_1286)
-}

-- | @foldClosedForm look foldfun accargs arrargs@ determines whether
-- each of the results of @foldfun@ can be expressed in a closed form.
foldClosedForm ::
  (ASTRep rep, BuilderOps rep) =>
  VarLookup rep ->
  Pat rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  RuleM rep ()
foldClosedForm :: VarLookup rep
-> Pat rep -> Lambda rep -> [SubExp] -> [VName] -> RuleM rep ()
foldClosedForm VarLookup rep
look Pat rep
pat Lambda rep
lam [SubExp]
accs [VName]
arrs = do
  SubExp
inputsize <- Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([TypeBase Shape NoUniqueness] -> SubExp)
-> RuleM rep [TypeBase Shape NoUniqueness] -> RuleM rep SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> RuleM rep (TypeBase Shape NoUniqueness))
-> [VName] -> RuleM rep [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RuleM rep (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
arrs

  PrimType
t <- case Pat rep -> [TypeBase Shape NoUniqueness]
forall dec. Typed dec => PatT dec -> [TypeBase Shape NoUniqueness]
patTypes Pat rep
pat of
    [Prim PrimType
t] -> PrimType -> RuleM rep PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
    [TypeBase Shape NoUniqueness]
_ -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify

  Body rep
closedBody <-
    [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
      (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat)
      SubExp
inputsize
      Names
forall a. Monoid a => a
mempty
      IntType
Int64
      Map VName SubExp
knownBnds
      ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam))
      (Lambda rep -> Body rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam)
      [SubExp]
accs
  VName
isEmpty <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
  [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat
    (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body rep -> Body rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If (VName -> SubExp
Var VName
isEmpty)
            (Body rep -> Body rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (Body rep)
-> RuleM rep (Body rep -> IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
accs
            RuleM rep (Body rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (Body rep)
-> RuleM rep (IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
            RuleM rep (IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType rep
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
        )
  where
    knownBnds :: Map VName SubExp
knownBnds = VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs

-- | @loopClosedForm pat respat merge bound bodys@ determines whether
-- the do-loop can be expressed in a closed form.
loopClosedForm ::
  (ASTRep rep, BuilderOps rep) =>
  Pat rep ->
  [(FParam rep, SubExp)] ->
  Names ->
  IntType ->
  SubExp ->
  Body rep ->
  RuleM rep ()
loopClosedForm :: Pat rep
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat rep
pat [(FParam rep, SubExp)]
merge Names
i IntType
it SubExp
bound Body rep
body = do
  PrimType
t <- case Pat rep -> [TypeBase Shape NoUniqueness]
forall dec. Typed dec => PatT dec -> [TypeBase Shape NoUniqueness]
patTypes Pat rep
pat of
    [Prim PrimType
t] -> PrimType -> RuleM rep PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
    [TypeBase Shape NoUniqueness]
_ -> RuleM rep PrimType
forall rep a. RuleM rep a
cannotSimplify

  Body rep
closedBody <-
    [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
      [VName]
mergenames
      SubExp
bound
      Names
i
      IntType
it
      Map VName SubExp
knownBnds
      ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
      Body rep
body
      [SubExp]
mergeexp
  VName
isEmpty <- String -> RuleM rep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
  [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)

  Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat
    (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body rep -> Body rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If (VName -> SubExp
Var VName
isEmpty)
            (Body rep -> Body rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (Body rep)
-> RuleM rep (Body rep -> IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
mergeexp
            RuleM rep (Body rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (Body rep)
-> RuleM rep (IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
            RuleM rep (IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType rep
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
        )
  where
    ([FParam rep]
mergepat, [SubExp]
mergeexp) = [(FParam rep, SubExp)] -> ([FParam rep], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
    mergeidents :: [Ident]
mergeidents = (FParam rep -> Ident) -> [FParam rep] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map FParam rep -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam rep]
mergepat
    mergenames :: [VName]
mergenames = (FParam rep -> VName) -> [FParam rep] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam rep -> VName
forall dec. Param dec -> VName
paramName [FParam rep]
mergepat
    knownBnds :: Map VName SubExp
knownBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp

checkResults ::
  BuilderOps rep =>
  [VName] ->
  SubExp ->
  Names ->
  IntType ->
  M.Map VName SubExp ->
  -- | Lambda-bound
  [VName] ->
  Body rep ->
  [SubExp] ->
  RuleM rep (Body rep)
checkResults :: [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body rep
body [SubExp]
accs = do
  ((), Stms rep
stms) <-
    RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep)))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep))))
-> RuleM rep () -> RuleM rep ((), Stms (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
      ((VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ())
-> [(VName, SubExpRes)] -> [(VName, SubExp)] -> RuleM rep ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
forall rep.
BuilderOps rep =>
(VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult ([VName] -> [SubExpRes] -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat [SubExpRes]
res) ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
  Stms (Rep (RuleM rep))
-> [SubExpRes] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> [SubExpRes] -> m (Body (Rep m))
mkBodyM Stms rep
Stms (Rep (RuleM rep))
stms ([SubExpRes] -> RuleM rep (Body (Rep (RuleM rep))))
-> [SubExpRes] -> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes [VName]
pat
  where
    stmMap :: Map VName (Exp rep)
stmMap = Body rep -> Map VName (Exp rep)
forall rep. Body rep -> Map VName (Exp rep)
makeBindMap Body rep
body
    ([VName]
accparams, [VName]
_) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
    res :: [SubExpRes]
res = Body rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult Body rep
body

    nonFree :: Names
nonFree = Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
untouchable

    checkResult :: (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (VName
p, SubExpRes Certs
_ (Var VName
v)) (VName
accparam, SubExp
acc)
      | Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- VName -> Map VName (Exp rep) -> Maybe (Exp rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp rep)
stmMap,
        SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SubExp
y = do
        -- One of x,y must be *this* accumulator, and the other must
        -- be something that is free in the body.
        let isThisAccum :: SubExp -> Bool
isThisAccum = (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
        (SubExp
this, SubExp
el) <- Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp)
forall a rep. Maybe a -> RuleM rep a
liftMaybe (Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp))
-> Maybe (SubExp, SubExp) -> RuleM rep (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
          case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
                 (SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
               ) of
            ((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
            ((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
            ((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> Maybe (SubExp, SubExp)
forall a. Maybe a
Nothing

        case BinOp
bop of
          BinOp
LogAnd ->
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
          Add IntType
t Overflow
w -> do
            SubExp
size' <- IntType -> SubExp -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
t SubExp
size
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
              (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                (IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
                (SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
                (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
          FAdd FloatType
t | Just RuleM rep SubExp
properly_typed_size <- FloatType -> Maybe (RuleM rep SubExp)
forall (m :: * -> *).
MonadBuilder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
            SubExp
size' <- RuleM rep SubExp
properly_typed_size
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
              (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                (FloatType -> BinOp
FAdd FloatType
t)
                (SubExp -> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
                (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
          BinOp
_ -> RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify -- Um... sorry.
    checkResult (VName, SubExpRes)
_ (VName, SubExp)
_ = RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

    asFreeSubExp :: SubExp -> Maybe SubExp
    asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
      | VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
    asFreeSubExp SubExp
se = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se

    properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
      m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"converted_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size

determineKnownBindings ::
  VarLookup rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  M.Map VName SubExp
determineKnownBindings :: VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs =
  Map VName SubExp
accBnds Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
  where
    ([Param (LParamInfo rep)]
accparams, [Param (LParamInfo rep)]
arrparams) =
      Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
    accBnds :: Map VName SubExp
accBnds =
      [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
        [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
accparams) [SubExp]
accs
    arrBnds :: Map VName SubExp
arrBnds =
      [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
        ((VName, VName) -> Maybe (VName, SubExp))
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, VName) -> Maybe (VName, SubExp)
forall a. (a, VName) -> Maybe (a, SubExp)
isReplicate ([(VName, VName)] -> [(VName, SubExp)])
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
          [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
arrparams) [VName]
arrs

    isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
      | Just (BasicOp (Replicate Shape
_ SubExp
ve), Certs
cs) <- VarLookup rep
look VName
v,
        Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty =
        (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
    isReplicate (a, VName)
_ = Maybe (a, SubExp)
forall a. Maybe a
Nothing

makeBindMap :: Body rep -> M.Map VName (Exp rep)
makeBindMap :: Body rep -> Map VName (Exp rep)
makeBindMap = [(VName, Exp rep)] -> Map VName (Exp rep)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Exp rep)] -> Map VName (Exp rep))
-> (Body rep -> [(VName, Exp rep)])
-> Body rep
-> Map VName (Exp rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Maybe (VName, Exp rep))
-> [Stm rep] -> [(VName, Exp rep)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm rep -> Maybe (VName, Exp rep)
forall rep. Stm rep -> Maybe (VName, Exp rep)
isSingletonStm ([Stm rep] -> [(VName, Exp rep)])
-> (Body rep -> [Stm rep]) -> Body rep -> [(VName, Exp rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep])
-> (Body rep -> Stms rep) -> Body rep -> [Stm rep]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms
  where
    isSingletonStm :: Stm rep -> Maybe (VName, Exp rep)
isSingletonStm (Let Pat rep
pat StmAux (ExpDec rep)
_ Exp rep
e) = case Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat of
      [VName
v] -> (VName, Exp rep) -> Maybe (VName, Exp rep)
forall a. a -> Maybe a
Just (VName
v, Exp rep
e)
      [VName]
_ -> Maybe (VName, Exp rep)
forall a. Maybe a
Nothing