{-# LANGUAGE TypeFamilies #-}

-- | High-level representation of SOACs.  When performing
-- SOAC-transformations, operating on normal 'Exp' values is somewhat
-- of a nuisance, as they can represent terms that are not proper
-- SOACs.  In contrast, this module exposes a SOAC representation that
-- does not enable invalid representations (except for type errors).
--
-- Furthermore, while standard normalised Futhark requires that the inputs
-- to a SOAC are variables or constants, the representation in this
-- module also supports various index-space transformations, like
-- @replicate@ or @rearrange@.  This is also very convenient when
-- implementing transformations.
--
-- The names exported by this module conflict with the standard Futhark
-- syntax tree constructors, so you are advised to use a qualified
-- import:
--
-- @
-- import Futhark.Analysis.HORep.SOAC (SOAC)
-- import qualified Futhark.Analysis.HORep.SOAC as SOAC
-- @
module Futhark.Analysis.HORep.SOAC
  ( -- * SOACs
    SOAC (..),
    Futhark.ScremaForm (..),
    inputs,
    setInputs,
    lambda,
    setLambda,
    typeOf,
    width,

    -- ** Converting to and from expressions
    NotSOAC (..),
    fromExp,
    toExp,
    toSOAC,

    -- * SOAC inputs
    Input (..),
    varInput,
    inputTransforms,
    identInput,
    isVarInput,
    isVarishInput,
    addTransform,
    addInitialTransforms,
    inputArray,
    inputRank,
    inputType,
    inputRowType,
    transformRows,
    transposeInput,
    applyTransforms,

    -- ** Input transformations
    ArrayTransforms,
    noTransforms,
    nullTransforms,
    (|>),
    (<|),
    viewf,
    ViewF (..),
    viewl,
    ViewL (..),
    ArrayTransform (..),
    transformFromExp,
    transformToExp,
    soacToStream,
  )
where

import Data.Foldable as Foldable
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.Construct hiding (toExp)
import Futhark.IR hiding
  ( Iota,
    Rearrange,
    Replicate,
    Reshape,
    typeOf,
  )
import Futhark.IR qualified as Futhark
import Futhark.IR.SOACS.SOAC
  ( HistOp (..),
    ScremaForm (..),
    scremaType,
  )
import Futhark.IR.SOACS.SOAC qualified as Futhark
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util.Pretty (pretty)
import Futhark.Util.Pretty qualified as PP

-- | A single, simple transformation.  If you want several, don't just
-- create a list, use 'ArrayTransforms' instead.
data ArrayTransform
  = -- | A permutation of an otherwise valid input.
    Rearrange Certs [Int]
  | -- | A reshaping of an otherwise valid input.
    Reshape Certs ReshapeKind Shape
  | -- | A reshaping of the outer dimension.
    ReshapeOuter Certs ReshapeKind Shape
  | -- | A reshaping of everything but the outer dimension.
    ReshapeInner Certs ReshapeKind Shape
  | -- | Replicate the rows of the array a number of times.
    Replicate Certs Shape
  deriving (Int -> ArrayTransform -> ShowS
[ArrayTransform] -> ShowS
ArrayTransform -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransform] -> ShowS
$cshowList :: [ArrayTransform] -> ShowS
show :: ArrayTransform -> String
$cshow :: ArrayTransform -> String
showsPrec :: Int -> ArrayTransform -> ShowS
$cshowsPrec :: Int -> ArrayTransform -> ShowS
Show, ArrayTransform -> ArrayTransform -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransform -> ArrayTransform -> Bool
$c/= :: ArrayTransform -> ArrayTransform -> Bool
== :: ArrayTransform -> ArrayTransform -> Bool
$c== :: ArrayTransform -> ArrayTransform -> Bool
Eq, Eq ArrayTransform
ArrayTransform -> ArrayTransform -> Bool
ArrayTransform -> ArrayTransform -> Ordering
ArrayTransform -> ArrayTransform -> ArrayTransform
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmin :: ArrayTransform -> ArrayTransform -> ArrayTransform
max :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmax :: ArrayTransform -> ArrayTransform -> ArrayTransform
>= :: ArrayTransform -> ArrayTransform -> Bool
$c>= :: ArrayTransform -> ArrayTransform -> Bool
> :: ArrayTransform -> ArrayTransform -> Bool
$c> :: ArrayTransform -> ArrayTransform -> Bool
<= :: ArrayTransform -> ArrayTransform -> Bool
$c<= :: ArrayTransform -> ArrayTransform -> Bool
< :: ArrayTransform -> ArrayTransform -> Bool
$c< :: ArrayTransform -> ArrayTransform -> Bool
compare :: ArrayTransform -> ArrayTransform -> Ordering
$ccompare :: ArrayTransform -> ArrayTransform -> Ordering
Ord)

instance Substitute ArrayTransform where
  substituteNames :: Map VName VName -> ArrayTransform -> ArrayTransform
substituteNames Map VName VName
substs (Rearrange Certs
cs [Int]
xs) =
    Certs -> [Int] -> ArrayTransform
Rearrange (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) [Int]
xs
  substituteNames Map VName VName
substs (Reshape Certs
cs ReshapeKind
k Shape
ses) =
    Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
  substituteNames Map VName VName
substs (ReshapeOuter Certs
cs ReshapeKind
k Shape
ses) =
    Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeOuter (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
  substituteNames Map VName VName
substs (ReshapeInner Certs
cs ReshapeKind
k Shape
ses) =
    Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
  substituteNames Map VName VName
substs (Replicate Certs
cs Shape
se) =
    Certs -> Shape -> ArrayTransform
Replicate (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
se)

-- | A sequence of array transformations, heavily inspired by
-- "Data.Seq".  You can decompose it using 'viewf' and 'viewl', and
-- grow it by using '|>' and '<|'.  These correspond closely to the
-- similar operations for sequences, except that appending will try to
-- normalise and simplify the transformation sequence.
--
-- The data type is opaque in order to enforce normalisation
-- invariants.  Basically, when you grow the sequence, the
-- implementation will try to coalesce neighboring permutations, for
-- example by composing permutations and removing identity
-- transformations.
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
  deriving (ArrayTransforms -> ArrayTransforms -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransforms -> ArrayTransforms -> Bool
$c/= :: ArrayTransforms -> ArrayTransforms -> Bool
== :: ArrayTransforms -> ArrayTransforms -> Bool
$c== :: ArrayTransforms -> ArrayTransforms -> Bool
Eq, Eq ArrayTransforms
ArrayTransforms -> ArrayTransforms -> Bool
ArrayTransforms -> ArrayTransforms -> Ordering
ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmin :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
max :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmax :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
>= :: ArrayTransforms -> ArrayTransforms -> Bool
$c>= :: ArrayTransforms -> ArrayTransforms -> Bool
> :: ArrayTransforms -> ArrayTransforms -> Bool
$c> :: ArrayTransforms -> ArrayTransforms -> Bool
<= :: ArrayTransforms -> ArrayTransforms -> Bool
$c<= :: ArrayTransforms -> ArrayTransforms -> Bool
< :: ArrayTransforms -> ArrayTransforms -> Bool
$c< :: ArrayTransforms -> ArrayTransforms -> Bool
compare :: ArrayTransforms -> ArrayTransforms -> Ordering
$ccompare :: ArrayTransforms -> ArrayTransforms -> Ordering
Ord, Int -> ArrayTransforms -> ShowS
[ArrayTransforms] -> ShowS
ArrayTransforms -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransforms] -> ShowS
$cshowList :: [ArrayTransforms] -> ShowS
show :: ArrayTransforms -> String
$cshow :: ArrayTransforms -> String
showsPrec :: Int -> ArrayTransforms -> ShowS
$cshowsPrec :: Int -> ArrayTransforms -> ShowS
Show)

instance Semigroup ArrayTransforms where
  ArrayTransforms
ts1 <> :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
<> ArrayTransforms
ts2 = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts2 of
    ArrayTransform
t :< ArrayTransforms
ts2' -> (ArrayTransforms
ts1 ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
t) forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts2'
    ViewF
EmptyF -> ArrayTransforms
ts1

instance Monoid ArrayTransforms where
  mempty :: ArrayTransforms
mempty = ArrayTransforms
noTransforms

instance Substitute ArrayTransforms where
  substituteNames :: Map VName VName -> ArrayTransforms -> ArrayTransforms
substituteNames Map VName VName
substs (ArrayTransforms Seq ArrayTransform
ts) =
    Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq ArrayTransform
ts

-- | The empty transformation list.
noTransforms :: ArrayTransforms
noTransforms :: ArrayTransforms
noTransforms = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty

-- | Is it an empty transformation list?
nullTransforms :: ArrayTransforms -> Bool
nullTransforms :: ArrayTransforms -> Bool
nullTransforms (ArrayTransforms Seq ArrayTransform
s) = forall a. Seq a -> Bool
Seq.null Seq ArrayTransform
s

-- | Decompose the input-end of the transformation sequence.
viewf :: ArrayTransforms -> ViewF
viewf :: ArrayTransforms -> ViewF
viewf (ArrayTransforms Seq ArrayTransform
s) = case forall a. Seq a -> ViewL a
Seq.viewl Seq ArrayTransform
s of
  ArrayTransform
t Seq.:< Seq ArrayTransform
s' -> ArrayTransform
t ArrayTransform -> ArrayTransforms -> ViewF
:< Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s'
  ViewL ArrayTransform
Seq.EmptyL -> ViewF
EmptyF

-- | A view of the first transformation to be applied.
data ViewF
  = EmptyF
  | ArrayTransform :< ArrayTransforms

-- | Decompose the output-end of the transformation sequence.
viewl :: ArrayTransforms -> ViewL
viewl :: ArrayTransforms -> ViewL
viewl (ArrayTransforms Seq ArrayTransform
s) = case forall a. Seq a -> ViewR a
Seq.viewr Seq ArrayTransform
s of
  Seq ArrayTransform
s' Seq.:> ArrayTransform
t -> Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s' ArrayTransforms -> ArrayTransform -> ViewL
:> ArrayTransform
t
  ViewR ArrayTransform
Seq.EmptyR -> ViewL
EmptyL

-- | A view of the last transformation to be applied.
data ViewL
  = EmptyL
  | ArrayTransforms :> ArrayTransform

-- | Add a transform to the end of the transformation list.
(|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
(|>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
  where
    extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewL
viewl ArrayTransforms
ts' of
      ViewL
EmptyL -> forall a. Maybe a
Nothing
      ArrayTransforms
ts'' :> ArrayTransform
t' -> forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
    add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ Seq ArrayTransform
ts' forall a. Seq a -> a -> Seq a
Seq.|> ArrayTransform
t'

-- | Add a transform at the beginning of the transformation list.
(<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
<| :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
(<|) = (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add forall a. a -> a
id
  where
    extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts' of
      ViewF
EmptyF -> forall a. Maybe a
Nothing
      ArrayTransform
t' :< ArrayTransforms
ts'' -> forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
    add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ ArrayTransform
t' forall a. a -> Seq a -> Seq a
Seq.<| Seq ArrayTransform
ts'

addTransform' ::
  (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)) ->
  (ArrayTransform -> ArrayTransforms -> ArrayTransforms) ->
  ((ArrayTransform, ArrayTransform) -> (ArrayTransform, ArrayTransform)) ->
  ArrayTransform ->
  ArrayTransforms ->
  ArrayTransforms
addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
t ArrayTransforms
ts =
  forall a. a -> Maybe a -> a
fromMaybe (ArrayTransform
t ArrayTransform -> ArrayTransforms -> ArrayTransforms
`add` ArrayTransforms
ts) forall a b. (a -> b) -> a -> b
$ do
    (ArrayTransform
t', ArrayTransforms
ts') <- ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts
    ArrayTransform
combined <- forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms forall a b. (a -> b) -> a -> b
$ (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap (ArrayTransform
t', ArrayTransform
t)
    forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
      if ArrayTransform -> Bool
identityTransform ArrayTransform
combined
        then ArrayTransforms
ts'
        else (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
    -> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
combined ArrayTransforms
ts'

identityTransform :: ArrayTransform -> Bool
identityTransform :: ArrayTransform -> Bool
identityTransform (Rearrange Certs
_ [Int]
perm) =
  forall (t :: * -> *). Foldable t => t Bool -> Bool
Foldable.and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [Int]
perm [Int
0 ..]
identityTransform ArrayTransform
_ = Bool
False

combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms (Rearrange Certs
cs2 [Int]
perm2) (Rearrange Certs
cs1 [Int]
perm1) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> [Int] -> ArrayTransform
Rearrange (Certs
cs1 forall a. Semigroup a => a -> a -> a
<> Certs
cs2) forall a b. (a -> b) -> a -> b
$ [Int]
perm2 [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm1
combineTransforms ArrayTransform
_ ArrayTransform
_ = forall a. Maybe a
Nothing

-- | Given an expression, determine whether the expression represents
-- an input transformation of an array variable.  If so, return the
-- variable and the transformation.  Only 'Rearrange' and 'Reshape'
-- are possible to express this way.
transformFromExp :: Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp :: forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp Certs
cs (BasicOp (Futhark.Rearrange [Int]
perm VName
v)) =
  forall a. a -> Maybe a
Just (VName
v, Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs [Int]
perm)
transformFromExp Certs
cs (BasicOp (Futhark.Reshape ReshapeKind
k Shape
shape VName
v)) =
  forall a. a -> Maybe a
Just (VName
v, Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape Certs
cs ReshapeKind
k Shape
shape)
transformFromExp Certs
cs (BasicOp (Futhark.Replicate Shape
shape (Var VName
v))) =
  forall a. a -> Maybe a
Just (VName
v, Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
shape)
transformFromExp Certs
_ Exp rep
_ = forall a. Maybe a
Nothing

-- | Turn an array transform on an array back into an expression.
transformToExp :: (Monad m, HasScope rep m) => ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp :: forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp (Replicate Certs
cs Shape
n) VName
ia =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Futhark.Replicate Shape
n (VName -> SubExp
Var VName
ia))
transformToExp (Rearrange Certs
cs [Int]
perm) VName
ia = do
  Int
r <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Futhark.Rearrange ([Int]
perm forall a. [a] -> [a] -> [a]
++ [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r forall a. Num a => a -> a -> a
- Int
1]) VName
ia)
transformToExp (Reshape Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape VName
ia)
transformToExp (ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
  Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeOuter Shape
shape Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)
transformToExp (ReshapeInner Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
  Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeInner Shape
shape Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)

-- | One array input to a SOAC - a SOAC may have multiple inputs, but
-- all are of this form.  Only the array inputs are expressed with
-- this type; other arguments, such as initial accumulator values, are
-- plain expressions.  The transforms are done left-to-right, that is,
-- the first element of the 'ArrayTransform' list is applied first.
data Input = Input ArrayTransforms VName Type
  deriving (Int -> Input -> ShowS
[Input] -> ShowS
Input -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Input] -> ShowS
$cshowList :: [Input] -> ShowS
show :: Input -> String
$cshow :: Input -> String
showsPrec :: Int -> Input -> ShowS
$cshowsPrec :: Int -> Input -> ShowS
Show, Input -> Input -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Input -> Input -> Bool
$c/= :: Input -> Input -> Bool
== :: Input -> Input -> Bool
$c== :: Input -> Input -> Bool
Eq, Eq Input
Input -> Input -> Bool
Input -> Input -> Ordering
Input -> Input -> Input
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Input -> Input -> Input
$cmin :: Input -> Input -> Input
max :: Input -> Input -> Input
$cmax :: Input -> Input -> Input
>= :: Input -> Input -> Bool
$c>= :: Input -> Input -> Bool
> :: Input -> Input -> Bool
$c> :: Input -> Input -> Bool
<= :: Input -> Input -> Bool
$c<= :: Input -> Input -> Bool
< :: Input -> Input -> Bool
$c< :: Input -> Input -> Bool
compare :: Input -> Input -> Ordering
$ccompare :: Input -> Input -> Ordering
Ord)

instance Substitute Input where
  substituteNames :: Map VName VName -> Input -> Input
substituteNames Map VName VName
substs (Input ArrayTransforms
ts VName
v Type
t) =
    ArrayTransforms -> VName -> Type -> Input
Input
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ArrayTransforms
ts)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
v)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Type
t)

-- | Create a plain array variable input with no transformations.
varInput :: HasScope t f => VName -> f Input
varInput :: forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput VName
v = Type -> Input
withType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  where
    withType :: Type -> Input
withType = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty) VName
v

-- | Create a plain array variable input with no transformations, from an 'Ident'.
identInput :: Ident -> Input
identInput :: Ident -> Input
identInput Ident
v = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty) (Ident -> VName
identName Ident
v) (Ident -> Type
identType Ident
v)

-- | If the given input is a plain variable input, with no transforms,
-- return the variable.
isVarInput :: Input -> Maybe VName
isVarInput :: Input -> Maybe VName
isVarInput (Input ArrayTransforms
ts VName
v Type
_) | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = forall a. a -> Maybe a
Just VName
v
isVarInput Input
_ = forall a. Maybe a
Nothing

-- | If the given input is a plain variable input, with no non-vacuous
-- transforms, return the variable.
isVarishInput :: Input -> Maybe VName
isVarishInput :: Input -> Maybe VName
isVarishInput (Input ArrayTransforms
ts VName
v Type
t)
  | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = forall a. a -> Maybe a
Just VName
v
  | Reshape Certs
cs ReshapeKind
ReshapeCoerce (Shape [SubExp
_]) :< ArrayTransforms
ts' <- ArrayTransforms -> ViewF
viewf ArrayTransforms
ts,
    Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
      Input -> Maybe VName
isVarishInput forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
Input ArrayTransforms
ts' VName
v Type
t
isVarishInput Input
_ = forall a. Maybe a
Nothing

-- | Add a transformation to the end of the transformation list.
addTransform :: ArrayTransform -> Input -> Input
addTransform :: ArrayTransform -> Input -> Input
addTransform ArrayTransform
tr (Input ArrayTransforms
trs VName
a Type
t) =
  ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
trs ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
tr) VName
a Type
t

-- | Add several transformations to the start of the transformation
-- list.
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms ArrayTransforms
ts (Input ArrayTransforms
ots VName
a Type
t) = ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
ts forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ots) VName
a Type
t

applyTransform :: MonadBuilder m => ArrayTransform -> VName -> m VName
applyTransform :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform ArrayTransform
tr VName
ia = do
  (Certs
cs, Exp (Rep m)
e) <- forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp ArrayTransform
tr VName
ia
  forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
s Exp (Rep m)
e
  where
    s :: String
s = case ArrayTransform
tr of
      Replicate {} -> String
"replicate"
      Rearrange {} -> String
"rearrange"
      Reshape {} -> String
"reshape"
      ReshapeOuter {} -> String
"reshape_outer"
      ReshapeInner {} -> String
"reshape_inner"

applyTransforms :: MonadBuilder m => ArrayTransforms -> VName -> m VName
applyTransforms :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms (ArrayTransforms Seq ArrayTransform
ts) VName
a = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform) VName
a Seq ArrayTransform
ts

-- | Convert SOAC inputs to the corresponding expressions.
inputsToSubExps ::
  (MonadBuilder m) =>
  [Input] ->
  m [VName]
inputsToSubExps :: forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => Input -> m VName
f
  where
    f :: Input -> m VName
f (Input ArrayTransforms
ts VName
a Type
_) = forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms ArrayTransforms
ts VName
a

-- | Return the array name of the input.
inputArray :: Input -> VName
inputArray :: Input -> VName
inputArray (Input ArrayTransforms
_ VName
v Type
_) = VName
v

-- | The transformations applied to an input.
inputTransforms :: Input -> ArrayTransforms
inputTransforms :: Input -> ArrayTransforms
inputTransforms (Input ArrayTransforms
ts VName
_ Type
_) = ArrayTransforms
ts

-- | Return the type of an input.
inputType :: Input -> Type
inputType :: Input -> Type
inputType (Input (ArrayTransforms Seq ArrayTransform
ts) VName
_ Type
at) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Type -> ArrayTransform -> Type
transformType Type
at Seq ArrayTransform
ts
  where
    transformType :: Type -> ArrayTransform -> Type
transformType Type
t (Replicate Certs
_ Shape
shape) =
      Type -> Shape -> Type
arrayOfShape Type
t Shape
shape
    transformType Type
t (Rearrange Certs
_ [Int]
perm) =
      [Int] -> Type -> Type
rearrangeType [Int]
perm Type
t
    transformType Type
t (Reshape Certs
_ ReshapeKind
_ Shape
shape) =
      Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Shape
shape
    transformType Type
t (ReshapeOuter Certs
_ ReshapeKind
_ Shape
shape) =
      let Shape [SubExp]
oldshape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
       in Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape (forall d. ShapeBase d -> [d]
shapeDims Shape
shape forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
1 [SubExp]
oldshape)
    transformType Type
t (ReshapeInner Certs
_ ReshapeKind
_ Shape
shape) =
      let Shape [SubExp]
oldshape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
       in Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape (forall a. Int -> [a] -> [a]
take Int
1 [SubExp]
oldshape forall a. [a] -> [a] -> [a]
++ forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

-- | Return the row type of an input.  Just a convenient alias.
inputRowType :: Input -> Type
inputRowType :: Input -> Type
inputRowType = forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType

-- | Return the array rank (dimensionality) of an input.  Just a
-- convenient alias.
inputRank :: Input -> Int
inputRank :: Input -> Int
inputRank = forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType

-- | Apply the transformations to every row of the input.
transformRows :: ArrayTransforms -> Input -> Input
transformRows :: ArrayTransforms -> Input -> Input
transformRows (ArrayTransforms Seq ArrayTransform
ts) =
  forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Input -> ArrayTransform -> Input
transformRows') Seq ArrayTransform
ts
  where
    transformRows' :: Input -> ArrayTransform -> Input
transformRows' Input
inp (Rearrange Certs
cs [Int]
perm) =
      ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm)) Input
inp
    transformRows' Input
inp (Reshape Certs
cs ReshapeKind
k Shape
shape) =
      ArrayTransform -> Input -> Input
addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner Certs
cs ReshapeKind
k Shape
shape) Input
inp
    transformRows' Input
inp (Replicate Certs
cs Shape
n)
      | Input -> Int
inputRank Input
inp forall a. Eq a => a -> a -> Bool
== Int
1 =
          Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty [Int
1, Int
0]
            ArrayTransform -> Input -> Input
`addTransform` (Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n ArrayTransform -> Input -> Input
`addTransform` Input
inp)
      | Bool
otherwise =
          Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty (Int
2 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: Int
1 forall a. a -> [a] -> [a]
: [Int
3 .. Input -> Int
inputRank Input
inp])
            ArrayTransform -> Input -> Input
`addTransform` ( Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n
                               ArrayTransform -> Input -> Input
`addTransform` (Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty (Int
1 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: [Int
2 .. Input -> Int
inputRank Input
inp forall a. Num a => a -> a -> a
- Int
1]) ArrayTransform -> Input -> Input
`addTransform` Input
inp)
                           )
    transformRows' Input
inp ArrayTransform
nts =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"transformRows: Cannot transform this yet:\n" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ArrayTransform
nts forall a. [a] -> [a] -> [a]
++ String
"\n" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Input
inp

-- | Add to the input a 'Rearrange' transform that performs an @(k,n)@
-- transposition.  The new transform will be at the end of the current
-- transformation list.
transposeInput :: Int -> Int -> Input -> Input
transposeInput :: Int -> Int -> Input -> Input
transposeInput Int
k Int
n Input
inp =
  ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [Int
0 .. Input -> Int
inputRank Input
inp forall a. Num a => a -> a -> a
- Int
1]) Input
inp

-- | A definite representation of a SOAC expression.
data SOAC rep
  = Stream SubExp (Lambda rep) [SubExp] [Input]
  | Scatter SubExp (Lambda rep) [Input] [(Shape, Int, VName)]
  | Screma SubExp (ScremaForm rep) [Input]
  | Hist SubExp [HistOp rep] (Lambda rep) [Input]
  deriving (SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
Show)

instance PP.Pretty Input where
  pretty :: forall ann. Input -> Doc ann
pretty (Input (ArrayTransforms Seq ArrayTransform
ts) VName
arr Type
_) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {ann}. Doc ann -> ArrayTransform -> Doc ann
f (forall a ann. Pretty a => a -> Doc ann
pretty VName
arr) Seq ArrayTransform
ts
    where
      f :: Doc ann -> ArrayTransform -> Doc ann
f Doc ann
e (Rearrange Certs
cs [Int]
perm) =
        Doc ann
"rearrange" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a. [Doc a] -> Doc a
PP.apply (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Int]
perm), Doc ann
e]
      f Doc ann
e (Reshape Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
        Doc ann
"reshape" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
        Doc ann
"reshape_outer" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (ReshapeInner Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
        Doc ann
"reshape_inner" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (Reshape Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
        Doc ann
"coerce" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
        Doc ann
"coerce_outer" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (ReshapeInner Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
        Doc ann
"coerce_inner" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
      f Doc ann
e (Replicate Certs
cs Shape
ne) =
        Doc ann
"replicate" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
ne, Doc ann
e]

instance PrettyRep rep => PP.Pretty (SOAC rep) where
  pretty :: forall ann. SOAC rep -> Doc ann
pretty (Screma SubExp
w ScremaForm rep
form [Input]
arrs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
Futhark.ppScrema SubExp
w [Input]
arrs ScremaForm rep
form
  pretty (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [Input]
imgs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
Futhark.ppHist SubExp
len [Input]
imgs [HistOp rep]
ops Lambda rep
bucket_fun
  pretty (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
arrs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
Futhark.ppStream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
lam
  pretty (Scatter SubExp
w Lambda rep
lam [Input]
arrs [(Shape, Int, VName)]
dests) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
Futhark.ppScatter SubExp
w [Input]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests

-- | Returns the inputs used in a SOAC.
inputs :: SOAC rep -> [Input]
inputs :: forall rep. SOAC rep -> [Input]
inputs (Stream SubExp
_ Lambda rep
_ [SubExp]
_ [Input]
arrs) = [Input]
arrs
inputs (Scatter SubExp
_len Lambda rep
_lam [Input]
ivs [(Shape, Int, VName)]
_as) = [Input]
ivs
inputs (Screma SubExp
_ ScremaForm rep
_ [Input]
arrs) = [Input]
arrs
inputs (Hist SubExp
_ [HistOp rep]
_ Lambda rep
_ [Input]
inps) = [Input]
inps

-- | Set the inputs to a SOAC.
setInputs :: [Input] -> SOAC rep -> SOAC rep
setInputs :: forall rep. [Input] -> SOAC rep -> SOAC rep
setInputs [Input]
arrs (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
_) =
  forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) Lambda rep
lam [SubExp]
nes [Input]
arrs
setInputs [Input]
arrs (Scatter SubExp
w Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
as) =
  forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) Lambda rep
lam [Input]
arrs [(Shape, Int, VName)]
as
setInputs [Input]
arrs (Screma SubExp
w ScremaForm rep
form [Input]
_) =
  forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w ScremaForm rep
form [Input]
arrs
setInputs [Input]
inps (Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
_) =
  forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
inps

newWidth :: [Input] -> SubExp -> SubExp
newWidth :: [Input] -> SubExp -> SubExp
newWidth [] SubExp
w = SubExp
w
newWidth (Input
inp : [Input]
_) SubExp
_ = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall a b. (a -> b) -> a -> b
$ Input -> Type
inputType Input
inp

-- | The lambda used in a given SOAC.
lambda :: SOAC rep -> Lambda rep
lambda :: forall rep. SOAC rep -> Lambda rep
lambda (Stream SubExp
_ Lambda rep
lam [SubExp]
_ [Input]
_) = Lambda rep
lam
lambda (Scatter SubExp
_len Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
_as) = Lambda rep
lam
lambda (Screma SubExp
_ (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam) [Input]
_) = Lambda rep
lam
lambda (Hist SubExp
_ [HistOp rep]
_ Lambda rep
lam [Input]
_) = Lambda rep
lam

-- | Set the lambda used in the SOAC.
setLambda :: Lambda rep -> SOAC rep -> SOAC rep
setLambda :: forall rep. Lambda rep -> SOAC rep -> SOAC rep
setLambda Lambda rep
lam (Stream SubExp
w Lambda rep
_ [SubExp]
nes [Input]
arrs) =
  forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
arrs
setLambda Lambda rep
lam (Scatter SubExp
len Lambda rep
_lam [Input]
ivs [(Shape, Int, VName)]
as) =
  forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len Lambda rep
lam [Input]
ivs [(Shape, Int, VName)]
as
setLambda Lambda rep
lam (Screma SubExp
w (ScremaForm [Scan rep]
scan [Reduce rep]
red Lambda rep
_) [Input]
arrs) =
  forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
red Lambda rep
lam) [Input]
arrs
setLambda Lambda rep
lam (Hist SubExp
w [HistOp rep]
ops Lambda rep
_ [Input]
inps) =
  forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
inps

-- | The return type of a SOAC.
typeOf :: SOAC rep -> [Type]
typeOf :: forall rep. SOAC rep -> [Type]
typeOf (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
_) =
  let accrtps :: [Type]
accrtps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
      arrtps :: [Type]
arrtps =
        [ forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1 Type
t) (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) NoUniqueness
NoUniqueness
          | Type
t <- forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)
        ]
   in [Type]
accrtps forall a. [a] -> [a] -> [a]
++ [Type]
arrtps
typeOf (Scatter SubExp
_w Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
dests) =
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = forall a. Int -> [a] -> [a]
drop Int
indexes forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
typeOf (Screma SubExp
w ScremaForm rep
form [Input]
_) =
  forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form
typeOf (Hist SubExp
_ [HistOp rep]
ops Lambda rep
_ [Input]
_) = do
  HistOp rep
op <- [HistOp rep]
ops
  forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)

-- | The "width" of a SOAC is the expected outer size of its array
-- inputs _after_ input-transforms have been carried out.
width :: SOAC rep -> SubExp
width :: forall rep. SOAC rep -> SubExp
width (Stream SubExp
w Lambda rep
_ [SubExp]
_ [Input]
_) = SubExp
w
width (Scatter SubExp
len Lambda rep
_lam [Input]
_ivs [(Shape, Int, VName)]
_as) = SubExp
len
width (Screma SubExp
w ScremaForm rep
_ [Input]
_) = SubExp
w
width (Hist SubExp
w [HistOp rep]
_ Lambda rep
_ [Input]
_) = SubExp
w

-- | Convert a SOAC to the corresponding expression.
toExp ::
  (MonadBuilder m, Op (Rep m) ~ Futhark.SOAC (Rep m)) =>
  SOAC (Rep m) ->
  m (Exp (Rep m))
toExp :: forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
toExp SOAC (Rep m)
soac = forall rep. Op rep -> Exp rep
Op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC SOAC (Rep m)
soac

-- | Convert a SOAC to a Futhark-level SOAC.
toSOAC :: MonadBuilder m => SOAC (Rep m) -> m (Futhark.SOAC (Rep m))
toSOAC :: forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC (Stream SubExp
w Lambda (Rep m)
lam [SubExp]
nes [Input]
inps) =
  forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.Stream SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
inps forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
toSOAC (Scatter SubExp
w Lambda (Rep m)
lam [Input]
ivs [(Shape, Int, VName)]
dests) =
  forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Futhark.Scatter SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
ivs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
dests
toSOAC (Screma SubExp
w ScremaForm (Rep m)
form [Input]
arrs) =
  forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ScremaForm (Rep m)
form
toSOAC (Hist SubExp
w [HistOp (Rep m)]
ops Lambda (Rep m)
lam [Input]
arrs) =
  forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Futhark.Hist SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [HistOp (Rep m)]
ops forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam

-- | The reason why some expression cannot be converted to a 'SOAC'
-- value.
data NotSOAC
  = -- | The expression is not a (tuple-)SOAC at all.
    NotSOAC
  deriving (Int -> NotSOAC -> ShowS
[NotSOAC] -> ShowS
NotSOAC -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NotSOAC] -> ShowS
$cshowList :: [NotSOAC] -> ShowS
show :: NotSOAC -> String
$cshow :: NotSOAC -> String
showsPrec :: Int -> NotSOAC -> ShowS
$cshowsPrec :: Int -> NotSOAC -> ShowS
Show)

-- | Either convert an expression to the normalised SOAC
-- representation, or a reason why the expression does not have the
-- valid form.
fromExp ::
  (Op rep ~ Futhark.SOAC rep, HasScope rep m) =>
  Exp rep ->
  m (Either NotSOAC (SOAC rep))
fromExp :: forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
fromExp (Op (Futhark.Stream SubExp
w [VName]
as [SubExp]
nes Lambda rep
lam)) =
  forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
lam [SubExp]
nes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
as
fromExp (Op (Futhark.Scatter SubExp
w [VName]
ivs Lambda rep
lam [(Shape, Int, VName)]
as)) =
  forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w Lambda rep
lam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
ivs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
as)
fromExp (Op (Futhark.Screma SubExp
w [VName]
arrs ScremaForm rep
form)) =
  forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w ScremaForm rep
form forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp (Op (Futhark.Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
lam)) =
  forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp Exp rep
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left NotSOAC
NotSOAC

-- | To-Stream translation of SOACs.
--   Returns the Stream SOAC and the
--   extra-accumulator body-result ident if any.
soacToStream ::
  ( HasScope rep m,
    MonadFreshNames m,
    Buildable rep,
    BuilderOps rep,
    Op rep ~ Futhark.SOAC rep
  ) =>
  SOAC rep ->
  m (SOAC rep, [Ident])
soacToStream :: forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
 Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
soacToStream SOAC rep
soac = do
  Param Type
chunk_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let chvar :: SubExp
chvar = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
chunk_param
      (Lambda rep
lam, [Input]
inps) = (forall rep. SOAC rep -> Lambda rep
lambda SOAC rep
soac, forall rep. SOAC rep -> [Input]
inputs SOAC rep
soac)
      w :: SubExp
w = forall rep. SOAC rep -> SubExp
width SOAC rep
soac
  Lambda rep
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
  let arrrtps :: [Type]
arrrtps = forall rep. SubExp -> Lambda rep -> [Type]
mapType SubExp
w Lambda rep
lam
      -- the chunked-outersize of the array result and input types
      loutps :: [Type]
loutps = [forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
arrrtps]
      lintps :: [Type]
lintps = [forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
inputRowType [Input]
inps]

  [Param Type]
strm_inpids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inp") [Type]
lintps
  -- Treat each SOAC case individually:
  case SOAC rep
soac of
    Screma SubExp
_ ScremaForm rep
form [Input]
_
      | Just Lambda rep
_ <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm rep
form -> do
          -- Map(f,a) => is translated in strem's body to:
          -- let strm_resids = map(f,a_ch) in strm_resids
          --
          -- array result and input IDs of the stream's lambda
          [Ident]
strm_resids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps
          let insoac :: SOAC rep
insoac =
                forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) forall a b. (a -> b) -> a -> b
$
                  forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
lam'
              insstm :: Stm rep
insstm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
strm_resids forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC rep
insoac
              strmbdy :: Body rep
strmbdy = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
              strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
strm_inpids
              strmlam :: Lambda rep
strmlam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
strmpar Body rep
strmbdy [Type]
loutps
          -- map(f,a) creates a stream with NO accumulators
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [] [Input]
inps, [])
      | Just ([Scan rep]
scans, Lambda rep
_) <- forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm rep
form,
        Futhark.Scan Lambda rep
scan_lam [SubExp]
nes <- forall rep. Buildable rep => [Scan rep] -> Scan rep
Futhark.singleScan [Scan rep]
scans -> do
          -- scanomap(scan_lam,nes,map_lam,a) => is translated in strem's body to:
          -- 1. let (scan0_ids,map_resids)   = scanomap(scan_lam, nes, map_lam, a_ch)
          -- 2. let strm_resids = map (acc `+`,nes, scan0_ids)
          -- 3. let outerszm1id = sizeof(0,strm_resids) - 1
          -- 4. let lasteel_ids = if outerszm1id < 0
          --                      then nes
          --                      else strm_resids[outerszm1id]
          -- 5. let acc'        = acc + lasteel_ids
          --    {acc', strm_resids, map_resids}
          -- the array and accumulator result types
          let scan_arr_ts :: [Type]
scan_arr_ts = forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
chvar) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
scan_lam
              accrtps :: [Type]
accrtps = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
scan_lam

          [Param Type]
inpacc_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
          Lambda rep
maplam <- forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam (forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
inpacc_ids) Lambda rep
scan_lam
          -- Finally, construct the stream
          let strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
          Lambda rep
strmlam <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
strmpar forall a b. (a -> b) -> a -> b
$ do
            -- 1. let (scan0_ids,map_resids)  = scanomap(scan_lam,nes,map_lam,a_ch)
            ([VName]
scan0_ids, [VName]
map_resids) <-
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_arr_ts)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) forall a b. (a -> b) -> a -> b
$
                  forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
Futhark.scanomapSOAC [forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda rep
scan_lam [SubExp]
nes] Lambda rep
lam'
            -- 2. let outerszm1id = chunksize - 1
            SubExp
outszm1id <-
              forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outszm1" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                BinOp -> SubExp -> SubExp -> BasicOp
BinOp
                  (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef)
                  (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
chunk_param)
                  (forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
            VName
empty_arr <-
              forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"empty_arr" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp
                  (IntType -> CmpOp
CmpSlt IntType
Int64)
                  SubExp
outszm1id
                  (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
            -- 3. let lasteel_ids = ...
            let indexLast :: VName -> BuilderT rep (State VNameSource) (Exp rep)
indexLast VName
arr = do
                  Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix SubExp
outszm1id]
            [VName]
lastel_ids <-
              forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"lastel"
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
empty_arr)
                  (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
nes)
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> BuilderT rep (State VNameSource) (Exp rep)
indexLast [VName]
scan0_ids)
            Body rep
addlelbdy <-
              forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
lastel_ids
            let (Stms rep
addlelstm, Result
addlelres) = (forall rep. Body rep -> Stms rep
bodyStms Body rep
addlelbdy, forall rep. Body rep -> Result
bodyResult Body rep
addlelbdy)
            -- 4. let strm_resids = map (acc `+`,nes, scan0_ids)
            [VName]
strm_resids <-
              forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"strm_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
                forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar [VName]
scan0_ids (forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
maplam)
            -- 5. let acc'        = acc + lasteel_ids
            forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
addlelstm
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
addlelres forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName]
strm_resids forall a. [a] -> [a] -> [a]
++ [VName]
map_resids)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [SubExp]
nes [Input]
inps,
              forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent [Param Type]
inpacc_ids
            )
      | Just ([Reduce rep]
reds, Lambda rep
_) <- forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
Futhark.isRedomapSOAC ScremaForm rep
form,
        Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes <- forall rep. Buildable rep => [Reduce rep] -> Reduce rep
Futhark.singleReduce [Reduce rep]
reds -> do
          -- Redomap(+,lam,nes,a) => is translated in strem's body to:
          -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in
          -- 2. let acc'                   = acc + acc0_ids          in
          --    {acc', strm_resids}

          let accrtps :: [Type]
accrtps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
              -- the chunked-outersize of the array result and input types
              loutps' :: [Type]
loutps' = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
loutps
              -- the lambda with proper index
              foldlam :: Lambda rep
foldlam = Lambda rep
lam'
          -- array result and input IDs of the stream's lambda
          [Ident]
strm_resids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps'
          [Param Type]
inpacc_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
          [Ident]
acc0_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"acc0") [Type]
accrtps
          -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in
          let insoac :: SOAC rep
insoac =
                forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma
                  SubExp
chvar
                  (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids)
                  forall a b. (a -> b) -> a -> b
$ forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
Futhark.redomapSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes] Lambda rep
foldlam
              insstm :: Stm rep
insstm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet ([Ident]
acc0_ids forall a. [a] -> [a] -> [a]
++ [Ident]
strm_resids) forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC rep
insoac
          -- 2. let acc'     = acc + acc0_ids    in
          Body rep
addaccbdy <-
            forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
lamin forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
acc0_ids
          -- Construct the stream
          let (Stms rep
addaccstm, Result
addaccres) = (forall rep. Body rep -> Stms rep
bodyStms Body rep
addaccbdy, forall rep. Body rep -> Result
bodyResult Body rep
addaccbdy)
              strmbdy :: Body rep
strmbdy =
                forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm forall a. Semigroup a => a -> a -> a
<> Stms rep
addaccstm) forall a b. (a -> b) -> a -> b
$
                  Result
addaccres forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
              strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
              strmlam :: Lambda rep
strmlam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
strmpar Body rep
strmbdy ([Type]
accrtps forall a. [a] -> [a] -> [a]
++ [Type]
loutps')
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [SubExp]
nes [Input]
inps, [])

    -- Otherwise it cannot become a stream.
    SOAC rep
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC rep
soac, [])
  where
    mkMapPlusAccLam ::
      (MonadFreshNames m, Buildable rep) =>
      [SubExp] ->
      Lambda rep ->
      m (Lambda rep)
    mkMapPlusAccLam :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam [SubExp]
accs Lambda rep
plus = do
      let ([Param Type]
accpars, [Param Type]
rempars) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus
          parstms :: [Stm rep]
parstms =
            forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\Param Type
par SubExp
se -> forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par] (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
              [Param Type]
accpars
              [SubExp]
accs
          plus_bdy :: Body rep
plus_bdy = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus
          newlambdy :: Body rep
newlambdy =
            forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
              (forall rep. Body rep -> BodyDec rep
bodyDec Body rep
plus_bdy)
              (forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
parstms forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
plus_bdy)
              (forall rep. Body rep -> Result
bodyResult Body rep
plus_bdy)
      forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
rempars Body rep
newlambdy forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
plus

    mkPlusBnds ::
      (MonadFreshNames m, Buildable rep) =>
      Lambda rep ->
      [SubExp] ->
      m (Body rep)
    mkPlusBnds :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
plus [SubExp]
accels = do
      Lambda rep
plus' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
plus
      let parstms :: [Stm rep]
parstms =
            forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\Param Type
par SubExp
se -> forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par] (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
              (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus')
              [SubExp]
accels
          body :: Body rep
body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
parstms forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
body}