-- | Facilities for composing SOAC functions.  Mostly intended for use
-- by the fusion module, but factored into a separate module for ease
-- of testing, debugging and development.  Of course, there is nothing
-- preventing you from using the exported functions whereever you
-- want.
--
-- Important: this module is \"dumb\" in the sense that it does not
-- check the validity of its inputs, and does not have any
-- functionality for massaging SOACs to be fusible.  It is assumed
-- that the given SOACs are immediately compatible.
--
-- The module will, however, remove duplicate inputs after fusion.
module Futhark.Optimise.Fusion.Composing
  ( fuseMaps,
    fuseRedomap,
    mergeReduceOps,
  )
where

import Data.List (mapAccumL)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Builder (Buildable (..), insertStm, insertStms, mkLet)
import Futhark.Construct (mapResult)
import Futhark.IR
import Futhark.Util (dropLast, splitAt3, takeLast)

-- | @fuseMaps lam1 inp1 out1 lam2 inp2@ fuses the function @lam1@ into
-- @lam2@.  Both functions must be mapping functions, although @lam2@
-- may have leading reduction parameters.  @inp1@ and @inp2@ are the
-- array inputs to the SOACs containing @lam1@ and @lam2@
-- respectively.  @out1@ are the identifiers to which the output of
-- the SOAC containing @lam1@ is bound.  It is nonsensical to call
-- this function unless the intersection of @out1@ and @inp2@ is
-- non-empty.
--
-- If @lam2@ accepts more parameters than there are elements in
-- @inp2@, it is assumed that the surplus (which are positioned at the
-- beginning of the parameter list) are reduction (accumulator)
-- parameters, that do not correspond to array elements, and they are
-- thus not modified.
--
-- The result is the fused function, and a list of the array inputs
-- expected by the SOAC containing the fused function.
fuseMaps ::
  Buildable rep =>
  -- | The producer var names that still need to be returned
  Names ->
  -- | Function of SOAC to be fused.
  Lambda rep ->
  -- | Input of SOAC to be fused.
  [SOAC.Input] ->
  -- | Output of SOAC to be fused.  The
  -- first identifier is the name of the
  -- actual output, where the second output
  -- is an identifier that can be used to
  -- bind a single element of that output.
  [(VName, Ident)] ->
  -- | Function to be fused with.
  Lambda rep ->
  -- | Input of SOAC to be fused with.
  [SOAC.Input] ->
  -- | The fused lambda and the inputs of
  -- the resulting SOAC.
  (Lambda rep, [SOAC.Input])
fuseMaps :: Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2 = (Lambda rep
lam2', Map Ident Input -> [Input]
forall k a. Map k a -> [a]
M.elems Map Ident Input
inputmap)
  where
    lam2' :: Lambda rep
lam2' =
      Lambda rep
lam2
        { lambdaParams :: [LParam rep]
lambdaParams =
            [ Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
name Type
t
              | Ident VName
name Type
t <- [Ident]
lam2redparams [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ Map Ident Input -> [Ident]
forall k a. Map k a -> [k]
M.keys Map Ident Input
inputmap
            ],
          lambdaBody :: Body rep
lambdaBody = Body rep
new_body2'
        }
    new_body2 :: Body rep
new_body2 =
      let stms :: [SubExpRes] -> [Stm rep]
stms [SubExpRes]
res =
            [ Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident
p] (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e
              | (Ident
p, SubExpRes Certs
cs SubExp
e) <- [Ident] -> [SubExpRes] -> [(Ident, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
pat [SubExpRes]
res
            ]
          bindLambda :: [SubExpRes] -> Body rep
bindLambda [SubExpRes]
res =
            [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([SubExpRes] -> [Stm rep]
forall rep. Buildable rep => [SubExpRes] -> [Stm rep]
stms [SubExpRes]
res) Stms rep -> Body rep -> Body rep
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
`insertStms` Body rep -> Body rep
makeCopiesInner (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam2)
       in Body rep -> Body rep
makeCopies (Body rep -> Body rep) -> Body rep -> Body rep
forall a b. (a -> b) -> a -> b
$ ([SubExpRes] -> Body rep) -> Body rep -> Body rep
forall rep.
Buildable rep =>
([SubExpRes] -> Body rep) -> Body rep -> Body rep
mapResult [SubExpRes] -> Body rep
bindLambda (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam1)
    new_body2_rses :: [SubExpRes]
new_body2_rses = Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
new_body2
    new_body2' :: Body rep
new_body2' =
      Body rep
new_body2 {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
new_body2_rses [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ (Ident -> SubExpRes) -> [Ident] -> [SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExpRes
varRes (VName -> SubExpRes) -> (Ident -> VName) -> Ident -> SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
unfus_pat}
    -- infusible variables are added at the end of the result/pattern/type
    ([Ident]
lam2redparams, [Ident]
unfus_pat, [Ident]
pat, Map Ident Input
inputmap, Body rep -> Body rep
makeCopies, Body rep -> Body rep
makeCopiesInner) =
      Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
    Body rep -> Body rep, Body rep -> Body rep)
forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
    Body rep -> Body rep, Body rep -> Body rep)
fuseInputs Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2

-- (unfus_accpat, unfus_arrpat) = splitAt (length unfus_accs) unfus_pat

fuseInputs ::
  Buildable rep =>
  Names ->
  Lambda rep ->
  [SOAC.Input] ->
  [(VName, Ident)] ->
  Lambda rep ->
  [SOAC.Input] ->
  ( [Ident],
    [Ident],
    [Ident],
    M.Map Ident SOAC.Input,
    Body rep -> Body rep,
    Body rep -> Body rep
  )
fuseInputs :: Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> ([Ident], [Ident], [Ident], Map Ident Input,
    Body rep -> Body rep, Body rep -> Body rep)
fuseInputs Names
unfus_nms Lambda rep
lam1 [Input]
inp1 [(VName, Ident)]
out1 Lambda rep
lam2 [Input]
inp2 =
  ([Ident]
lam2redparams, [Ident]
unfus_vars, [Ident]
outstms, Map Ident Input
inputmap, Body rep -> Body rep
makeCopies, Body rep -> Body rep
makeCopiesInner)
  where
    ([Ident]
lam2redparams, [Ident]
lam2arrparams) =
      Int -> [Ident] -> ([Ident], [Ident])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
lam2params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
inp2) [Ident]
lam2params
    lam1params :: [Ident]
lam1params = (Param (LParamInfo rep) -> Ident)
-> [Param (LParamInfo rep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent ([Param (LParamInfo rep)] -> [Ident])
-> [Param (LParamInfo rep)] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam1
    lam2params :: [Ident]
lam2params = (Param (LParamInfo rep) -> Ident)
-> [Param (LParamInfo rep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent ([Param (LParamInfo rep)] -> [Ident])
-> [Param (LParamInfo rep)] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam2
    lam1inputmap :: Map Ident Input
lam1inputmap = [(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam1params [Input]
inp1
    lam2inputmap :: Map Ident Input
lam2inputmap = [(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
    (Map Ident Input
lam2inputmap', Body rep -> Body rep
makeCopiesInner) = Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
forall rep.
Buildable rep =>
Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs Map Ident Input
lam2inputmap
    originputmap :: Map Ident Input
originputmap = Map Ident Input
lam1inputmap Map Ident Input -> Map Ident Input -> Map Ident Input
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map Ident Input
lam2inputmap'
    outins :: Map Ident Input
outins =
      ([Ident] -> [Input] -> Map Ident Input)
-> ([Ident], [Input]) -> Map Ident Input
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([VName] -> [Ident] -> [Input] -> Map Ident Input
outParams ([VName] -> [Ident] -> [Input] -> Map Ident Input)
-> [VName] -> [Ident] -> [Input] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ ((VName, Ident) -> VName) -> [(VName, Ident)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Ident) -> VName
forall a b. (a, b) -> a
fst [(VName, Ident)]
out1) (([Ident], [Input]) -> Map Ident Input)
-> ([Ident], [Input]) -> Map Ident Input
forall a b. (a -> b) -> a -> b
$
        [(Ident, Input)] -> ([Ident], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Ident, Input)] -> ([Ident], [Input]))
-> [(Ident, Input)] -> ([Ident], [Input])
forall a b. (a -> b) -> a -> b
$ Map Ident Input -> [(Ident, Input)]
forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
lam2inputmap'
    outstms :: [Ident]
outstms = [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins
    (Map Ident Input
inputmap, Body rep -> Body rep
makeCopies) =
      Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
forall rep.
Buildable rep =>
Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs (Map Ident Input -> (Map Ident Input, Body rep -> Body rep))
-> Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
forall a b. (a -> b) -> a -> b
$ Map Ident Input
originputmap Map Ident Input -> Map Ident Input -> Map Ident Input
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Map Ident Input
outins
    -- Cosmin: @unfus_vars@ is supposed to be the lam2 vars corresponding to unfus_nms (?)
    getVarParPair :: (b, Input) -> Maybe (VName, b)
getVarParPair (b, Input)
x = case Input -> Maybe VName
SOAC.isVarInput ((b, Input) -> Input
forall a b. (a, b) -> b
snd (b, Input)
x) of
      Just VName
nm -> (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
nm, (b, Input) -> b
forall a b. (a, b) -> a
fst (b, Input)
x)
      Maybe VName
Nothing -> Maybe (VName, b)
forall a. Maybe a
Nothing -- should not be reached!
    outinsrev :: Map VName Ident
outinsrev = [(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ ((Ident, Input) -> Maybe (VName, Ident))
-> [(Ident, Input)] -> [(VName, Ident)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Ident, Input) -> Maybe (VName, Ident)
forall b. (b, Input) -> Maybe (VName, b)
getVarParPair ([(Ident, Input)] -> [(VName, Ident)])
-> [(Ident, Input)] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ Map Ident Input -> [(Ident, Input)]
forall k a. Map k a -> [(k, a)]
M.toList Map Ident Input
outins
    unfusible :: VName -> Maybe Ident
unfusible VName
outname
      | VName
outname VName -> Names -> Bool
`nameIn` Names
unfus_nms =
          VName
outname VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName Ident -> Map VName Ident -> Map VName Ident
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map VName Ident
outinsrev ([(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, Ident)]
out1)
    unfusible VName
_ = Maybe Ident
forall a. Maybe a
Nothing
    unfus_vars :: [Ident]
unfus_vars = ((VName, Ident) -> Maybe Ident) -> [(VName, Ident)] -> [Ident]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Maybe Ident
unfusible (VName -> Maybe Ident)
-> ((VName, Ident) -> VName) -> (VName, Ident) -> Maybe Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Ident) -> VName
forall a b. (a, b) -> a
fst) [(VName, Ident)]
out1

outParams ::
  [VName] ->
  [Ident] ->
  [SOAC.Input] ->
  M.Map Ident SOAC.Input
outParams :: [VName] -> [Ident] -> [Input] -> Map Ident Input
outParams [VName]
out1 [Ident]
lam2arrparams [Input]
inp2 =
  [(Ident, Input)] -> Map Ident Input
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ident, Input)] -> Map Ident Input)
-> [(Ident, Input)] -> Map Ident Input
forall a b. (a -> b) -> a -> b
$ ((Ident, Input) -> Maybe (Ident, Input))
-> [(Ident, Input)] -> [(Ident, Input)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Ident, Input) -> Maybe (Ident, Input)
forall a. (a, Input) -> Maybe (a, Input)
isOutParam ([(Ident, Input)] -> [(Ident, Input)])
-> [(Ident, Input)] -> [(Ident, Input)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Input] -> [(Ident, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
lam2arrparams [Input]
inp2
  where
    isOutParam :: (a, Input) -> Maybe (a, Input)
isOutParam (a
p, Input
inp)
      | Just VName
a <- Input -> Maybe VName
SOAC.isVarInput Input
inp,
        VName
a VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
out1 =
          (a, Input) -> Maybe (a, Input)
forall a. a -> Maybe a
Just (a
p, Input
inp)
    isOutParam (a, Input)
_ = Maybe (a, Input)
forall a. Maybe a
Nothing

filterOutParams ::
  [(VName, Ident)] ->
  M.Map Ident SOAC.Input ->
  [Ident]
filterOutParams :: [(VName, Ident)] -> Map Ident Input -> [Ident]
filterOutParams [(VName, Ident)]
out1 Map Ident Input
outins =
  (Map VName [Ident], [Ident]) -> [Ident]
forall a b. (a, b) -> b
snd ((Map VName [Ident], [Ident]) -> [Ident])
-> (Map VName [Ident], [Ident]) -> [Ident]
forall a b. (a -> b) -> a -> b
$ (Map VName [Ident] -> (VName, Ident) -> (Map VName [Ident], Ident))
-> Map VName [Ident]
-> [(VName, Ident)]
-> (Map VName [Ident], [Ident])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Map VName [Ident] -> (VName, Ident) -> (Map VName [Ident], Ident)
forall k b. Ord k => Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map VName [Ident]
outUsage [(VName, Ident)]
out1
  where
    outUsage :: Map VName [Ident]
outUsage = (Map VName [Ident] -> Ident -> Input -> Map VName [Ident])
-> Map VName [Ident] -> Map Ident Input -> Map VName [Ident]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' Map VName [Ident] -> Ident -> Input -> Map VName [Ident]
forall a. Map VName [a] -> a -> Input -> Map VName [a]
add Map VName [Ident]
forall k a. Map k a
M.empty Map Ident Input
outins
      where
        add :: Map VName [a] -> a -> Input -> Map VName [a]
add Map VName [a]
m a
p Input
inp =
          case Input -> Maybe VName
SOAC.isVarInput Input
inp of
            Just VName
v -> ([a] -> [a] -> [a])
-> VName -> [a] -> Map VName [a] -> Map VName [a]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++) VName
v [a
p] Map VName [a]
m
            Maybe VName
Nothing -> Map VName [a]
m

    checkUsed :: Map k [b] -> (k, b) -> (Map k [b], b)
checkUsed Map k [b]
m (k
a, b
ra) =
      case k -> Map k [b] -> Maybe [b]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
a Map k [b]
m of
        Just (b
p : [b]
ps) -> (k -> [b] -> Map k [b] -> Map k [b]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
a [b]
ps Map k [b]
m, b
p)
        Maybe [b]
_ -> (Map k [b]
m, b
ra)

removeDuplicateInputs ::
  Buildable rep =>
  M.Map Ident SOAC.Input ->
  (M.Map Ident SOAC.Input, Body rep -> Body rep)
removeDuplicateInputs :: Map Ident Input -> (Map Ident Input, Body rep -> Body rep)
removeDuplicateInputs = ((Map Ident Input, Body rep -> Body rep), Map Input VName)
-> (Map Ident Input, Body rep -> Body rep)
forall a b. (a, b) -> a
fst (((Map Ident Input, Body rep -> Body rep), Map Input VName)
 -> (Map Ident Input, Body rep -> Body rep))
-> (Map Ident Input
    -> ((Map Ident Input, Body rep -> Body rep), Map Input VName))
-> Map Ident Input
-> (Map Ident Input, Body rep -> Body rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Map Ident Input, Body rep -> Body rep), Map Input VName)
 -> Ident
 -> Input
 -> ((Map Ident Input, Body rep -> Body rep), Map Input VName))
-> ((Map Ident Input, Body rep -> Body rep), Map Input VName)
-> Map Ident Input
-> ((Map Ident Input, Body rep -> Body rep), Map Input VName)
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' ((Map Ident Input, Body rep -> Body rep), Map Input VName)
-> Ident
-> Input
-> ((Map Ident Input, Body rep -> Body rep), Map Input VName)
forall rep k c.
(Buildable rep, Ord k) =>
((Map Ident k, Body rep -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body rep -> c), Map k VName)
comb ((Map Ident Input
forall k a. Map k a
M.empty, Body rep -> Body rep
forall a. a -> a
id), Map Input VName
forall k a. Map k a
M.empty)
  where
    comb :: ((Map Ident k, Body rep -> c), Map k VName)
-> Ident -> k -> ((Map Ident k, Body rep -> c), Map k VName)
comb ((Map Ident k
parmap, Body rep -> c
inner), Map k VName
arrmap) Ident
par k
arr =
      case k -> Map k VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
arr Map k VName
arrmap of
        Maybe VName
Nothing ->
          ( (Ident -> k -> Map Ident k -> Map Ident k
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Ident
par k
arr Map Ident k
parmap, Body rep -> c
inner),
            k -> VName -> Map k VName -> Map k VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
arr (Ident -> VName
identName Ident
par) Map k VName
arrmap
          )
        Just VName
par' ->
          ( (Map Ident k
parmap, Body rep -> c
inner (Body rep -> c) -> (Body rep -> Body rep) -> Body rep -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName -> Body rep -> Body rep
forall rep. Buildable rep => Ident -> VName -> Body rep -> Body rep
forward Ident
par VName
par'),
            Map k VName
arrmap
          )
    forward :: Ident -> VName -> Body rep -> Body rep
forward Ident
to VName
from Body rep
b =
      [Ident] -> Exp rep -> Stm rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident
to] (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
from) Stm rep -> Body rep -> Body rep
forall rep. Buildable rep => Stm rep -> Body rep -> Body rep
`insertStm` Body rep
b

fuseRedomap ::
  Buildable rep =>
  Names ->
  [VName] ->
  Lambda rep ->
  [SubExp] ->
  [SubExp] ->
  [SOAC.Input] ->
  [(VName, Ident)] ->
  Lambda rep ->
  [SubExp] ->
  [SubExp] ->
  [SOAC.Input] ->
  (Lambda rep, [SOAC.Input])
fuseRedomap :: Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
  Names
unfus_nms
  [VName]
outVars
  Lambda rep
p_lam
  [SubExp]
p_scan_nes
  [SubExp]
p_red_nes
  [Input]
p_inparr
  [(VName, Ident)]
outPairs
  Lambda rep
c_lam
  [SubExp]
c_scan_nes
  [SubExp]
c_red_nes
  [Input]
c_inparr =
    -- We hack the implementation of map o redomap to handle this case:
    --   (i) we remove the accumulator formal paramter and corresponding
    --       (body) result from from redomap's fold-lambda body
    let p_num_nes :: Int
p_num_nes = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes
        unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_nms) [VName]
outVars
        p_lam_body :: Body rep
p_lam_body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
p_lam
        ([Type]
p_lam_scan_ts, [Type]
p_lam_red_ts, [Type]
p_lam_map_ts) =
          Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
p_lam
        ([SubExpRes]
p_lam_scan_res, [SubExpRes]
p_lam_red_res, [SubExpRes]
p_lam_map_res) =
          Int
-> Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
p_red_nes) ([SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes]))
-> [SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes])
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
p_lam_body
        p_lam_hacked :: Lambda rep
p_lam_hacked =
          Lambda rep
p_lam
            { lambdaParams :: [LParam rep]
lambdaParams = Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
takeLast ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) ([LParam rep] -> [LParam rep]) -> [LParam rep] -> [LParam rep]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
p_lam,
              lambdaBody :: Body rep
lambdaBody = Body rep
p_lam_body {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
p_lam_map_res},
              lambdaReturnType :: [Type]
lambdaReturnType = [Type]
p_lam_map_ts
            }

        --  (ii) we remove the accumulator's (global) output result from
        --       @outPairs@, then ``map o redomap'' fuse the two lambdas
        --       (in the usual way), and construct the extra return types
        --       for the arrays that fall through.
        (Lambda rep
res_lam, [Input]
new_inp) =
          Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps
            ([VName] -> Names
namesFromList [VName]
unfus_arrs)
            Lambda rep
p_lam_hacked
            [Input]
p_inparr
            (Int -> [(VName, Ident)] -> [(VName, Ident)]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes [(VName, Ident)]
outPairs)
            Lambda rep
c_lam
            [Input]
c_inparr
        ([Type]
res_lam_scan_ts, [Type]
res_lam_red_ts, [Type]
res_lam_map_ts) =
          Int -> Int -> [Type] -> ([Type], [Type], [Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) ([Type] -> ([Type], [Type], [Type]))
-> [Type] -> ([Type], [Type], [Type])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
res_lam
        ([VName]
_, [Type]
extra_map_ts) =
          [(VName, Type)] -> ([VName], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Type)] -> ([VName], [Type]))
