{-# LANGUAGE FlexibleContexts, TypeFamilies #-}
-- | This module defines a convenience typeclass for creating
-- normalised programs.
module Futhark.Binder.Class
  ( Bindable (..)
  , mkLet
  , MonadBinder (..)
  , mkLetM
  , bodyStms
  , insertStms
  , insertStm
  , letBind
  , letBind_
  , letBindNames
  , letBindNames_
  , collectStms_
  , bodyBind

  , module Futhark.MonadFreshNames
  )
where

import Control.Monad.Writer
import qualified Data.Kind

import Futhark.Representation.AST
import Futhark.MonadFreshNames

-- | The class of lores that can be constructed solely from an
-- expression, within some monad.  Very important: the methods should
-- not have any significant side effects!  They may be called more
-- often than you think, and the results thrown away.  If used
-- exclusively within a 'MonadBinder' instance, it is acceptable for
-- them to create new bindings, however.
class (Attributes lore,
       FParamAttr lore ~ DeclType,
       LParamAttr lore ~ Type,
       RetType lore ~ DeclExtType,
       BranchType lore ~ ExtType,
       SetType (LetAttr lore)) =>
      Bindable lore where
  mkExpPat :: [Ident] -> [Ident] -> Exp lore -> Pattern lore
  mkExpAttr :: Pattern lore -> Exp lore -> ExpAttr lore
  mkBody :: Stms lore -> Result -> Body lore
  mkLetNames :: (MonadFreshNames m, HasScope lore m) =>
                [VName] -> Exp lore -> m (Stm lore)

-- | A monad that supports the creation of bindings from expressions
-- and bodies from bindings, with a specific lore.  This is the main
-- typeclass that a monad must implement in order for it to be useful
-- for generating or modifying Futhark code.
--
-- Very important: the methods should not have any significant side
-- effects!  They may be called more often than you think, and the
-- results thrown away.  It is acceptable for them to create new
-- bindings, however.
class (Attributes (Lore m),
       MonadFreshNames m, Applicative m, Monad m,
       LocalScope (Lore m) m) =>
      MonadBinder m where
  type Lore m :: Data.Kind.Type
  mkExpAttrM :: Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
  mkBodyM :: Stms (Lore m) -> Result -> m (Body (Lore m))
  mkLetNamesM :: [VName] -> Exp (Lore m) -> m (Stm (Lore m))
  addStm      :: Stm (Lore m) -> m ()
  addStm      = Seq (Stm (Lore m)) -> m ()
forall (m :: * -> *). MonadBinder m => Seq (Stm (Lore m)) -> m ()
addStms (Seq (Stm (Lore m)) -> m ())
-> (Stm (Lore m) -> Seq (Stm (Lore m))) -> Stm (Lore m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Seq (Stm (Lore m))
forall lore. Stm lore -> Stms lore
oneStm
  addStms     :: Stms (Lore m) -> m ()
  collectStms :: m a -> m (a, Stms (Lore m))
  certifying :: Certificates -> m a -> m a

mkLetM :: MonadBinder m => Pattern (Lore m) -> Exp (Lore m) -> m (Stm (Lore m))
mkLetM :: Pattern (Lore m) -> Exp (Lore m) -> m (Stm (Lore m))
mkLetM Pattern (Lore m)
pat Exp (Lore m)
e = Pattern (Lore m)
-> StmAux (ExpAttr (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat (StmAux (ExpAttr (Lore m)) -> Exp (Lore m) -> Stm (Lore m))
-> m (StmAux (ExpAttr (Lore m)))
-> m (Exp (Lore m) -> Stm (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Certificates -> ExpAttr (Lore m) -> StmAux (ExpAttr (Lore m))
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
forall a. Monoid a => a
mempty (ExpAttr (Lore m) -> StmAux (ExpAttr (Lore m)))
-> m (ExpAttr (Lore m)) -> m (StmAux (ExpAttr (Lore m)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
mkExpAttrM Pattern (Lore m)
pat Exp (Lore m)
e) m (Exp (Lore m) -> Stm (Lore m))
-> m (Exp (Lore m)) -> m (Stm (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
e

letBind :: MonadBinder m =>
           Pattern (Lore m) -> Exp (Lore m) -> m [Ident]
letBind :: Pattern (Lore m) -> Exp (Lore m) -> m [Ident]
letBind Pattern (Lore m)
pat Exp (Lore m)
e = do
  Stm (Lore m)
bnd <- Pattern (Lore m) -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m (Stm (Lore m))
mkLetM Pattern (Lore m)
pat Exp (Lore m)
e
  Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore m)
bnd
  [Ident] -> m [Ident]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Ident] -> m [Ident]) -> [Ident] -> m [Ident]
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents (Pattern (Lore m) -> [Ident]) -> Pattern (Lore m) -> [Ident]
forall a b. (a -> b) -> a -> b
$ Stm (Lore m) -> Pattern (Lore m)
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Lore m)
bnd

letBind_ :: MonadBinder m =>
            Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ :: Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat Exp (Lore m)
e = m [Ident] -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m [Ident] -> m ()) -> m [Ident] -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m [Ident]
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m [Ident]
letBind Pattern (Lore m)
pat Exp (Lore m)
e

mkLet :: Bindable lore => [Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet :: [Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [Ident]
ctx [Ident]
val Exp lore
e =
  let pat :: Pattern lore
pat = [Ident] -> [Ident] -> Exp lore -> Pattern lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Pattern lore
mkExpPat [Ident]
ctx [Ident]
val Exp lore
e
      attr :: ExpAttr lore
attr = Pattern lore -> Exp lore -> ExpAttr lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpAttr lore
mkExpAttr Pattern lore
pat Exp lore
e
  in Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern lore
pat (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
forall a. Monoid a => a
mempty ExpAttr lore
attr) Exp lore
e

letBindNames :: MonadBinder m =>
                [VName] -> Exp (Lore m) -> m [Ident]
letBindNames :: [VName] -> Exp (Lore m) -> m [Ident]
letBindNames [VName]
names Exp (Lore m)
e = do
  Stm (Lore m)
bnd <- [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName]
names Exp (Lore m)
e
  Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore m)
bnd
  [Ident] -> m [Ident]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Ident] -> m [Ident]) -> [Ident] -> m [Ident]
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr (Lore m)) -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents (PatternT (LetAttr (Lore m)) -> [Ident])
-> PatternT (LetAttr (Lore m)) -> [Ident]
forall a b. (a -> b) -> a -> b
$ Stm (Lore m) -> PatternT (LetAttr (Lore m))
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Lore m)
bnd

letBindNames_ :: MonadBinder m =>
                [VName] -> Exp (Lore m) -> m ()
letBindNames_ :: [VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName]
names Exp (Lore m)
e = m [Ident] -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m [Ident] -> m ()) -> m [Ident] -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore m) -> m [Ident]
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m [Ident]
letBindNames [VName]
names Exp (Lore m)
e

collectStms_ :: MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ :: m a -> m (Stms (Lore m))
collectStms_ = ((a, Stms (Lore m)) -> Stms (Lore m))
-> m (a, Stms (Lore m)) -> m (Stms (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms (Lore m)) -> Stms (Lore m)
forall a b. (a, b) -> b
snd (m (a, Stms (Lore m)) -> m (Stms (Lore m)))
-> (m a -> m (a, Stms (Lore m))) -> m a -> m (Stms (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m (a, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Seq (Stm (Lore m)))
collectStms

bodyBind :: MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind :: Body (Lore m) -> m [SubExp]
bodyBind (Body BodyAttr (Lore m)
_ Stms (Lore m)
bnds [SubExp]
es) = do
  Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Seq (Stm (Lore m)) -> m ()
addStms Stms (Lore m)
bnds
  [SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [SubExp]
es

-- | Add several bindings at the outermost level of a 'Body'.
insertStms :: Bindable lore => Stms lore -> Body lore -> Body lore
insertStms :: Stms lore -> Body lore -> Body lore
insertStms Stms lore
bnds1 (Body BodyAttr lore
_ Stms lore
bnds2 [SubExp]
res) = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
bnds1Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
bnds2) [SubExp]
res

-- | Add a single binding at the outermost level of a 'Body'.
insertStm :: Bindable lore => Stm lore -> Body lore -> Body lore
insertStm :: Stm lore -> Body lore -> Body lore
insertStm = Stms lore -> Body lore -> Body lore
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms (Stms lore -> Body lore -> Body lore)
-> (Stm lore -> Stms lore) -> Stm lore -> Body lore -> Body lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm