{-# LANGUAGE FlexibleContexts #-}
module Futhark.Internalise.Lambdas
  ( InternaliseLambda
  , internaliseMapLambda
  , internaliseStreamMapLambda
  , internaliseFoldLambda
  , internaliseStreamLambda
  , internalisePartitionLambda
  )
  where

import Control.Monad
import Data.Loc

import Language.Futhark as E
import Futhark.Representation.SOACS as I
import Futhark.MonadFreshNames

import Futhark.Internalise.Monad
import Futhark.Internalise.AccurateSizes
import Futhark.Representation.SOACS.Simplify (simplifyLambda)

-- | A function for internalising lambdas.
type InternaliseLambda =
  E.Exp -> [I.Type] -> InternaliseM ([I.LParam], I.Body, [I.ExtType])

internaliseMapLambda :: InternaliseLambda
                     -> E.Exp
                     -> [I.SubExp]
                     -> InternaliseM I.Lambda
internaliseMapLambda :: InternaliseLambda -> Exp -> [SubExp] -> InternaliseM Lambda
internaliseMapLambda InternaliseLambda
internaliseLambda Exp
lam [SubExp]
args = do
  [Type]
argtypes <- (SubExp -> InternaliseM Type) -> [SubExp] -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType [SubExp]
args
  let rowtypes :: [Type]
rowtypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
I.rowType [Type]
argtypes
  ([Param Type]
params, Body
body, [TypeBase ExtShape NoUniqueness]
rettype) <- InternaliseLambda
internaliseLambda Exp
lam [Type]
rowtypes
  ([Type]
rettype', [Ident]
inner_shapes) <- [TypeBase ExtShape NoUniqueness] -> InternaliseM ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' [TypeBase ExtShape NoUniqueness]
rettype
  let outer_shape :: SubExp
outer_shape = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
argtypes
  Lambda
shapefun <- [LParam] -> Body -> [Type] -> [Ident] -> InternaliseM Lambda
makeShapeFun [Param Type]
[LParam]
params Body
body [Type]
rettype' [Ident]
inner_shapes
  (SubExp -> InternaliseM Exp)
-> [SubExp]
-> [Ident]
-> Lambda
-> [SubExp]
-> SubExp
-> InternaliseM ()
bindMapShapes SubExp -> InternaliseM Exp
forall (m :: * -> *) lore. MonadBinder m => SubExp -> m (ExpT lore)
index0 [] [Ident]
inner_shapes Lambda
shapefun [SubExp]
args SubExp
outer_shape
  Body
body' <- Scope SOACS -> InternaliseM Body -> InternaliseM Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params) (InternaliseM Body -> InternaliseM Body)
-> InternaliseM Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$
           (InternaliseM Certificates -> InternaliseM Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore InternaliseM)
-> InternaliseM (Body (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
(m Certificates -> m Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore m)
-> m (Body (Lore m))
ensureResultShape InternaliseM Certificates -> InternaliseM Certificates
asserting
           ([ErrorMsgPart SubExp] -> ErrorMsg SubExp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [String -> ErrorMsgPart SubExp
forall a. String -> ErrorMsgPart a
ErrorString String
"not all iterations produce same shape"])
           (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
lam) [Type]
rettype' Body (Lore InternaliseM)
Body
body
  Lambda -> InternaliseM Lambda
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda -> InternaliseM Lambda) -> Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ [LParam] -> Body -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
I.Lambda [Param Type]
[LParam]
params Body
body' [Type]
rettype'
  where index0 :: SubExp -> m (ExpT lore)
index0 SubExp
arg = do
          VName
arg' <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"arg" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
I.SubExp SubExp
arg
          Type
arg_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arg'
          ExpT lore -> m (ExpT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT lore -> m (ExpT lore)) -> ExpT lore -> m (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
I.Index VName
arg' (Slice SubExp -> BasicOp lore) -> Slice SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arg_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
I.DimFix SubExp
zero]
        zero :: SubExp
zero = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::I.Int32)

internaliseStreamMapLambda :: InternaliseLambda
                           -> E.Exp
                           -> [I.SubExp]
                           -> InternaliseM I.Lambda
internaliseStreamMapLambda :: InternaliseLambda -> Exp -> [SubExp] -> InternaliseM Lambda
internaliseStreamMapLambda InternaliseLambda
internaliseLambda Exp
lam [SubExp]
args = do
  VName
chunk_size <- String -> InternaliseM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_size"
  let chunk_param :: Param (TypeBase shape u)
chunk_param = VName -> TypeBase shape u -> Param (TypeBase shape u)
forall attr. VName -> attr -> Param attr
I.Param VName
chunk_size (PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32)
      outer :: TypeBase Shape u -> TypeBase Shape u
outer = (TypeBase Shape u -> SubExp -> TypeBase Shape u
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
I.Var VName
chunk_size)
  Scope SOACS -> InternaliseM Lambda -> InternaliseM Lambda
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type
forall shape u. Param (TypeBase shape u)
chunk_param]) (InternaliseM Lambda -> InternaliseM Lambda)
-> InternaliseM Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ do
    [Type]
argtypes <- (SubExp -> InternaliseM Type) -> [SubExp] -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType [SubExp]
args
    ([Param Type]
lam_params, Body
orig_body, [TypeBase ExtShape NoUniqueness]
rettype) <-
      InternaliseLambda
internaliseLambda Exp
lam ([Type]
 -> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness]))
-> [Type]
-> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
outer [Type]
argtypes
    let Param Type
orig_chunk_param : [Param Type]
params = [Param Type]
lam_params
    Body
body <- Binder SOACS Body -> InternaliseM Body
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS Body -> InternaliseM Body)
-> Binder SOACS Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
orig_chunk_param] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp SOACS -> Exp
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp SOACS -> Exp) -> BasicOp SOACS -> Exp
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp SOACS
forall lore. SubExp -> BasicOp lore
I.SubExp (SubExp -> BasicOp SOACS) -> SubExp -> BasicOp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var VName
chunk_size
      Body -> Binder SOACS Body
forall (m :: * -> *) a. Monad m => a -> m a
return Body
orig_body
    ([Type]
rettype', [Ident]
inner_shapes) <- [TypeBase ExtShape NoUniqueness] -> InternaliseM ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' [TypeBase ExtShape NoUniqueness]
rettype
    let outer_shape :: SubExp
outer_shape = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
argtypes
    Lambda
shapefun <- [LParam] -> Body -> [Type] -> [Ident] -> InternaliseM Lambda
makeShapeFun (Param Type
forall shape u. Param (TypeBase shape u)
chunk_paramParam Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:[Param Type]
params) Body
body [Type]
rettype' [Ident]
inner_shapes
    (SubExp -> InternaliseM Exp)
-> [SubExp]
-> [Ident]
-> Lambda
-> [SubExp]
-> SubExp
-> InternaliseM ()
bindMapShapes (VName -> SubExp -> InternaliseM Exp
forall (m :: * -> *) lore.
MonadBinder m =>
VName -> SubExp -> m (ExpT lore)
slice0 VName
chunk_size) [SubExp
zero] [Ident]
inner_shapes Lambda
shapefun [SubExp]
args SubExp
outer_shape
    Body
body' <- Scope SOACS -> InternaliseM Body -> InternaliseM Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params) (InternaliseM Body -> InternaliseM Body)
-> InternaliseM Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ InternaliseM (Body (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (InternaliseM (Body (Lore InternaliseM))
 -> InternaliseM (Body (Lore InternaliseM)))
-> InternaliseM (Body (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
forall a b. (a -> b) -> a -> b
$ do
      [VName] -> Exp (Lore InternaliseM) -> InternaliseM ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
orig_chunk_param] (Exp (Lore InternaliseM) -> InternaliseM ())
-> Exp (Lore InternaliseM) -> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ BasicOp SOACS -> Exp
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp SOACS -> Exp) -> BasicOp SOACS -> Exp
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp SOACS
forall lore. SubExp -> BasicOp lore
I.SubExp (SubExp -> BasicOp SOACS) -> SubExp -> BasicOp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var VName
chunk_size
      (InternaliseM Certificates -> InternaliseM Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore InternaliseM)
-> InternaliseM (Body (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
(m Certificates -> m Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore m)
-> m (Body (Lore m))
ensureResultShape InternaliseM Certificates -> InternaliseM Certificates
asserting
        ([ErrorMsgPart SubExp] -> ErrorMsg SubExp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [String -> ErrorMsgPart SubExp
forall a. String -> ErrorMsgPart a
ErrorString String
"not all iterations produce same shape"])
        (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
lam) ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
outer [Type]
rettype') Body (Lore InternaliseM)
Body
body
    Lambda -> InternaliseM Lambda
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda -> InternaliseM Lambda) -> Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ [LParam] -> Body -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
I.Lambda (Param Type
forall shape u. Param (TypeBase shape u)
chunk_paramParam Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:[Param Type]
params) Body
body' ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
outer [Type]
rettype')
  where slice0 :: VName -> SubExp -> m (ExpT lore)
slice0 VName
chunk_size SubExp
arg = do
          VName
arg' <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"arg" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
I.SubExp SubExp
arg
          Type
arg_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arg'
          ExpT lore -> m (ExpT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT lore -> m (ExpT lore)) -> ExpT lore -> m (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp lore
forall lore. VName -> Slice SubExp -> BasicOp lore
I.Index VName
arg' (Slice SubExp -> BasicOp lore) -> Slice SubExp -> BasicOp lore
forall a b. (a -> b) -> a -> b
$
            Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arg_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
I.DimSlice SubExp
zero (VName -> SubExp
I.Var VName
chunk_size) SubExp
one]
        zero :: SubExp
zero = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::I.Int32)
        one :: SubExp
one = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::I.Int32)

makeShapeFun :: [I.LParam] -> I.Body -> [Type] -> [I.Ident]
             -> InternaliseM I.Lambda
makeShapeFun :: [LParam] -> Body -> [Type] -> [Ident] -> InternaliseM Lambda
makeShapeFun [LParam]
params Body
body [Type]
val_rettype [Ident]
inner_shapes = do
  -- Some of 'params' may be unique, which means that the shape slice
  -- would consume its input.  This is not acceptable - that input is
  -- needed for the value function!  Hence, for all unique parameters,
  -- we create a substitute non-unique parameter, and insert a
  -- copy-binding in the body of the function.
  ([Param Type]
params', Seq (Stm SOACS)
copystms) <- [LParam] -> InternaliseM ([LParam], Seq (Stm SOACS))
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, HasScope lore m,
 BinderOps lore) =>
[LParam lore] -> m ([LParam lore], Stms lore)
nonuniqueParams [LParam]
params
  Body
shape_body <- Binder SOACS Body -> InternaliseM Body
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS Body -> InternaliseM Body)
-> Binder SOACS Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> Binder SOACS Body -> Binder SOACS Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params') (Binder SOACS Body -> Binder SOACS Body)
-> Binder SOACS Body -> Binder SOACS Body
forall a b. (a -> b) -> a -> b
$ do
    (Stm SOACS -> BinderT SOACS (State VNameSource) ())
-> Seq (Stm SOACS) -> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Seq (Stm SOACS)
copystms
    [VName] -> [Type] -> Body -> Binder SOACS Body
forall lore (m :: * -> *).
(HasScope lore m, MonadFreshNames m, BinderOps lore,
 Bindable lore) =>
[VName] -> [Type] -> Body lore -> m (Body lore)
shapeBody ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
I.identName [Ident]
inner_shapes) [Type]
val_rettype Body
body
  Lambda -> InternaliseM Lambda
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda -> InternaliseM Lambda) -> Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ [LParam] -> Body -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
I.Lambda [Param Type]
[LParam]
params' Body
shape_body [Type]
forall shape u. [TypeBase shape u]
rettype
  where rettype :: [TypeBase shape u]
rettype = Int -> TypeBase shape u -> [TypeBase shape u]
forall a. Int -> a -> [a]
replicate ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
inner_shapes) (TypeBase shape u -> [TypeBase shape u])
-> TypeBase shape u -> [TypeBase shape u]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32

bindMapShapes :: (SubExp -> InternaliseM I.Exp) -> [SubExp]
              -> [I.Ident] -> I.Lambda -> [I.SubExp] -> SubExp
              -> InternaliseM ()
bindMapShapes :: (SubExp -> InternaliseM Exp)
-> [SubExp]
-> [Ident]
-> Lambda
-> [SubExp]
-> SubExp
-> InternaliseM ()
bindMapShapes SubExp -> InternaliseM Exp
indexArg [SubExp]
extra_args [Ident]
inner_shapes Lambda
sizefun [SubExp]
args SubExp
outer_shape
  | [Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
I.lambdaReturnType Lambda
sizefun = () -> InternaliseM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do
      let size_args :: [Maybe a]
size_args = Int -> Maybe a -> [Maybe a]
forall a. Int -> a -> [a]
replicate ([Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Param Type] -> Int) -> [Param Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
sizefun) Maybe a
forall a. Maybe a
Nothing
      Lambda
sizefun' <- Lambda -> [Maybe VName] -> InternaliseM Lambda
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda -> [Maybe VName] -> m Lambda
simplifyLambda Lambda
sizefun [Maybe VName]
forall a. [Maybe a]
size_args
      let sizefun_safe :: Bool
sizefun_safe =
            (Stm SOACS -> Bool) -> Seq (Stm SOACS) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
I.safeExp (Exp -> Bool) -> (Stm SOACS -> Exp) -> Stm SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Exp
forall lore. Stm lore -> Exp lore
I.stmExp) (Seq (Stm SOACS) -> Bool) -> Seq (Stm SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ Body -> Seq (Stm SOACS)
forall lore. BodyT lore -> Stms lore
I.bodyStms (Body -> Seq (Stm SOACS)) -> Body -> Seq (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
I.lambdaBody Lambda
sizefun'
          sizefun_arg_invariant :: Bool
sizefun_arg_invariant =
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` Body -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
I.lambdaBody Lambda
sizefun')) (VName -> Bool) -> (Param Type -> VName) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
I.paramName) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$
            Lambda -> [LParam]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
sizefun'
      if Bool
sizefun_safe Bool -> Bool -> Bool
&& Bool
sizefun_arg_invariant
        then do [SubExp]
ses <- Body (Lore InternaliseM) -> InternaliseM [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Body (Lore InternaliseM) -> InternaliseM [SubExp])
-> Body (Lore InternaliseM) -> InternaliseM [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
sizefun'
                [(Ident, SubExp)]
-> ((Ident, SubExp) -> InternaliseM ()) -> InternaliseM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Ident] -> [SubExp] -> [(Ident, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
inner_shapes [SubExp]
ses) (((Ident, SubExp) -> InternaliseM ()) -> InternaliseM ())
-> ((Ident, SubExp) -> InternaliseM ()) -> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v, SubExp
se) ->
                  Pattern (Lore InternaliseM)
-> Exp (Lore InternaliseM) -> InternaliseM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident
v]) (Exp (Lore InternaliseM) -> InternaliseM ())
-> Exp (Lore InternaliseM) -> InternaliseM ()
forall a b. (a -> b) -> a -> b
$ BasicOp SOACS -> Exp
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp SOACS -> Exp) -> BasicOp SOACS -> Exp
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp SOACS
forall lore. SubExp -> BasicOp lore
I.SubExp SubExp
se
        else Pattern (Lore InternaliseM)
-> Exp (Lore InternaliseM) -> InternaliseM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ ([Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
inner_shapes) (Exp -> InternaliseM ()) -> InternaliseM Exp -> InternaliseM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
             InternaliseM (Exp (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
-> IfSort
-> InternaliseM (Exp (Lore InternaliseM))
forall (m :: * -> *).
(MonadBinder m,
 BranchType (Lore m) ~ TypeBase ExtShape NoUniqueness) =>
m (Exp (Lore m))
-> m (Body (Lore m))
-> m (Body (Lore m))
-> IfSort
-> m (Exp (Lore m))
eIf' InternaliseM (Exp (Lore InternaliseM))
isnonempty InternaliseM (Body (Lore InternaliseM))
InternaliseM Body
nonemptybranch InternaliseM (Body (Lore InternaliseM))
InternaliseM Body
emptybranch IfSort
IfFallback

  where emptybranch :: InternaliseM Body
emptybranch =
          Body -> InternaliseM Body
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body -> InternaliseM Body) -> Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ((Type -> SubExp) -> [Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Type -> SubExp
forall a b. a -> b -> a
const SubExp
zero) ([Type] -> [SubExp]) -> [Type] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
I.lambdaReturnType Lambda
sizefun)
        nonemptybranch :: InternaliseM (Body (Lore InternaliseM))
nonemptybranch = InternaliseM (Body (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (InternaliseM (Body (Lore InternaliseM))
 -> InternaliseM (Body (Lore InternaliseM)))
-> InternaliseM (Body (Lore InternaliseM))
-> InternaliseM (Body (Lore InternaliseM))
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> Body
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body) -> InternaliseM [SubExp] -> InternaliseM Body
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Lambda (Lore InternaliseM)
-> [InternaliseM (Exp (Lore InternaliseM))]
-> InternaliseM [SubExp]
forall (m :: * -> *).
MonadBinder m =>
Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda Lambda (Lore InternaliseM)
Lambda
sizefun ([InternaliseM Exp] -> InternaliseM [SubExp])
-> ([InternaliseM Exp] -> [InternaliseM Exp])
-> [InternaliseM Exp]
-> InternaliseM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((SubExp -> InternaliseM Exp) -> [SubExp] -> [InternaliseM Exp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> InternaliseM Exp
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp [SubExp]
extra_args[InternaliseM Exp] -> [InternaliseM Exp] -> [InternaliseM Exp]
forall a. [a] -> [a] -> [a]
++) ([InternaliseM Exp] -> InternaliseM [SubExp])
-> [InternaliseM Exp] -> InternaliseM [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> InternaliseM Exp) -> [SubExp] -> [InternaliseM Exp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> InternaliseM Exp
indexArg [SubExp]
args)

        isnonempty :: InternaliseM (Exp (Lore InternaliseM))
isnonempty = InternaliseM (Exp (Lore InternaliseM))
-> InternaliseM (Exp (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
m (Exp (Lore m)) -> m (Exp (Lore m))
eNot (InternaliseM (Exp (Lore InternaliseM))
 -> InternaliseM (Exp (Lore InternaliseM)))
-> InternaliseM (Exp (Lore InternaliseM))
-> InternaliseM (Exp (Lore InternaliseM))
forall a b. (a -> b) -> a -> b
$ CmpOp
-> InternaliseM (Exp (Lore InternaliseM))
-> InternaliseM (Exp (Lore InternaliseM))
-> InternaliseM (Exp (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
CmpOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eCmpOp (PrimType -> CmpOp
I.CmpEq PrimType
I.int32)
                     (Exp (Lore InternaliseM) -> InternaliseM (Exp (Lore InternaliseM))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore InternaliseM) -> InternaliseM (Exp (Lore InternaliseM)))
-> Exp (Lore InternaliseM)
-> InternaliseM (Exp (Lore InternaliseM))
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM)
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM))
-> BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore InternaliseM)
forall lore. SubExp -> BasicOp lore
I.SubExp SubExp
outer_shape)
                     (Exp (Lore InternaliseM) -> InternaliseM (Exp (Lore InternaliseM))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore InternaliseM) -> InternaliseM (Exp (Lore InternaliseM)))
-> Exp (Lore InternaliseM)
-> InternaliseM (Exp (Lore InternaliseM))
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM)
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM))
-> BasicOp (Lore InternaliseM) -> Exp (Lore InternaliseM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore InternaliseM)
forall lore. SubExp -> BasicOp lore
SubExp SubExp
zero)
        zero :: SubExp
zero = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::I.Int32)

internaliseFoldLambda :: InternaliseLambda
                      -> E.Exp
                      -> [I.Type] -> [I.Type]
                      -> InternaliseM I.Lambda
internaliseFoldLambda :: InternaliseLambda -> Exp -> [Type] -> [Type] -> InternaliseM Lambda
internaliseFoldLambda InternaliseLambda
internaliseLambda Exp
lam [Type]
acctypes [Type]
arrtypes = do
  let rowtypes :: [Type]
rowtypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
I.rowType [Type]
arrtypes
  ([Param Type]
params, Body
body, [TypeBase ExtShape NoUniqueness]
rettype) <- InternaliseLambda
internaliseLambda Exp
lam ([Type]
 -> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness]))
-> [Type]
-> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ [Type]
acctypes [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
rowtypes
  let rettype' :: [Type]
rettype' = [ TypeBase ExtShape NoUniqueness
t TypeBase ExtShape NoUniqueness -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`I.setArrayShape` Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
I.arrayShape Type
shape
                   | (TypeBase ExtShape NoUniqueness
t,Type
shape) <- [TypeBase ExtShape NoUniqueness]
-> [Type] -> [(TypeBase ExtShape NoUniqueness, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TypeBase ExtShape NoUniqueness]
rettype [Type]
acctypes ]
  -- The result of the body must have the exact same shape as the
  -- initial accumulator.  We accomplish this with an assertion and
  -- reshape().
  Body
body' <- Scope SOACS -> InternaliseM Body -> InternaliseM Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params) (InternaliseM Body -> InternaliseM Body)
-> InternaliseM Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$
           (InternaliseM Certificates -> InternaliseM Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore InternaliseM)
-> InternaliseM (Body (Lore InternaliseM))
forall (m :: * -> *).
MonadBinder m =>
(m Certificates -> m Certificates)
-> ErrorMsg SubExp
-> SrcLoc
-> [Type]
-> Body (Lore m)
-> m (Body (Lore m))
ensureResultShape InternaliseM Certificates -> InternaliseM Certificates
asserting
           ([ErrorMsgPart SubExp] -> ErrorMsg SubExp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [String -> ErrorMsgPart SubExp
forall a. String -> ErrorMsgPart a
ErrorString String
"shape of result does not match shape of initial value"])
           (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
lam) [Type]
rettype' Body (Lore InternaliseM)
Body
body
  Lambda -> InternaliseM Lambda
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda -> InternaliseM Lambda) -> Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ [LParam] -> Body -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
I.Lambda [Param Type]
[LParam]
params Body
body' [Type]
rettype'

internaliseStreamLambda :: InternaliseLambda
                        -> E.Exp
                        -> [I.Type]
                        -> InternaliseM ([LParam], Body)
internaliseStreamLambda :: InternaliseLambda -> Exp -> [Type] -> InternaliseM ([LParam], Body)
internaliseStreamLambda InternaliseLambda
internaliseLambda Exp
lam [Type]
rowts = do
  VName
chunk_size <- String -> InternaliseM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_size"
  let chunk_param :: Param (TypeBase shape u)
chunk_param = VName -> TypeBase shape u -> Param (TypeBase shape u)
forall attr. VName -> attr -> Param attr
I.Param VName
chunk_size (TypeBase shape u -> Param (TypeBase shape u))
-> TypeBase shape u -> Param (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32
      chunktypes :: [Type]
chunktypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
I.Var VName
chunk_size) [Type]
rowts
  Scope SOACS
-> InternaliseM ([Param Type], Body)
-> InternaliseM ([Param Type], Body)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type
forall shape u. Param (TypeBase shape u)
chunk_param]) (InternaliseM ([Param Type], Body)
 -> InternaliseM ([Param Type], Body))
-> InternaliseM ([Param Type], Body)
-> InternaliseM ([Param Type], Body)
forall a b. (a -> b) -> a -> b
$ do
    ([Param Type]
lam_params, Body
orig_body, [TypeBase ExtShape NoUniqueness]
_) <-
      InternaliseLambda
internaliseLambda Exp
lam ([Type]
 -> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness]))
-> [Type]
-> InternaliseM ([LParam], Body, [TypeBase ExtShape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
chunktypes
    let Param Type
orig_chunk_param : [Param Type]
params = [Param Type]
lam_params
    Body
body <- Binder SOACS Body -> InternaliseM Body
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS Body -> InternaliseM Body)
-> Binder SOACS Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
orig_chunk_param] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp SOACS -> Exp
forall lore. BasicOp lore -> ExpT lore
I.BasicOp (BasicOp SOACS -> Exp) -> BasicOp SOACS -> Exp
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp SOACS
forall lore. SubExp -> BasicOp lore
I.SubExp (SubExp -> BasicOp SOACS) -> SubExp -> BasicOp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var VName
chunk_size
      Body -> Binder SOACS Body
forall (m :: * -> *) a. Monad m => a -> m a
return Body
orig_body
    ([Param Type], Body) -> InternaliseM ([Param Type], Body)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
forall shape u. Param (TypeBase shape u)
chunk_paramParam Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:[Param Type]
params, Body
body)

-- Given @k@ lambdas, this will return a lambda that returns an
-- (k+2)-element tuple of integers.  The first element is the
-- equivalence class ID in the range [0,k].  The remaining are all zero
-- except for possibly one element.
internalisePartitionLambda :: InternaliseLambda
                           -> Int
                           -> E.Exp
                           -> [I.SubExp]
                           -> InternaliseM I.Lambda
internalisePartitionLambda :: InternaliseLambda -> Int -> Exp -> [SubExp] -> InternaliseM Lambda
internalisePartitionLambda InternaliseLambda
internaliseLambda Int
k Exp
lam [SubExp]
args = do
  [Type]
argtypes <- (SubExp -> InternaliseM Type) -> [SubExp] -> InternaliseM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType [SubExp]
args
  let rowtypes :: [Type]
rowtypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
I.rowType [Type]
argtypes
  ([Param Type]
params, Body
body, [TypeBase ExtShape NoUniqueness]
_) <- InternaliseLambda
internaliseLambda Exp
lam [Type]
rowtypes
  Body
body' <- Scope SOACS -> InternaliseM Body -> InternaliseM Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params) (InternaliseM Body -> InternaliseM Body)
-> InternaliseM Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$
           Body -> InternaliseM Body
lambdaWithIncrement Body
body
  Lambda -> InternaliseM Lambda
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda -> InternaliseM Lambda) -> Lambda -> InternaliseM Lambda
forall a b. (a -> b) -> a -> b
$ [LParam] -> Body -> [Type] -> Lambda
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
I.Lambda [Param Type]
[LParam]
params Body
body' [Type]
forall shape u. [TypeBase shape u]
rettype
  where rettype :: [TypeBase shape u]
rettype = Int -> TypeBase shape u -> [TypeBase shape u]
forall a. Int -> a -> [a]
replicate (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) (TypeBase shape u -> [TypeBase shape u])
-> TypeBase shape u -> [TypeBase shape u]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int32
        result :: Int -> [SubExp]
result Int
i = (Int32 -> SubExp) -> [Int32] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant ([Int32] -> [SubExp]) -> [Int32] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i :: Int32) Int32 -> [Int32] -> [Int32]
forall a. a -> [a] -> [a]
:
                   (Int -> Int32 -> [Int32]
forall a. Int -> a -> [a]
replicate Int
i Int32
0 [Int32] -> [Int32] -> [Int32]
forall a. [a] -> [a] -> [a]
++ [Int32
1::Int32] [Int32] -> [Int32] -> [Int32]
forall a. [a] -> [a] -> [a]
++ Int -> Int32 -> [Int32]
forall a. Int -> a -> [a]
replicate (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i) Int32
0)

        mkResult :: SubExp -> Int -> m [SubExp]
mkResult SubExp
_ Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
k = [SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> m [SubExp]) -> [SubExp] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp]
result Int
i
        mkResult SubExp
eq_class Int
i = do
          SubExp
is_i <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"is_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int32) SubExp
eq_class (Int -> SubExp
forall v. IsValue v => v -> SubExp
constant Int
i)
          ([VName] -> [SubExp]) -> m [VName] -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
I.Var) (m [VName] -> m [SubExp])
-> (Exp (Lore m) -> m [VName]) -> Exp (Lore m) -> m [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Lore m) -> m [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"part_res" (Exp (Lore m) -> m [SubExp]) -> m (Exp (Lore m)) -> m [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
            m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m,
 BranchType (Lore m) ~ TypeBase ExtShape NoUniqueness) =>
m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
is_i) (Body (Lore m) -> m (Body (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Lore m) -> m (Body (Lore m)))
-> Body (Lore m) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body (Lore m)
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body (Lore m)) -> [SubExp] -> Body (Lore m)
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp]
result Int
i)
                               ([SubExp] -> Body (Lore m)
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body (Lore m)) -> m [SubExp] -> m (Body (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Int -> m [SubExp]
mkResult SubExp
eq_class (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))

        lambdaWithIncrement :: I.Body -> InternaliseM I.Body
        lambdaWithIncrement :: Body -> InternaliseM Body
lambdaWithIncrement Body
lam_body = Binder SOACS Body -> InternaliseM Body
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS Body -> InternaliseM Body)
-> Binder SOACS Body -> InternaliseM Body
forall a b. (a -> b) -> a -> b
$ do
          SubExp
eq_class <- [SubExp] -> SubExp
forall a. [a] -> a
head ([SubExp] -> SubExp)
-> BinderT SOACS (State VNameSource) [SubExp]
-> BinderT SOACS (State VNameSource) SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind Body (Lore (BinderT SOACS (State VNameSource)))
Body
lam_body
          [SubExp] -> Body
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body)
-> BinderT SOACS (State VNameSource) [SubExp] -> Binder SOACS Body
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Int -> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
SubExp -> Int -> m [SubExp]
mkResult SubExp
eq_class Int
0