{-# LANGUAGE TypeFamilies #-}
module Futhark.Builder.Class
( Buildable (..),
mkLet,
mkLet',
MonadBuilder (..),
insertStms,
insertStm,
letBind,
letBindNames,
collectStms_,
bodyBind,
attributing,
auxing,
module Futhark.MonadFreshNames,
)
where
import Data.Kind qualified
import Futhark.IR
import Futhark.MonadFreshNames
class
( ASTRep rep,
FParamInfo rep ~ DeclType,
LParamInfo rep ~ Type,
RetType rep ~ DeclExtType,
BranchType rep ~ ExtType
) =>
Buildable rep
where
mkExpPat :: [Ident] -> Exp rep -> Pat (LetDec rep)
mkExpDec :: Pat (LetDec rep) -> Exp rep -> ExpDec rep
mkBody :: Stms rep -> Result -> Body rep
mkLetNames ::
(MonadFreshNames m, HasScope rep m) =>
[VName] ->
Exp rep ->
m (Stm rep)
class
( ASTRep (Rep m),
MonadFreshNames m,
Applicative m,
Monad m,
LocalScope (Rep m) m
) =>
MonadBuilder m
where
type Rep m :: Data.Kind.Type
mkExpDecM :: Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkBodyM :: Stms (Rep m) -> Result -> m (Body (Rep m))
mkLetNamesM :: [VName] -> Exp (Rep m) -> m (Stm (Rep m))
addStm :: Stm (Rep m) -> m ()
addStm = Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ())
-> (Stm (Rep m) -> Stms (Rep m)) -> Stm (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Rep m) -> Stms (Rep m)
forall rep. Stm rep -> Stms rep
oneStm
addStms :: Stms (Rep m) -> m ()
collectStms :: m a -> m (a, Stms (Rep m))
certifying :: Certs -> m a -> m a
certifying = (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stms (Rep m) -> Stms (Rep m)) -> m a -> m a)
-> (Certs -> Stms (Rep m) -> Stms (Rep m)) -> Certs -> m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m)
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m))
-> (Certs -> Stm (Rep m) -> Stm (Rep m))
-> Certs
-> Stms (Rep m)
-> Stms (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> Stm (Rep m) -> Stm (Rep m)
forall rep. Certs -> Stm rep -> Stm rep
certify
censorStms ::
(MonadBuilder m) =>
(Stms (Rep m) -> Stms (Rep m)) ->
m a ->
m a
censorStms :: forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms Stms (Rep m) -> Stms (Rep m)
f m a
m = do
(a
x, Stms (Rep m)
stms) <- m a -> m (a, Stms (Rep m))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m a
m
Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Stms (Rep m) -> Stms (Rep m)
f Stms (Rep m)
stms
a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
attributing :: (MonadBuilder m) => Attrs -> m a -> m a
attributing :: forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs = (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stms (Rep m) -> Stms (Rep m)) -> m a -> m a)
-> (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m)
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Rep m) -> Stm (Rep m)
onStm
where
onStm :: Stm (Rep m) -> Stm (Rep m)
onStm (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e) =
Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux {stmAuxAttrs = attrs <> stmAuxAttrs aux} Exp (Rep m)
e
auxing :: (MonadBuilder m) => StmAux anyrep -> m a -> m a
auxing :: forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (StmAux Certs
cs Attrs
attrs anyrep
_) = (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stms (Rep m) -> Stms (Rep m)) -> m a -> m a)
-> (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m)
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Rep m) -> Stm (Rep m)
onStm
where
onStm :: Stm (Rep m) -> Stm (Rep m)
onStm (Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e) =
Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
pat StmAux (ExpDec (Rep m))
aux' Exp (Rep m)
e
where
aux' :: StmAux (ExpDec (Rep m))
aux' =
StmAux (ExpDec (Rep m))
aux
{ stmAuxAttrs = attrs <> stmAuxAttrs aux,
stmAuxCerts = cs <> stmAuxCerts aux
}
letBind ::
(MonadBuilder m) =>
Pat (LetDec (Rep m)) ->
Exp (Rep m) ->
m ()
letBind :: forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat Exp (Rep m)
e =
Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> m (Stm (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Rep m))
pat (StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m))
-> m (StmAux (ExpDec (Rep m))) -> m (Exp (Rep m) -> Stm (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpDec (Rep m) -> StmAux (ExpDec (Rep m))
forall dec. dec -> StmAux dec
defAux (ExpDec (Rep m) -> StmAux (ExpDec (Rep m)))
-> m (ExpDec (Rep m)) -> m (StmAux (ExpDec (Rep m)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep m))
pat Exp (Rep m)
e) m (Exp (Rep m) -> Stm (Rep m))
-> m (Exp (Rep m)) -> m (Stm (Rep m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp (Rep m) -> m (Exp (Rep m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
e
mkLet :: (Buildable rep) => [Ident] -> Exp rep -> Stm rep
mkLet :: forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
ids Exp rep
e =
let pat :: Pat (LetDec rep)
pat = [Ident] -> Exp rep -> Pat (LetDec rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Pat (LetDec rep)
mkExpPat [Ident]
ids Exp rep
e
dec :: ExpDec rep
dec = Pat (LetDec rep) -> Exp rep -> ExpDec rep
forall rep.
Buildable rep =>
Pat (LetDec rep) -> Exp rep -> ExpDec rep
mkExpDec Pat (LetDec rep)
pat Exp rep
e
in Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat (ExpDec rep -> StmAux (ExpDec rep)
forall dec. dec -> StmAux dec
defAux ExpDec rep
dec) Exp rep
e
mkLet' :: (Buildable rep) => [Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' :: forall rep a.
Buildable rep =>
[Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' [Ident]
ids (StmAux Certs
cs Attrs
attrs a
_) Exp rep
e =
let pat :: Pat (LetDec rep)
pat = [Ident] -> Exp rep -> Pat (LetDec rep)
forall rep. Buildable rep => [Ident] -> Exp rep -> Pat (LetDec rep)
mkExpPat [Ident]
ids Exp rep
e
dec :: ExpDec rep
dec = Pat (LetDec rep) -> Exp rep -> ExpDec rep
forall rep.
Buildable rep =>
Pat (LetDec rep) -> Exp rep -> ExpDec rep
mkExpDec Pat (LetDec rep)
pat Exp rep
e
in Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
dec) Exp rep
e
letBindNames :: (MonadBuilder m) => [VName] -> Exp (Rep m) -> m ()
letBindNames :: forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names Exp (Rep m)
e = Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> m (Stm (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
names Exp (Rep m)
e
collectStms_ :: (MonadBuilder m) => m a -> m (Stms (Rep m))
collectStms_ :: forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ = ((a, Stms (Rep m)) -> Stms (Rep m))
-> m (a, Stms (Rep m)) -> m (Stms (Rep m))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms (Rep m)) -> Stms (Rep m)
forall a b. (a, b) -> b
snd (m (a, Stms (Rep m)) -> m (Stms (Rep m)))
-> (m a -> m (a, Stms (Rep m))) -> m a -> m (Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m (a, Stms (Rep m))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
bodyBind :: (MonadBuilder m) => Body (Rep m) -> m Result
bodyBind :: forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = do
Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep m)
stms
Result -> m Result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
insertStms :: (Buildable rep) => Stms rep -> Body rep -> Body rep
insertStms :: forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms Stms rep
stms1 (Body BodyDec rep
_ Stms rep
stms2 Result
res) = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
stms1 Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms2) Result
res
insertStm :: (Buildable rep) => Stm rep -> Body rep -> Body rep
insertStm :: forall rep. Buildable rep => Stm rep -> Body rep -> Body rep
insertStm = Stms rep -> Body rep -> Body rep
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms (Stms rep -> Body rep -> Body rep)
-> (Stm rep -> Stms rep) -> Stm rep -> Body rep -> Body rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm