{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Monad
( ADM,
RState (..),
runADM,
Adj (..),
InBounds (..),
Sparse (..),
adjFromParam,
adjFromVar,
lookupAdj,
lookupAdjVal,
adjVal,
updateAdj,
updateSubExpAdj,
updateAdjSlice,
updateAdjIndex,
setAdj,
insAdj,
adjsReps,
copyConsumedArrsInStm,
copyConsumedArrsInBody,
addSubstitution,
returnSweepCode,
adjVName,
subAD,
noAdjsFor,
subSubsts,
isActive,
tabNest,
oneExp,
zeroExp,
unitAdjOfType,
addLambda,
VjpOps (..),
setLoopTape,
lookupLoopTape,
substLoopTape,
renameLoopTape,
)
where
import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)
zeroExp :: Type -> Exp rep
zeroExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp (Prim PrimType
pt) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp TypeBase Shape NoUniqueness
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue
oneExp :: Type -> Exp rep
oneExp :: forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp (Prim PrimType
t) = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp TypeBase Shape NoUniqueness
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
data InBounds
=
CheckBounds (Maybe SubExp)
| AssumeBounds
|
OutOfBounds
deriving (InBounds -> InBounds -> Bool
(InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool) -> Eq InBounds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
/= :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
Eq InBounds
-> (InBounds -> InBounds -> Ordering)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> InBounds)
-> (InBounds -> InBounds -> InBounds)
-> Ord InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: InBounds -> InBounds -> Ordering
compare :: InBounds -> InBounds -> Ordering
$c< :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
>= :: InBounds -> InBounds -> Bool
$cmax :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
min :: InBounds -> InBounds -> InBounds
Ord, Int -> InBounds -> [Char] -> [Char]
[InBounds] -> [Char] -> [Char]
InBounds -> [Char]
(Int -> InBounds -> [Char] -> [Char])
-> (InBounds -> [Char])
-> ([InBounds] -> [Char] -> [Char])
-> Show InBounds
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> InBounds -> [Char] -> [Char]
showsPrec :: Int -> InBounds -> [Char] -> [Char]
$cshow :: InBounds -> [Char]
show :: InBounds -> [Char]
$cshowList :: [InBounds] -> [Char] -> [Char]
showList :: [InBounds] -> [Char] -> [Char]
Show)
data Sparse = Sparse
{
Sparse -> Shape
sparseShape :: Shape,
Sparse -> PrimType
sparseType :: PrimType,
Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
}
deriving (Sparse -> Sparse -> Bool
(Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool) -> Eq Sparse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
/= :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Eq Sparse
-> (Sparse -> Sparse -> Ordering)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Sparse)
-> (Sparse -> Sparse -> Sparse)
-> Ord Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Sparse -> Sparse -> Ordering
compare :: Sparse -> Sparse -> Ordering
$c< :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
>= :: Sparse -> Sparse -> Bool
$cmax :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
min :: Sparse -> Sparse -> Sparse
Ord, Int -> Sparse -> [Char] -> [Char]
[Sparse] -> [Char] -> [Char]
Sparse -> [Char]
(Int -> Sparse -> [Char] -> [Char])
-> (Sparse -> [Char])
-> ([Sparse] -> [Char] -> [Char])
-> Show Sparse
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Sparse -> [Char] -> [Char]
showsPrec :: Int -> Sparse -> [Char] -> [Char]
$cshow :: Sparse -> [Char]
show :: Sparse -> [Char]
$cshowList :: [Sparse] -> [Char] -> [Char]
showList :: [Sparse] -> [Char] -> [Char]
Show)
data Adj
= AdjSparse Sparse
| AdjVal SubExp
| AdjZero Shape PrimType
deriving (Adj -> Adj -> Bool
(Adj -> Adj -> Bool) -> (Adj -> Adj -> Bool) -> Eq Adj
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
/= :: Adj -> Adj -> Bool
Eq, Eq Adj
Eq Adj
-> (Adj -> Adj -> Ordering)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Adj)
-> (Adj -> Adj -> Adj)
-> Ord Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Adj -> Adj -> Ordering
compare :: Adj -> Adj -> Ordering
$c< :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
>= :: Adj -> Adj -> Bool
$cmax :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
min :: Adj -> Adj -> Adj
Ord, Int -> Adj -> [Char] -> [Char]
[Adj] -> [Char] -> [Char]
Adj -> [Char]
(Int -> Adj -> [Char] -> [Char])
-> (Adj -> [Char]) -> ([Adj] -> [Char] -> [Char]) -> Show Adj
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Adj -> [Char] -> [Char]
showsPrec :: Int -> Adj -> [Char] -> [Char]
$cshow :: Adj -> [Char]
show :: Adj -> [Char]
$cshowList :: [Adj] -> [Char] -> [Char]
showList :: [Adj] -> [Char] -> [Char]
Show)
instance Substitute Adj where
substituteNames :: Substitutions -> Adj -> Adj
substituteNames Substitutions
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal (SubExp -> Adj) -> SubExp -> Adj
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Substitutions -> VName -> VName
forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
m VName
v
substituteNames Substitutions
_ Adj
adj = Adj
adj
zeroArray :: (MonadBuilder m) => Shape -> Type -> m VName
zeroArray :: forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape TypeBase Shape NoUniqueness
t
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep m)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
| Bool
otherwise = do
SubExp
zero <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> Exp (Rep m)
forall rep. TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
Attrs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing (Attr -> Attrs
oneAttr Attr
"sequential") (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeroes_" (Exp (Rep m) -> m VName)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
shape SubExp
zero
sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
(VName -> [(InBounds, SubExp, SubExp)] -> m VName)
-> [(InBounds, SubExp, SubExp)] -> VName -> m VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> (InBounds, SubExp, SubExp) -> m VName)
-> VName -> [(InBounds, SubExp, SubExp)] -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs (VName -> m VName) -> m VName -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Shape -> TypeBase Shape NoUniqueness -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
where
arr_t :: TypeBase Shape NoUniqueness
arr_t = PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
`arrayOfShape` Shape
shape
f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
let stm :: Safety -> m VName
stm Safety
s =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
case InBounds
check of
InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
InBounds
OutOfBounds -> VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjFromParam :: Param t -> Adj
adjFromParam :: forall t. Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar (VName -> Adj) -> (Param t -> VName) -> Param t -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param t -> VName
forall dec. Param dec -> VName
paramName
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: TypeBase Shape NoUniqueness -> ADM Adj
unitAdjOfType TypeBase Shape NoUniqueness
t = SubExp -> Adj
AdjVal (SubExp -> Adj) -> ADM SubExp -> ADM Adj
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (TypeBase Shape NoUniqueness -> Exp SOACS
forall rep. TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
(((InBounds, SubExp, SubExp) -> [SubExp])
-> [(InBounds, SubExp, SubExp)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (InBounds, SubExp, SubExp) -> [SubExp]
forall {a} {a}. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse (Sparse -> Adj) -> ([SubExp] -> Sparse) -> [SubExp] -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> ([SubExp] -> [(InBounds, SubExp, SubExp)]) -> [SubExp] -> Sparse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(InBounds, SubExp, SubExp)]
-> [SubExp] -> [(InBounds, SubExp, SubExp)]
forall {b} {c} {c}. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
where
ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
(InBounds
check', c
i, c
v) (InBounds, c, c) -> [(InBounds, c, c)] -> [(InBounds, c, c)]
forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
where
check' :: InBounds
check' = case InBounds
check of
InBounds
AssumeBounds -> InBounds
AssumeBounds
CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False))
repIvs [(InBounds, b, c)]
_ [c]
_ = []
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj]))
-> [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. (a -> b) -> a -> b
$ (Adj -> ([SubExp], [SubExp] -> Adj))
-> [Adj] -> [([SubExp], [SubExp] -> Adj)]
forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
in ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, (([SubExp] -> Adj) -> [SubExp] -> Adj)
-> [[SubExp] -> Adj] -> [[SubExp]] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([SubExp] -> Adj) -> [SubExp] -> Adj
forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs ([[SubExp]] -> [Adj])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks (([SubExp] -> Int) -> [[SubExp]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))
data RState = RState
{ RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
RState -> Substitutions
stateLoopTape :: Substitutions,
RState -> Substitutions
stateSubsts :: Substitutions,
RState -> VNameSource
stateNameSource :: VNameSource
}
newtype ADM a = ADM (BuilderT SOACS (State RState) a)
deriving
( (forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
fmap :: forall a b. (a -> b) -> ADM a -> ADM b
$c<$ :: forall a b. a -> ADM b -> ADM a
<$ :: forall a b. a -> ADM b -> ADM a
Functor,
Functor ADM
Functor ADM
-> (forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ADM a
pure :: forall a. a -> ADM a
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
liftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
$c*> :: forall a b. ADM a -> ADM b -> ADM b
*> :: forall a b. ADM a -> ADM b -> ADM b
$c<* :: forall a b. ADM a -> ADM b -> ADM a
<* :: forall a b. ADM a -> ADM b -> ADM a
Applicative,
Applicative ADM
Applicative ADM
-> (forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>> :: forall a b. ADM a -> ADM b -> ADM b
$creturn :: forall a. a -> ADM a
return :: forall a. a -> ADM a
Monad,
MonadState RState,
Monad ADM
ADM VNameSource
Monad ADM
-> ADM VNameSource
-> (VNameSource -> ADM ())
-> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: ADM VNameSource
getNameSource :: ADM VNameSource
$cputNameSource :: VNameSource -> ADM ()
putNameSource :: VNameSource -> ADM ()
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
instance MonadBuilder ADM where
type Rep ADM = SOACS
mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM)))
-> BuilderT SOACS (State RState) (ExpDec (Rep ADM))
-> ADM (ExpDec (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM)))
-> BuilderT SOACS (State RState) (Body (Rep ADM))
-> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM)))
-> BuilderT SOACS (State RState) (Stm (Rep ADM))
-> ADM (Stm (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep (BuilderT SOACS (State RState)))
-> BuilderT SOACS (State RState) ()
Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a. ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM)))
-> BuilderT SOACS (State RState) (a, Stms (Rep ADM))
-> ADM (a, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall a.
BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m
instance MonadFreshNames (State RState) where
getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})
runADM :: (MonadFreshNames m) => ADM a -> m a
runADM :: forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
(VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
(RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> StateT RState Identity (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> StateT RState Identity (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
(Map VName Adj
-> Substitutions -> Substitutions -> VNameSource -> RState
RState Map VName Adj
forall a. Monoid a => a
mempty Substitutions
forall a. Monoid a => a
mempty Substitutions
forall a. Monoid a => a
mempty VNameSource
vn)
adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = Sparse -> ADM VName
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = Shape -> TypeBase Shape NoUniqueness -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (TypeBase Shape NoUniqueness -> ADM VName)
-> TypeBase Shape NoUniqueness -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateAdjs :: Map VName Adj
stateAdjs = VName -> Adj -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Adj
v_adj (Map VName Adj -> Map VName Adj) -> Map VName Adj -> Map VName Adj
forall a b. (a -> b) -> a -> b
$ RState -> Map VName Adj
stateAdjs RState
env}
insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> (VName -> Adj) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = Stm SOACS
-> ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS))
-> ADM (Substitutions, Stms SOACS)
-> ADM (Substitutions, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM)))
-> ADM Substitutions -> ADM (Substitutions, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
s
where
copyConsumedArrsInStm' :: Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
stm =
let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = Stm SOACS -> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM [(VName, VName)] -> ADM [(VName, VName)])
-> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
v_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Array {} -> do
VName
v' <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v)
VName -> VName -> ADM ()
addSubstitution VName
v' VName
v
[(VName, VName)] -> ADM [(VName, VName)]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName
v, VName
v')]
TypeBase Shape NoUniqueness
_ -> [(VName, VName)] -> ADM [(VName, VName)]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, VName)]
forall a. Monoid a => a
mempty
in [(VName, VName)] -> Substitutions
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Substitutions)
-> ([[(VName, VName)]] -> [(VName, VName)])
-> [[(VName, VName)]]
-> Substitutions
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(VName, VName)]] -> [(VName, VName)]
forall a. Monoid a => [a] -> a
mconcat
([[(VName, VName)]] -> Substitutions)
-> ADM [[(VName, VName)]] -> ADM Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM [(VName, VName)])
-> [VName] -> ADM [[(VName, VName)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst (AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty (Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm)))
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
[Substitutions] -> Substitutions
forall a. Monoid a => [a] -> a
mconcat ([Substitutions] -> Substitutions)
-> ADM [Substitutions] -> ADM Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Substitutions) -> [VName] -> ADM [Substitutions]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Substitutions
forall {m :: * -> *}. MonadBuilder m => VName -> m Substitutions
onConsumed ((VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body (Aliases SOACS) -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (AliasTable -> Body SOACS -> Body (Aliases SOACS)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty Body SOACS
b))
where
onConsumed :: VName -> m Substitutions
onConsumed VName
v = do
TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Acc {} -> [Char] -> m Substitutions
forall a. HasCallStack => [Char] -> a
error ([Char] -> m Substitutions) -> [Char] -> m Substitutions
forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
v
Array {} ->
VName -> VName -> Substitutions
forall k a. k -> a -> Map k a
M.singleton VName
v
(VName -> Substitutions) -> m VName -> m Substitutions
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp
(VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy")
(BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
v))
TypeBase Shape NoUniqueness
_ -> Substitutions -> m Substitutions
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Substitutions
forall a. Monoid a => a
mempty
returnSweepCode :: ADM a -> ADM a
returnSweepCode :: forall a. ADM a -> ADM a
returnSweepCode ADM a
m = do
(a
a, Stms SOACS
stms) <- ADM a -> ADM (a, Stms (Rep ADM))
forall a. ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
Substitutions
substs <- (RState -> Substitutions) -> ADM Substitutions
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Substitutions -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
substs Stms SOACS
stms
a -> ADM a
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateSubsts :: Substitutions
stateSubsts = VName -> VName -> Substitutions -> Substitutions
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' (Substitutions -> Substitutions) -> Substitutions -> Substitutions
forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateSubsts RState
env}
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: forall a. Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
[Adj]
old <- (RState -> [Adj]) -> ADM [Adj]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> [Adj]) -> ADM [Adj]) -> (RState -> [Adj]) -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ \RState
env -> (VName -> Maybe Adj) -> [VName] -> [Adj]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = (Map VName Adj -> VName -> Map VName Adj)
-> Map VName Adj -> [VName] -> Map VName Adj
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Map VName Adj -> Map VName Adj)
-> Map VName Adj -> VName -> Map VName Adj
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) (RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'}
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = [(VName, Adj)] -> Map VName Adj
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Adj] -> [(VName, Adj)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names' [Adj]
old) Map VName Adj -> Map VName Adj -> Map VName Adj
forall a. Semigroup a => a -> a -> a
<> RState -> Map VName Adj
stateAdjs RState
env}
a -> ADM a
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
where
names' :: [VName]
names' = Names -> [VName]
namesToList Names
names
addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd
tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = [VName]
-> Int
-> [VName]
-> ([VName] -> [VName] -> ADM [VName])
-> ADM [VName]
forall {m :: * -> *} {t}.
(BodyDec (Rep m) ~ (), OpC (Rep m) ~ SOAC,
LParamInfo (Rep m) ~ TypeBase Shape NoUniqueness, Eq t, Num t,
MonadBuilder m) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
where
tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f ([VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
[TypeBase Shape NoUniqueness]
vs_ts <- (VName -> m (TypeBase Shape NoUniqueness))
-> [VName] -> m [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
vs
let w :: SubExp
w = Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
vs_ts
VName
iota <-
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tab_iota" (Exp (Rep m) -> m VName)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param (TypeBase Shape NoUniqueness)
iparam <- [Char]
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[Param (TypeBase Shape NoUniqueness)]
params <- [VName]
-> (VName -> m (Param (TypeBase Shape NoUniqueness)))
-> m [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
vs ((VName -> m (Param (TypeBase Shape NoUniqueness)))
-> m [Param (TypeBase Shape NoUniqueness)])
-> (VName -> m (Param (TypeBase Shape NoUniqueness)))
-> m [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \VName
v ->
[Char]
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (TypeBase Shape NoUniqueness
-> m (Param (TypeBase Shape NoUniqueness)))
-> m (TypeBase Shape NoUniqueness)
-> m (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
(([TypeBase Shape NoUniqueness]
ret, Result
res), Stms (Rep m)
stms) <- m ([TypeBase Shape NoUniqueness], Result)
-> m (([TypeBase Shape NoUniqueness], Result), Stms (Rep m))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m ([TypeBase Shape NoUniqueness], Result)
-> m (([TypeBase Shape NoUniqueness], Result), Stms (Rep m)))
-> (m ([TypeBase Shape NoUniqueness], Result)
-> m ([TypeBase Shape NoUniqueness], Result))
-> m ([TypeBase Shape NoUniqueness], Result)
-> m (([TypeBase Shape NoUniqueness], Result), Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m)
-> m ([TypeBase Shape NoUniqueness], Result)
-> m ([TypeBase Shape NoUniqueness], Result)
forall a. Scope (Rep m) -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Param (TypeBase Shape NoUniqueness)
iparam Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params)) (m ([TypeBase Shape NoUniqueness], Result)
-> m (([TypeBase Shape NoUniqueness], Result), Stms (Rep m)))
-> m ([TypeBase Shape NoUniqueness], Result)
-> m (([TypeBase Shape NoUniqueness], Result), Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
[VName]
res <- [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' (Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
iparam VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
is) (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) ((Param (TypeBase Shape NoUniqueness) -> VName)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
params) [VName] -> [VName] -> m [VName]
f
[TypeBase Shape NoUniqueness]
ret <- (VName -> m (TypeBase Shape NoUniqueness))
-> [VName] -> m [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
res
([TypeBase Shape NoUniqueness], Result)
-> m ([TypeBase Shape NoUniqueness], Result)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TypeBase Shape NoUniqueness]
ret, [VName] -> Result
varsRes [VName]
res)
let lam :: Lambda (Rep m)
lam = [LParam (Rep m)]
-> [TypeBase Shape NoUniqueness] -> Body (Rep m) -> Lambda (Rep m)
forall rep.
[LParam rep]
-> [TypeBase Shape NoUniqueness] -> Body rep -> Lambda rep
Lambda (Param (TypeBase Shape NoUniqueness)
iparam Param (TypeBase Shape NoUniqueness)
-> [Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params) [TypeBase Shape NoUniqueness]
ret (BodyDec (Rep m) -> Stms (Rep m) -> Result -> Body (Rep m)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res)
[Char] -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tab" (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep m) (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (OpC (Rep m) (Rep m) -> Exp (Rep m))
-> OpC (Rep m) (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) (Lambda (Rep m) -> ScremaForm (Rep m)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
lam)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: TypeBase Shape NoUniqueness
t@Array {} = do
Param (TypeBase Shape NoUniqueness)
xs_p <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" TypeBase Shape NoUniqueness
t
Param (TypeBase Shape NoUniqueness)
ys_p <- [Char]
-> TypeBase Shape NoUniqueness
-> ADM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ys" TypeBase Shape NoUniqueness
t
Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (TypeBase Shape NoUniqueness -> ADM (Lambda SOACS))
-> TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
Body SOACS
body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
res <-
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"lam_map" (Exp SOACS -> ADM SubExp)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM SubExp) -> SOAC SOACS -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t) [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
xs_p, Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
ys_p] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body SOACS
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [SubExp
res]
Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)
LParam SOACS
xs_p, Param (TypeBase Shape NoUniqueness)
LParam SOACS
ys_p],
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness
t],
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
}
addLambda TypeBase Shape NoUniqueness
t =
[Char] -> ADM (Lambda SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Lambda SOACS)) -> [Char] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Show a => a -> [Char]
show TypeBase Shape NoUniqueness
t
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
TypeBase Shape NoUniqueness
x_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
case TypeBase Shape NoUniqueness
x_t of
Prim PrimType
pt ->
Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
Array {} -> do
Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (TypeBase Shape NoUniqueness -> ADM (Lambda SOACS))
-> TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
x_t
Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
x_t) [VName
x, VName
y] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
TypeBase Shape NoUniqueness
_ ->
[Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [Char]
forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
x_t
lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
TypeBase Shape NoUniqueness
v_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
TypeBase Shape NoUniqueness
_ -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
v_t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
v_t)
Just Adj
v_adj -> Adj -> ADM Adj
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj
lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing ->
VName -> VName -> ADM ()
insAdj VName
v VName
d
Just Adj
adj -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
TypeBase Shape NoUniqueness
v_adj_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
case TypeBase Shape NoUniqueness
v_adj_t of
Acc {} -> do
[SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ADM (TypeBase Shape NoUniqueness) -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
d
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
VName
v_adj' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Slice [DimFix SubExp
i]) VName
v VName
d =
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
AssumeBounds, SubExp
i) (VName -> SubExp
Var VName
d)
updateAdjSlice Slice SubExp
slice VName
v VName
d = do
TypeBase Shape NoUniqueness
t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
TypeBase Shape NoUniqueness
v_adj_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
VName
v_adj' <- case TypeBase Shape NoUniqueness
v_adj_t of
Acc {} -> do
let dims :: [SubExp]
dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
[SubExp]
slice' <-
(TPrimExp Int64 VName -> ADM SubExp)
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> ADM [SubExp])
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
v_adj') (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' [SubExp]
slice' [VName -> SubExp
Var VName
d']
VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
VName
v_adjslice <-
if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
then VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
else [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
[Char] -> VName -> Slice SubExp -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"updated_adj" VName
v_adj Slice SubExp
slice (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adjslice VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = () -> ADM ()
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
TypeBase Shape NoUniqueness
t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
let iv :: (InBounds, SubExp, SubExp)
iv = (InBounds
check, SubExp
i, SubExp
se)
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just AdjZero {} ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> [(InBounds, SubExp, SubExp)] -> Sparse
forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv (InBounds, SubExp, SubExp)
-> [(InBounds, SubExp, SubExp)] -> [(InBounds, SubExp, SubExp)]
forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
Just adj :: Adj
adj@AdjVal {} -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
TypeBase Shape NoUniqueness
v_adj_t <- VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
VName
se_v <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"se_v" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
VName -> VName -> ADM ()
insAdj VName
v
(VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case TypeBase Shape NoUniqueness
v_adj_t of
Acc {}
| InBounds
check InBounds -> InBounds -> Bool
forall a. Eq a => a -> a -> Bool
== InBounds
OutOfBounds ->
VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
| Bool
otherwise -> do
[SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ADM (TypeBase Shape NoUniqueness) -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
se_v
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
se_v, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
se_v', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
let stms :: Safety -> ADM VName
stms Safety
s = do
VName
v_adj_i <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
v_adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]
SubExp
se_update <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"updated_adj_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
se_v VName
v_adj_i
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj) (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
v_adj (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se_update
case InBounds
check of
CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
InBounds
OutOfBounds -> VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = (TypeBase Shape NoUniqueness -> Bool)
-> ADM (TypeBase Shape NoUniqueness) -> ADM Bool
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
/= PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) (ADM (TypeBase Shape NoUniqueness) -> ADM Bool)
-> (VName -> ADM (TypeBase Shape NoUniqueness))
-> VName
-> ADM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType
subAD :: ADM a -> ADM a
subAD :: forall a. ADM a -> ADM a
subAD ADM a
m = do
Map VName Adj
old_state_adjs <- (RState -> Map VName Adj) -> ADM (Map VName Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateAdjs :: Map VName Adj
stateAdjs = Map VName Adj
old_state_adjs}
a -> ADM a
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
subSubsts :: ADM a -> ADM a
subSubsts :: forall a. ADM a -> ADM a
subSubsts ADM a
m = do
Substitutions
old_state_substs <- (RState -> Substitutions) -> ADM Substitutions
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateSubsts :: Substitutions
stateSubsts = Substitutions
old_state_substs}
a -> ADM a
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
data VjpOps = VjpOps
{ VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
}
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateLoopTape :: Substitutions
stateLoopTape = VName -> VName -> Substitutions -> Substitutions
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
vs (Substitutions -> Substitutions) -> Substitutions -> Substitutions
forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateLoopTape RState
env}
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Substitutions -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Substitutions -> Maybe VName)
-> (RState -> Substitutions) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Substitutions
stateLoopTape
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = (VName -> ADM ()) -> Maybe VName -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') (Maybe VName -> ADM ()) -> ADM (Maybe VName) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape = ((VName, VName) -> ADM ()) -> [(VName, VName)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> VName -> ADM ()) -> (VName, VName) -> ADM ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) ([(VName, VName)] -> ADM ())
-> (Substitutions -> [(VName, VName)]) -> Substitutions -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Substitutions -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList