{-# LANGUAGE TypeFamilies #-}
module Futhark.Tools
( module Futhark.Construct,
redomapToMapAndReduce,
dissectScrema,
sequentialStreamWholeArray,
partitionChunkedFoldParameters,
module Futhark.Analysis.PrimExp.Convert,
)
where
import Control.Monad.Identity
import Futhark.Analysis.PrimExp.Convert
import Futhark.Construct
import Futhark.IR
import Futhark.IR.SOACS.SOAC
import Futhark.Util
redomapToMapAndReduce ::
( MonadFreshNames m,
Buildable rep,
ExpDec rep ~ (),
Op rep ~ SOAC rep
) =>
Pat (LetDec rep) ->
( SubExp,
[Reduce rep],
Lambda rep,
[VName]
) ->
m (Stm rep, Stm rep)
redomapToMapAndReduce :: forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce (Pat [PatElem (LetDec rep)]
pes) (SubExp
w, [Reduce rep]
reds, Lambda rep
map_lam, [VName]
arrs) = do
([Ident]
map_pat, Pat (LetDec rep)
red_pat, [VName]
red_arrs) <-
forall {k} dec (m :: * -> *) (rep :: k).
(Typed dec, MonadFreshNames m) =>
[PatElem dec]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat dec, [VName])
splitScanOrRedomap [PatElem (LetDec rep)]
pes SubExp
w Lambda rep
map_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
let map_stm :: Stm rep
map_stm = forall {k} (rep :: k).
Buildable rep =>
[Ident] -> Exp rep -> Stm rep
mkLet [Ident]
map_pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda rep
map_lam)
Stm rep
red_stm <-
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
red_pat (forall dec. dec -> StmAux dec
defAux ()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
red_arrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep
map_stm, Stm rep
red_stm)
splitScanOrRedomap ::
(Typed dec, MonadFreshNames m) =>
[PatElem dec] ->
SubExp ->
Lambda rep ->
[[SubExp]] ->
m ([Ident], Pat dec, [VName])
splitScanOrRedomap :: forall {k} dec (m :: * -> *) (rep :: k).
(Typed dec, MonadFreshNames m) =>
[PatElem dec]
-> SubExp
-> Lambda rep
-> [[SubExp]]
-> m ([Ident], Pat dec, [VName])
splitScanOrRedomap [PatElem dec]
pes SubExp
w Lambda rep
map_lam [[SubExp]]
nes = do
let ([PatElem dec]
acc_pes, [PatElem dec]
arr_pes) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
nes) [PatElem dec]
pes
([Type]
acc_ts, [Type]
_arr_ts) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
nes)) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam
[Ident]
map_accpat <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem dec -> Type -> m Ident
accMapPatElem [PatElem dec]
acc_pes [Type]
acc_ts
[Ident]
map_arrpat <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem dec -> m Ident
arrMapPatElem [PatElem dec]
arr_pes
let map_pat :: [Ident]
map_pat = [Ident]
map_accpat forall a. [a] -> [a] -> [a]
++ [Ident]
map_arrpat
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Ident]
map_pat, forall dec. [PatElem dec] -> Pat dec
Pat [PatElem dec]
acc_pes, forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
map_accpat)
where
accMapPatElem :: PatElem dec -> Type -> m Ident
accMapPatElem PatElem dec
pe Type
acc_t =
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent (VName -> [Char]
baseString (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) forall a. [a] -> [a] -> [a]
++ [Char]
"_map_acc") forall a b. (a -> b) -> a -> b
$ Type
acc_t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w
arrMapPatElem :: PatElem dec -> m Ident
arrMapPatElem = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => PatElem dec -> Ident
patElemIdent
dissectScrema ::
( MonadBuilder m,
Op (Rep m) ~ SOAC (Rep m),
Buildable (Rep m)
) =>
Pat (LetDec (Rep m)) ->
SubExp ->
ScremaForm (Rep m) ->
[VName] ->
m ()
dissectScrema :: forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec (Rep m))
pat SubExp
w (ScremaForm [Scan (Rep m)]
scans [Reduce (Rep m)]
reds Lambda (Rep m)
map_lam) [VName]
arrs = do
let num_reds :: Int
num_reds = forall {k} (rep :: k). [Reduce rep] -> Int
redResults [Reduce (Rep m)]
reds
num_scans :: Int
num_scans = forall {k} (rep :: k). [Scan rep] -> Int
scanResults [Scan (Rep m)]
scans
([VName]
scan_res, [VName]
red_res, [VName]
map_res) =
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_scans Int
num_reds forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Rep m))
pat
[VName]
to_red <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_reds forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"to_red"
let scanomap :: ScremaForm (Rep m)
scanomap = forall {k} (rep :: k). [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan (Rep m)]
scans Lambda (Rep m)
map_lam
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ([VName]
scan_res forall a. Semigroup a => a -> a -> a
<> [VName]
to_red forall a. Semigroup a => a -> a -> a
<> [VName]
map_res) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs ScremaForm (Rep m)
scanomap
ScremaForm (Rep m)
reduce <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce (Rep m)]
reds
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
red_res forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
to_red ScremaForm (Rep m)
reduce
sequentialStreamWholeArray ::
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m)) ->
SubExp ->
[SubExp] ->
Lambda (Rep m) ->
[VName] ->
m ()
sequentialStreamWholeArray :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec (Rep m))
pat SubExp
w [SubExp]
nes Lambda (Rep m)
lam [VName]
arrs = do
let (Param (LParamInfo (Rep m))
chunk_size_param, [Param (LParamInfo (Rep m))]
fold_params, [Param (LParamInfo (Rep m))]
arr_params) =
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
chunk_size_param] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
w
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo (Rep m))]
fold_params [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, SubExp
ne) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo (Rep m))]
arr_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param (LParamInfo (Rep m))
p, VName
arr) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ReshapeKind -> ShapeBase SubExp -> VName -> BasicOp
Reshape ReshapeKind
ReshapeCoerce (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (LParamInfo (Rep m))
p) VName
arr
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Rep m))
pat) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam) forall a b. (a -> b) -> a -> b
$ \(PatElem (LetDec (Rep m))
pe, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ case (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec (Rep m))
pe, SubExp
se) of
([SubExp]
dims, Var VName
v)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
dims ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Rep m))
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> ShapeBase SubExp -> VName -> BasicOp
Reshape ReshapeKind
ReshapeCoerce (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) VName
v
([SubExp], SubExp)
_ -> forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Rep m))
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
partitionChunkedFoldParameters ::
Int ->
[Param dec] ->
(Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters :: forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters Int
_ [] =
forall a. HasCallStack => [Char] -> a
error [Char]
"partitionChunkedFoldParameters: lambda takes no parameters"
partitionChunkedFoldParameters Int
num_accs (Param dec
chunk_param : [Param dec]
params) =
let ([Param dec]
acc_params, [Param dec]
arr_params) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
in (Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)