{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Monad
( ADM,
RState (..),
runADM,
Adj (..),
InBounds (..),
Sparse (..),
adjFromParam,
adjFromVar,
lookupAdj,
lookupAdjVal,
adjVal,
updateAdj,
updateSubExpAdj,
updateAdjSlice,
updateAdjIndex,
setAdj,
insAdj,
adjsReps,
copyConsumedArrsInStm,
copyConsumedArrsInBody,
addSubstitution,
returnSweepCode,
adjVName,
subAD,
noAdjsFor,
subSubsts,
isActive,
tabNest,
oneExp,
zeroExp,
unitAdjOfType,
addLambda,
VjpOps (..),
setLoopTape,
lookupLoopTape,
substLoopTape,
renameLoopTape,
)
where
import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)
zeroExp :: Type -> Exp rep
zeroExp :: forall {k} (rep :: k). TypeBase Shape NoUniqueness -> Exp rep
zeroExp (Prim PrimType
pt) =
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp TypeBase Shape NoUniqueness
t = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue forall a b. (a -> b) -> a -> b
$ forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue
oneExp :: Type -> Exp rep
oneExp :: forall {k} (rep :: k). TypeBase Shape NoUniqueness -> Exp rep
oneExp (Prim PrimType
t) = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp TypeBase Shape NoUniqueness
t = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
t
data InBounds
=
CheckBounds (Maybe SubExp)
| AssumeBounds
|
OutOfBounds
deriving (InBounds -> InBounds -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c== :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmax :: InBounds -> InBounds -> InBounds
>= :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c< :: InBounds -> InBounds -> Bool
compare :: InBounds -> InBounds -> Ordering
$ccompare :: InBounds -> InBounds -> Ordering
Ord, Int -> InBounds -> ShowS
[InBounds] -> ShowS
InBounds -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [InBounds] -> ShowS
$cshowList :: [InBounds] -> ShowS
show :: InBounds -> [Char]
$cshow :: InBounds -> [Char]
showsPrec :: Int -> InBounds -> ShowS
$cshowsPrec :: Int -> InBounds -> ShowS
Show)
data Sparse = Sparse
{
Sparse -> Shape
sparseShape :: Shape,
Sparse -> PrimType
sparseType :: PrimType,
Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
}
deriving (Sparse -> Sparse -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c== :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmax :: Sparse -> Sparse -> Sparse
>= :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c< :: Sparse -> Sparse -> Bool
compare :: Sparse -> Sparse -> Ordering
$ccompare :: Sparse -> Sparse -> Ordering
Ord, Int -> Sparse -> ShowS
[Sparse] -> ShowS
Sparse -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Sparse] -> ShowS
$cshowList :: [Sparse] -> ShowS
show :: Sparse -> [Char]
$cshow :: Sparse -> [Char]
showsPrec :: Int -> Sparse -> ShowS
$cshowsPrec :: Int -> Sparse -> ShowS
Show)
data Adj
= AdjSparse Sparse
| AdjVal SubExp
| AdjZero Shape PrimType
deriving (Adj -> Adj -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c== :: Adj -> Adj -> Bool
Eq, Eq Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmax :: Adj -> Adj -> Adj
>= :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c< :: Adj -> Adj -> Bool
compare :: Adj -> Adj -> Ordering
$ccompare :: Adj -> Adj -> Ordering
Ord, Int -> Adj -> ShowS
[Adj] -> ShowS
Adj -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Adj] -> ShowS
$cshowList :: [Adj] -> ShowS
show :: Adj -> [Char]
$cshow :: Adj -> [Char]
showsPrec :: Int -> Adj -> ShowS
$cshowsPrec :: Int -> Adj -> ShowS
Show)
instance Substitute Adj where
substituteNames :: Substitutions -> Adj -> Adj
substituteNames Substitutions
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
m VName
v
substituteNames Substitutions
_ Adj
adj = Adj
adj
zeroArray :: MonadBuilder m => Shape -> Type -> m VName
zeroArray :: forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape TypeBase Shape NoUniqueness
t
| forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Eq a => a -> a -> Bool
== Int
0 =
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
| Bool
otherwise = do
SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). TypeBase Shape NoUniqueness -> Exp rep
zeroExp TypeBase Shape NoUniqueness
t
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing (Attr -> Attrs
oneAttr Attr
"sequential") forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeroes_" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
shape SubExp
zero
sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
where
arr_t :: TypeBase Shape NoUniqueness
arr_t = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t TypeBase Shape NoUniqueness -> Shape -> TypeBase Shape NoUniqueness
`arrayOfShape` Shape
shape
f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
let stm :: Safety -> m VName
stm Safety
s =
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
arr_t [forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
case InBounds
check of
InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
InBounds
OutOfBounds -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjFromParam :: Param t -> Adj
adjFromParam :: forall t. Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: TypeBase Shape NoUniqueness -> ADM Adj
unitAdjOfType TypeBase Shape NoUniqueness
t = SubExp -> Adj
AdjVal forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (forall {k} (rep :: k). TypeBase Shape NoUniqueness -> Exp rep
oneExp TypeBase Shape NoUniqueness
t)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
(forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {a} {a}. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b} {c} {c}. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
where
ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
(InBounds
check', c
i, c
v) forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
where
check' :: InBounds
check' = case InBounds
check of
InBounds
AssumeBounds -> InBounds
AssumeBounds
CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (forall a. a -> Maybe a
Just (forall v. IsValue v => v -> SubExp
constant Bool
False))
repIvs [(InBounds, b, c)]
_ [c]
_ = []
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
in (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))
data RState = RState
{ RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
RState -> Substitutions
stateLoopTape :: Substitutions,
RState -> Substitutions
stateSubsts :: Substitutions,
RState -> VNameSource
stateNameSource :: VNameSource
}
newtype ADM a = ADM (BuilderT SOACS (State RState) a)
deriving
( forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ADM b -> ADM a
$c<$ :: forall a b. a -> ADM b -> ADM a
fmap :: forall a b. (a -> b) -> ADM a -> ADM b
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
Functor,
Functor ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ADM a -> ADM b -> ADM a
$c<* :: forall a b. ADM a -> ADM b -> ADM a
*> :: forall a b. ADM a -> ADM b -> ADM b
$c*> :: forall a b. ADM a -> ADM b -> ADM b
liftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
pure :: forall a. a -> ADM a
$cpure :: forall a. a -> ADM a
Applicative,
Applicative ADM
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ADM a
$creturn :: forall a. a -> ADM a
>> :: forall a b. ADM a -> ADM b -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
Monad,
MonadState RState,
Monad ADM
ADM VNameSource
VNameSource -> ADM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> ADM ()
$cputNameSource :: VNameSource -> ADM ()
getNameSource :: ADM VNameSource
$cgetNameSource :: ADM VNameSource
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
instance MonadBuilder ADM where
type Rep ADM = SOACS
mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e
mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep ADM)
bnds Result
res
mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e
addStms :: Stms (Rep ADM) -> ADM ()
addStms = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: forall a. ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = forall a. BuilderT SOACS (State RState) a -> ADM a
ADM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m
instance MonadFreshNames (State RState) where
getNameSource :: State RState VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})
runADM :: MonadFreshNames m => ADM a -> m a
runADM :: forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource forall a b. (a -> b) -> a -> b
$
forall s a. State s a -> s -> (a, s)
runState
(forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m forall a. Monoid a => a
mempty)
(Map VName Adj
-> Substitutions -> Substitutions -> VNameSource -> RState
RState forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty VNameSource
vn)
adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = forall (m :: * -> *).
MonadBuilder m =>
Shape -> TypeBase Shape NoUniqueness -> m VName
zeroArray Shape
shape forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Adj
v_adj forall a b. (a -> b) -> a -> b
$ RState -> Map VName Adj
stateAdjs RState
env}
insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
s
where
copyConsumedArrsInStm' :: Stm SOACS -> ADM Substitutions
copyConsumedArrsInStm' Stm SOACS
stm =
let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Array {} -> do
VName
v' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
VName -> VName -> ADM ()
addSubstitution VName
v' VName
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName
v, VName
v')]
TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
in forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Stms rep -> Names
consumedInStms forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm SOACS
stm)))
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => VName -> m Substitutions
onConsumed (forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Body rep -> Names
consumedInBody (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody forall a. Monoid a => a
mempty Body SOACS
b))
where
onConsumed :: VName -> m Substitutions
onConsumed VName
v = do
TypeBase Shape NoUniqueness
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Acc {} -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString VName
v
Array {} -> forall k a. k -> a -> Map k a
M.singleton VName
v forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
returnSweepCode :: ADM a -> ADM a
returnSweepCode :: forall a. ADM a -> ADM a
returnSweepCode ADM a
m = do
(a
a, Stms SOACS
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
Substitutions
substs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
substs Stms SOACS
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateSubsts :: Substitutions
stateSubsts = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateSubsts RState
env}
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: forall a. Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
[Adj]
old <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ \RState
env -> forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Map k a
M.delete) (RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'}
a
x <- ADM a
m
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names' [Adj]
old) forall a. Semigroup a => a -> a -> a
<> RState -> Map VName Adj
stateAdjs RState
env}
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
where
names' :: [VName]
names' = Names -> [VName]
namesToList Names
names
addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd
tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = forall {m :: * -> *} {t}.
(LParamInfo (Rep m) ~ TypeBase Shape NoUniqueness,
BodyDec (Rep m) ~ (), Op (Rep m) ~ SOAC (Rep m), Eq t, Num t,
MonadBuilder m) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
where
tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f (forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
[TypeBase Shape NoUniqueness]
vs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
vs
let w :: SubExp
w = forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [TypeBase Shape NoUniqueness]
vs_ts
VName
iota <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tab_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param (TypeBase Shape NoUniqueness)
iparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[Param (TypeBase Shape NoUniqueness)]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
vs forall a b. (a -> b) -> a -> b
$ \VName
v ->
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
(([TypeBase Shape NoUniqueness]
ret, Result
res), Stms (Rep m)
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams (Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params)) forall a b. (a -> b) -> a -> b
$ do
[VName]
res <- [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' (forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [VName]
is) (t
n forall a. Num a => a -> a -> a
- t
1) (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape NoUniqueness)]
params) [VName] -> [VName] -> m [VName]
f
[TypeBase Shape NoUniqueness]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TypeBase Shape NoUniqueness]
ret, [VName] -> Result
varsRes [VName]
res)
let lam :: Lambda (Rep m)
lam = forall {k} (rep :: k).
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda (Param (TypeBase Shape NoUniqueness)
iparam forall a. a -> [a] -> [a]
: [Param (TypeBase Shape NoUniqueness)]
params) (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res) [TypeBase Shape NoUniqueness]
ret
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tab" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
vs) (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
lam)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: TypeBase Shape NoUniqueness
t@Array {} = do
Param (TypeBase Shape NoUniqueness)
xs_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" TypeBase Shape NoUniqueness
t
Param (TypeBase Shape NoUniqueness)
ys_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ys" TypeBase Shape NoUniqueness
t
Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t
Body SOACS
body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
SubExp
res <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"lam_map" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t) [forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
xs_p, forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
ys_p] (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody [SubExp
res]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)
xs_p, Param (TypeBase Shape NoUniqueness)
ys_p],
lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness
t],
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
}
addLambda TypeBase Shape NoUniqueness
t =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show TypeBase Shape NoUniqueness
t
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
TypeBase Shape NoUniqueness
x_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
case TypeBase Shape NoUniqueness
x_t of
Prim PrimType
pt ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
Array {} -> do
Lambda SOACS
lam <- TypeBase Shape NoUniqueness -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
x_t
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
x_t) [VName
x, VName
y] (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
TypeBase Shape NoUniqueness
_ ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString TypeBase Shape NoUniqueness
x_t
lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
TypeBase Shape NoUniqueness
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
case TypeBase Shape NoUniqueness
v_t of
Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
TypeBase Shape NoUniqueness
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
v_t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
v_t)
Just Adj
v_adj -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj
lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = do
Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing ->
VName -> VName -> ADM ()
insAdj VName
v VName
d
Just Adj
adj -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
TypeBase Shape NoUniqueness
v_adj_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
case TypeBase Shape NoUniqueness
v_adj_t of
Acc {} -> do
[SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
d
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
VName
v_adj' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Slice [DimFix SubExp
i]) VName
v VName
d =
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
AssumeBounds, SubExp
i) (VName -> SubExp
Var VName
d)
updateAdjSlice Slice SubExp
slice VName
v VName
d = do
TypeBase Shape NoUniqueness
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
TypeBase Shape NoUniqueness
v_adj_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
VName
v_adj' <- case TypeBase Shape NoUniqueness
v_adj_t of
Acc {} -> do
let dims :: [SubExp]
dims = forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
[SubExp]
slice' <-
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") forall a b. (a -> b) -> a -> b
$
forall d. Num d => Slice d -> [d] -> [d]
fixSlice (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
v_adj') forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' [SubExp]
slice' [VName -> SubExp
Var VName
d']
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
VName
v_adjslice <-
if forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
t
then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
else forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"updated_adj" VName
v_adj Slice SubExp
slice forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adjslice VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
Maybe Adj
maybeAdj <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
TypeBase Shape NoUniqueness
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
let iv :: (InBounds, SubExp, SubExp)
iv = (InBounds
check, SubExp
i, SubExp
se)
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just AdjZero {} ->
VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) [(InBounds, SubExp, SubExp)
iv]
Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
VName -> Adj -> ADM ()
setAdj VName
v forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
Just adj :: Adj
adj@AdjVal {} -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
TypeBase Shape NoUniqueness
v_adj_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v_adj
VName
se_v <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"se_v" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
VName -> VName -> ADM ()
insAdj VName
v
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case TypeBase Shape NoUniqueness
v_adj_t of
Acc {}
| InBounds
check forall a. Eq a => a -> a -> Bool
== InBounds
OutOfBounds ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
| Bool
otherwise -> do
[SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
se_v
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
se_v, VName
v_adj] forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
se_v', VName
v_adj'] ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (SubExp
i forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
TypeBase Shape NoUniqueness
_ -> do
let stms :: Safety -> ADM VName
stms Safety
s = do
VName
v_adj_i <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
v_adj forall a b. (a -> b) -> a -> b
$
TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [forall d. d -> DimIndex d
DimFix SubExp
i]
SubExp
se_update <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"updated_adj_i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
se_v VName
v_adj_i
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
v_adj (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_adj_t [forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se_update
case InBounds
check of
CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
InBounds
OutOfBounds -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Eq a => a -> a -> Bool
/= forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType
subAD :: ADM a -> ADM a
subAD :: forall a. ADM a -> ADM a
subAD ADM a
m = do
Map VName Adj
old_state_adjs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
a
x <- ADM a
m
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateAdjs :: Map VName Adj
stateAdjs = Map VName Adj
old_state_adjs}
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
subSubsts :: ADM a -> ADM a
subSubsts :: forall a. ADM a -> ADM a
subSubsts ADM a
m = do
Substitutions
old_state_substs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Substitutions
stateSubsts
a
x <- ADM a
m
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateSubsts :: Substitutions
stateSubsts = Substitutions
old_state_substs}
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
data VjpOps = VjpOps
{ VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
}
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateLoopTape :: Substitutions
stateLoopTape = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
vs forall a b. (a -> b) -> a -> b
$ RState -> Substitutions
stateLoopTape RState
env}
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Substitutions
stateLoopTape
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList