{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a convenience typeclass for creating
-- normalised programs.
--
-- See "Futhark.Construct" for a high-level description.
module Futhark.Builder.Class
  ( Buildable (..),
    mkLet,
    mkLet',
    MonadBuilder (..),
    insertStms,
    insertStm,
    letBind,
    letBindNames,
    collectStms_,
    bodyBind,
    attributing,
    auxing,
    module Futhark.MonadFreshNames,
  )
where

import qualified Data.Kind
import Futhark.IR
import Futhark.MonadFreshNames

-- | The class of representations that can be constructed solely from
-- an expression, within some monad.  Very important: the methods
-- should not have any significant side effects!  They may be called
-- more often than you think, and the results thrown away.  If used
-- exclusively within a 'MonadBuilder' instance, it is acceptable for
-- them to create new bindings, however.
class
  ( ASTRep rep,
    FParamInfo rep ~ DeclType,
    LParamInfo rep ~ Type,
    RetType rep ~ DeclExtType,
    BranchType rep ~ ExtType,
    SetType (LetDec rep)
  ) =>
  Buildable rep
  where
  mkExpPat :: [Ident] -> Exp rep -> Pat rep
  mkExpDec :: Pat rep -> Exp rep -> ExpDec rep
  mkBody :: Stms rep -> Result -> Body rep
  mkLetNames ::
    (MonadFreshNames m, HasScope rep m) =>
    [VName] ->
    Exp rep ->
    m (Stm rep)

-- | A monad that supports the creation of bindings from expressions
-- and bodies from bindings, with a specific rep.  This is the main
-- typeclass that a monad must implement in order for it to be useful
-- for generating or modifying Futhark code.  Most importantly
-- maintains a current state of 'Stms' (as well as a 'Scope') that
-- have been added with 'addStm'.
--
-- Very important: the methods should not have any significant side
-- effects!  They may be called more often than you think, and the
-- results thrown away.  It is acceptable for them to create new
-- bindings, however.
class
  ( ASTRep (Rep m),
    MonadFreshNames m,
    Applicative m,
    Monad m,
    LocalScope (Rep m) m
  ) =>
  MonadBuilder m
  where
  type Rep m :: Data.Kind.Type
  mkExpDecM :: Pat (Rep m) -> Exp (Rep m) -> m (ExpDec (Rep m))
  mkBodyM :: Stms (Rep m) -> Result -> m (Body (Rep m))
  mkLetNamesM :: [VName] -> Exp (Rep m) -> m (Stm (Rep m))

  -- | Add a statement to the 'Stms' under construction.
  addStm :: Stm (Rep m) -> m ()
  addStm = Seq (Stm (Rep m)) -> m ()
forall (m :: * -> *). MonadBuilder m => Seq (Stm (Rep m)) -> m ()
addStms (Seq (Stm (Rep m)) -> m ())
-> (Stm (Rep m) -> Seq (Stm (Rep m))) -> Stm (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Rep m) -> Seq (Stm (Rep m))
forall rep. Stm rep -> Stms rep
oneStm

  -- | Add multiple statements to the 'Stms' under construction.
  addStms :: Stms (Rep m) -> m ()

  -- | Obtain the statements constructed during a monadic action,
  -- instead of adding them to the state.
  collectStms :: m a -> m (a, Stms (Rep m))

  -- | Add the provided certificates to any statements added during
  -- execution of the action.
  certifying :: Certs -> m a -> m a
  certifying = (Seq (Stm (Rep m)) -> Seq (Stm (Rep m))) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Seq (Stm (Rep m)) -> Seq (Stm (Rep m))) -> m a -> m a)
-> (Certs -> Seq (Stm (Rep m)) -> Seq (Stm (Rep m)))
-> Certs
-> m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm (Rep m) -> Stm (Rep m))
-> Seq (Stm (Rep m)) -> Seq (Stm (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Stm (Rep m) -> Stm (Rep m))
 -> Seq (Stm (Rep m)) -> Seq (Stm (Rep m)))
-> (Certs -> Stm (Rep m) -> Stm (Rep m))
-> Certs
-> Seq (Stm (Rep m))
-> Seq (Stm (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> Stm (Rep m) -> Stm (Rep m)
forall rep. Certs -> Stm rep -> Stm rep
certify

-- | Apply a function to the statements added by this action.
censorStms ::
  MonadBuilder m =>
  (Stms (Rep m) -> Stms (Rep m)) ->
  m a ->
  m a
censorStms :: (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms Stms (Rep m) -> Stms (Rep m)
f m a
m = do
  (a
x, Stms (Rep m)
stms) <- m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Seq (Stm (Rep m)))
collectStms m a
m
  Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Seq (Stm (Rep m)) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Stms (Rep m) -> Stms (Rep m)
f Stms (Rep m)
stms
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | Add the given attributes to any statements added by this action.
attributing :: MonadBuilder m => Attrs -> m a -> m a
attributing :: Attrs -> m a -> m a
attributing Attrs
attrs = (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stms (Rep m) -> Stms (Rep m)) -> m a -> m a)
-> (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Rep m) -> Stm (Rep m)
onStm
  where
    onStm :: Stm (Rep m) -> Stm (Rep m)
onStm (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e) =
      Pat (Rep m)
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux {stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Rep m)) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec (Rep m))
aux} Exp (Rep m)
e

-- | Add the certificates and attributes to any statements added by
-- this action.
auxing :: MonadBuilder m => StmAux anyrep -> m a -> m a
auxing :: StmAux anyrep -> m a -> m a
auxing (StmAux Certs
cs Attrs
attrs anyrep
_) = (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall (m :: * -> *) a.
MonadBuilder m =>
(Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
censorStms ((Stms (Rep m) -> Stms (Rep m)) -> m a -> m a)
-> (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ (Stm (Rep m) -> Stm (Rep m)) -> Stms (Rep m) -> Stms (Rep m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm (Rep m) -> Stm (Rep m)
onStm
  where
    onStm :: Stm (Rep m) -> Stm (Rep m)
onStm (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e) =
      Pat (Rep m)
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux' Exp (Rep m)
e
      where
        aux' :: StmAux (ExpDec (Rep m))
aux' =
          StmAux (ExpDec (Rep m))
aux
            { stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Rep m)) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec (Rep m))
aux,
              stmAuxCerts :: Certs
stmAuxCerts = Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec (Rep m)) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec (Rep m))
aux
            }

-- | Add a statement with the given pattern and expression.
letBind ::
  MonadBuilder m =>
  Pat (Rep m) ->
  Exp (Rep m) ->
  m ()
letBind :: Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat Exp (Rep m)
e =
  Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> m (Stm (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat (Rep m)
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Rep m)
pat (StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m))
-> m (StmAux (ExpDec (Rep m))) -> m (Exp (Rep m) -> Stm (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpDec (Rep m) -> StmAux (ExpDec (Rep m))
forall dec. dec -> StmAux dec
defAux (ExpDec (Rep m) -> StmAux (ExpDec (Rep m)))
-> m (ExpDec (Rep m)) -> m (StmAux (ExpDec (Rep m)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (Rep m) -> Exp (Rep m) -> m (ExpDec (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (Rep m)
pat Exp (Rep m)
e) m (Exp (Rep m) -> Stm (Rep m))
-> m (Exp (Rep m)) -> m (Stm (Rep m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
e

-- | Construct a 'Stm' from identifiers for the context- and value
-- part of the pattern, as well as the expression.
mkLet :: Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet :: [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
ids Exp rep
e =
  let pat :: Pat rep
pat = [Ident] -> Exp rep -> Pat rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Pat rep
mkExpPat [Ident]
ids Exp rep
e
      dec :: ExpDec rep
dec = Pat rep -> Exp rep -> ExpDec rep
forall rep. Buildable rep => Pat rep -> Exp rep -> ExpDec rep
mkExpDec Pat rep
pat Exp rep
e
   in Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat rep
pat (ExpDec rep -> StmAux (ExpDec rep)
forall dec. dec -> StmAux dec
defAux ExpDec rep
dec) Exp rep
e

-- | Like mkLet, but also take attributes and certificates from the
-- given 'StmAux'.
mkLet' :: Buildable rep => [Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' :: [Ident] -> StmAux a -> Exp rep -> Stm rep
mkLet' [Ident]
ids (StmAux Certs
cs Attrs
attrs a
_) Exp rep
e =
  let pat :: Pat rep
pat = [Ident] -> Exp rep -> Pat rep
forall rep. Buildable rep => [Ident] -> Exp rep -> Pat rep
mkExpPat [Ident]
ids Exp rep
e
      dec :: ExpDec rep
dec = Pat rep -> Exp rep -> ExpDec rep
forall rep. Buildable rep => Pat rep -> Exp rep -> ExpDec rep
mkExpDec Pat rep
pat Exp rep
e
   in Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat rep
pat (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
dec) Exp rep
e

-- | Add a statement with the given pattern element names and
-- expression.
letBindNames :: MonadBuilder m => [VName] -> Exp (Rep m) -> m ()
letBindNames :: [VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names Exp (Rep m)
e = Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep m) -> m ()) -> m (Stm (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
names Exp (Rep m)
e

-- | As 'collectStms', but throw away the ordinary result.
collectStms_ :: MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ :: m a -> m (Stms (Rep m))
collectStms_ = ((a, Stms (Rep m)) -> Stms (Rep m))
-> m (a, Stms (Rep m)) -> m (Stms (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms (Rep m)) -> Stms (Rep m)
forall a b. (a, b) -> b
snd (m (a, Stms (Rep m)) -> m (Stms (Rep m)))
-> (m a -> m (a, Stms (Rep m))) -> m a -> m (Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Seq (Stm (Rep m)))
collectStms

-- | Add the statements of the body, then return the body result.
bodyBind :: MonadBuilder m => Body (Rep m) -> m Result
bodyBind :: Body (Rep m) -> m Result
bodyBind (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = do
  Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Seq (Stm (Rep m)) -> m ()
addStms Stms (Rep m)
stms
  Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

-- | Add several bindings at the outermost level of a t'Body'.
insertStms :: Buildable rep => Stms rep -> Body rep -> Body rep
insertStms :: Stms rep -> Body rep -> Body rep
insertStms Stms rep
stms1 (Body BodyDec rep
_ Stms rep
stms2 Result
res) = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
stms1 Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms2) Result
res

-- | Add a single binding at the outermost level of a t'Body'.
insertStm :: Buildable rep => Stm rep -> Body rep -> Body rep
insertStm :: Stm rep -> Body rep -> Body rep
insertStm = Stms rep -> Body rep -> Body rep
forall rep. Buildable rep => Stms rep -> Body rep -> Body rep
insertStms (Stms rep -> Body rep -> Body rep)
-> (Stm rep -> Stms rep) -> Stm rep -> Body rep -> Body rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm