-- | This module implements facilities for determining whether a
-- reduction or fold can be expressed in a closed form (i.e. not as a
-- SOAC).
--
-- Right now, the module can detect only trivial cases.  In the
-- future, we would like to make it more powerful, as well as possibly
-- also being able to analyse sequential loops.
module Futhark.Optimise.Simplify.Rules.ClosedForm
  ( foldClosedForm,
    loopClosedForm,
  )
where

import Control.Monad
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename

{-
Motivation:

  let {*[int,x_size_27] map_computed_shape_1286} = replicate(x_size_27,
                                                             all_equal_shape_1044) in
  let {*[bool,x_size_27] map_size_checks_1292} = replicate(x_size_27, x_1291) in
  let {bool all_equal_checked_1298, int all_equal_shape_1299} =
    reduceT(fn {bool, int} (bool bacc_1293, int nacc_1294, bool belm_1295,
                            int nelm_1296) =>
              let {bool tuplit_elems_1297} = bacc_1293 && belm_1295 in
              {tuplit_elems_1297, nelm_1296},
            {True, 0}, map_size_checks_1292, map_computed_shape_1286)
-}

-- | @foldClosedForm look foldfun accargs arrargs@ determines whether
-- each of the results of @foldfun@ can be expressed in a closed form.
foldClosedForm ::
  (BuilderOps rep) =>
  VarLookup rep ->
  Pat (LetDec rep) ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  RuleM rep ()
foldClosedForm :: forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm VarLookup rep
look Pat (LetDec rep)
pat Lambda rep
lam [SubExp]
accs [VName]
arrs = do
  SubExp
inputsize <- forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs

  PrimType
t <- case forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat of
    [Prim PrimType
t] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
    [Type]
_ -> forall rep a. RuleM rep a
cannotSimplify

  Body rep
closedBody <-
    forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
      (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
      SubExp
inputsize
      forall a. Monoid a => a
mempty
      IntType
Int64
      Map VName SubExp
knownBnds
      (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam))
      (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
      [SubExp]
accs
  VName
isEmpty <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] forall a b. (a -> b) -> a -> b
$
    forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
accs)
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
        )
  where
    knownBnds :: Map VName SubExp
knownBnds = forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs

-- | @loopClosedForm pat respat merge bound bodys@ determines whether
-- the do-loop can be expressed in a closed form.
loopClosedForm ::
  (BuilderOps rep) =>
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  Names ->
  IntType ->
  SubExp ->
  Body rep ->
  RuleM rep ()
loopClosedForm :: forall rep.
BuilderOps rep =>
Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Names
i IntType
it SubExp
bound Body rep
body = do
  PrimType
t <- case forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat of
    [Prim PrimType
t] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
    [Type]
_ -> forall rep a. RuleM rep a
cannotSimplify

  Body rep
closedBody <-
    forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
      [VName]
mergenames
      SubExp
bound
      Names
i
      IntType
it
      Map VName SubExp
knownBnds
      (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
      Body rep
body
      [SubExp]
mergeexp
  VName
isEmpty <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] forall a b. (a -> b) -> a -> b
$
    forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)

  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just (Bool -> PrimValue
BoolValue Bool
True)] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
mergeexp)
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
        )
  where
    ([FParam rep]
mergepat, [SubExp]
mergeexp) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
    mergeidents :: [Ident]
mergeidents = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam rep]
mergepat
    mergenames :: [VName]
mergenames = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [FParam rep]
mergepat
    knownBnds :: Map VName SubExp
knownBnds = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp

checkResults ::
  BuilderOps rep =>
  [VName] ->
  SubExp ->
  Names ->
  IntType ->
  M.Map VName SubExp ->
  -- | Lambda-bound
  [VName] ->
  Body rep ->
  [SubExp] ->
  RuleM rep (Body rep)
checkResults :: forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body rep
body [SubExp]
accs = do
  ((), Stms rep
stms) <-
    forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {rep}.
BuilderOps rep =>
(VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat Result
res) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
  forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms rep
stms forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
pat
  where
    stmMap :: Map VName (Exp rep)
stmMap = forall rep. Body rep -> Map VName (Exp rep)
makeBindMap Body rep
body
    ([VName]
accparams, [VName]
_) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
    res :: Result
res = forall rep. Body rep -> Result
bodyResult Body rep
body

    nonFree :: Names
nonFree = forall rep. Body rep -> Names
boundInBody Body rep
body forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params forall a. Semigroup a => a -> a -> a
<> Names
untouchable

    checkResult :: (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (VName
p, SubExpRes Certs
_ (Var VName
v)) (VName
accparam, SubExp
acc)
      | Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp rep)
stmMap,
        SubExp
x forall a. Eq a => a -> a -> Bool
/= SubExp
y = do
          -- One of x,y must be *this* accumulator, and the other must
          -- be something that is free in the body.
          let isThisAccum :: SubExp -> Bool
isThisAccum = (forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
          (SubExp
this, SubExp
el) <- forall a rep. Maybe a -> RuleM rep a
liftMaybe forall a b. (a -> b) -> a -> b
$
            case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
                   (SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
                 ) of
              ((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
              ((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
              ((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> forall a. Maybe a
Nothing

          case BinOp
bop of
            BinOp
LogAnd ->
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
            Add IntType
t Overflow
w -> do
              SubExp
size' <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
t SubExp
size
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                  (IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
                  (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
                  (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
            FAdd FloatType
t | Just RuleM rep SubExp
properly_typed_size <- forall {m :: * -> *}.
MonadBuilder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
              SubExp
size' <- RuleM rep SubExp
properly_typed_size
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                  (FloatType -> BinOp
FAdd FloatType
t)
                  (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
                  (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
            BinOp
_ -> forall rep a. RuleM rep a
cannotSimplify -- Um... sorry.
    checkResult (VName, SubExpRes)
_ (VName, SubExp)
_ = forall rep a. RuleM rep a
cannotSimplify

    asFreeSubExp :: SubExp -> Maybe SubExp
    asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
      | VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
    asFreeSubExp SubExp
se = forall a. a -> Maybe a
Just SubExp
se

    properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"converted_size" forall a b. (a -> b) -> a -> b
$
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size

determineKnownBindings ::
  VarLookup rep ->
  Lambda rep ->
  [SubExp] ->
  [VName] ->
  M.Map VName SubExp
determineKnownBindings :: forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs =
  Map VName SubExp
accBnds forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
  where
    ([Param (LParamInfo rep)]
accparams, [Param (LParamInfo rep)]
arrparams) =
      forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
    accBnds :: Map VName SubExp
accBnds =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
accparams) [SubExp]
accs
    arrBnds :: Map VName SubExp
arrBnds =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a}. (a, VName) -> Maybe (a, SubExp)
isReplicate forall a b. (a -> b) -> a -> b
$
          forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
arrparams) [VName]
arrs

    isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
      | Just (BasicOp (Replicate Shape
_ SubExp
ve), Certs
cs) <- VarLookup rep
look VName
v,
        Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
          forall a. a -> Maybe a
Just (a
p, SubExp
ve)
    isReplicate (a, VName)
_ = forall a. Maybe a
Nothing

makeBindMap :: Body rep -> M.Map VName (Exp rep)
makeBindMap :: forall rep. Body rep -> Map VName (Exp rep)
makeBindMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {rep}. Stm rep -> Maybe (VName, Exp rep)
isSingletonStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
  where
    isSingletonStm :: Stm rep -> Maybe (VName, Exp rep)
isSingletonStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) = case forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat of
      [VName
v] -> forall a. a -> Maybe a
Just (VName
v, Exp rep
e)
      [VName]
_ -> forall a. Maybe a
Nothing