-> [(VName, Type)] -> ([VName], [Type])
forall a b. (a -> b) -> a -> b
$
            ((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName
nm, Type
_) -> VName
nm VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
unfus_arrs) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
              [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes [VName]
outVars) ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
                Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
p_num_nes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
                  Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
p_lam

        -- (iii) Finally, we put back the accumulator's formal parameter and
        --       (body) result in the first position of the obtained lambda.
        accpars :: [LParam rep]
accpars = Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
dropLast ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
p_inparr) ([LParam rep] -> [LParam rep]) -> [LParam rep] -> [LParam rep]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
p_lam
        res_body :: Body rep
res_body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
res_lam
        ([SubExpRes]
res_lam_scan_res, [SubExpRes]
res_lam_red_res, [SubExpRes]
res_lam_map_res) =
          Int
-> Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_scan_nes) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
c_red_nes) ([SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes]))
-> [SubExpRes] -> ([SubExpRes], [SubExpRes], [SubExpRes])
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
res_body
        res_body' :: Body rep
res_body' =
          Body rep
res_body
            { bodyResult :: [SubExpRes]
bodyResult =
                [SubExpRes]
p_lam_scan_res [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_scan_res
                  [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
p_lam_red_res
                  [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_red_res
                  [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ [SubExpRes]
res_lam_map_res
            }
        res_lam' :: Lambda rep
res_lam' =
          Lambda rep
res_lam
            { lambdaParams :: [LParam rep]
lambdaParams = [LParam rep]
accpars [LParam rep] -> [LParam rep] -> [LParam rep]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
res_lam,
              lambdaBody :: Body rep
lambdaBody = Body rep
res_body',
              lambdaReturnType :: [Type]
lambdaReturnType =
                [Type]
p_lam_scan_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_scan_ts
                  [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
p_lam_red_ts
                  [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_red_ts
                  [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
res_lam_map_ts
                  [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
extra_map_ts
            }
     in (Lambda rep
res_lam', [Input]
new_inp)

mergeReduceOps :: Lambda rep -> Lambda rep -> Lambda rep
mergeReduceOps :: Lambda rep -> Lambda rep -> Lambda rep
mergeReduceOps (Lambda [LParam rep]
par1 Body rep
bdy1 [Type]
rtp1) (Lambda [LParam rep]
par2 Body rep
bdy2 [Type]
rtp2) =
  let body' :: Body rep
body' =
        BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body
          (Body rep -> BodyDec rep
forall rep. Body rep -> BodyDec rep
bodyDec Body rep
bdy1)
          (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
bdy1 Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
bdy2)
          (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
bdy1 [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. [a] -> [a] -> [a]
++ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
bdy2)
      (Int
len1, Int
len2) = ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rtp1, [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rtp2)
      par' :: [LParam rep]
par' = Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
take Int
len1 [LParam rep]
par1 [LParam rep] -> [LParam rep] -> [LParam rep]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
take Int
len2 [LParam rep]
par2 [LParam rep] -> [LParam rep] -> [LParam rep]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
drop Int
len1 [LParam rep]
par1 [LParam rep] -> [LParam rep] -> [LParam rep]
forall a. [a] -> [a] -> [a]
++ Int -> [LParam rep] -> [LParam rep]
forall a. Int -> [a] -> [a]
drop Int
len2 [LParam rep]
par2
   in [LParam rep] -> Body rep -> [Type] -> Lambda rep
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam rep]
par' Body rep
body' ([Type]
rtp1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rtp2)