{-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, TypeFamilies, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
{-# LANGUAGE ConstraintKinds #-}
-- | This module defines a convenience monad/typeclass for creating
-- normalised programs.
module Futhark.Binder
  ( -- * A concrete @MonadBinder@ monad.
    BinderT
  , runBinderT, runBinderT_
  , runBinderT', runBinderT'_
  , BinderOps (..)
  , bindableMkExpAttrB
  , bindableMkBodyB
  , bindableMkLetNamesB
  , Binder
  , runBinder
  , runBinder_
  , joinBinder
  , runBodyBinder
  -- * Non-class interface
  , addBinderStms
  , collectBinderStms
  , certifyingBinder
  -- * The 'MonadBinder' typeclass
  , module Futhark.Binder.Class
  )
where

import Control.Arrow (second)
import Control.Monad.Writer
import Control.Monad.State.Strict
import Control.Monad.Reader
import Control.Monad.Error.Class
import qualified Data.Map.Strict as M

import Futhark.Binder.Class
import Futhark.Representation.AST

class Attributes lore => BinderOps lore where
  mkExpAttrB :: (MonadBinder m, Lore m ~ lore) =>
                Pattern lore -> Exp lore -> m (ExpAttr lore)
  mkBodyB :: (MonadBinder m, Lore m ~ lore) =>
             Stms lore -> Result -> m (Body lore)
  mkLetNamesB :: (MonadBinder m, Lore m ~ lore) =>
                 [VName] -> Exp lore -> m (Stm lore)

bindableMkExpAttrB :: (MonadBinder m, Bindable (Lore m)) =>
                      Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
bindableMkExpAttrB :: Pattern (Lore m) -> Exp (Lore m) -> m (ExpAttr (Lore m))
bindableMkExpAttrB Pattern (Lore m)
pat Exp (Lore m)
e = ExpAttr (Lore m) -> m (ExpAttr (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpAttr (Lore m) -> m (ExpAttr (Lore m)))
-> ExpAttr (Lore m) -> m (ExpAttr (Lore m))
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> ExpAttr (Lore m)
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpAttr lore
mkExpAttr Pattern (Lore m)
pat Exp (Lore m)
e

bindableMkBodyB :: (MonadBinder m, Bindable (Lore m)) =>
                   Stms (Lore m) -> Result -> m (Body (Lore m))
bindableMkBodyB :: Stms (Lore m) -> Result -> m (Body (Lore m))
bindableMkBodyB Stms (Lore m)
stms Result
res = Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Lore m) -> m (Body (Lore m)))
-> Body (Lore m) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ Stms (Lore m) -> Result -> Body (Lore m)
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms (Lore m)
stms Result
res

bindableMkLetNamesB :: (MonadBinder m, Bindable (Lore m)) =>
                       [VName] -> Exp (Lore m) -> m (Stm (Lore m))
bindableMkLetNamesB :: [VName] -> Exp (Lore m) -> m (Stm (Lore m))
bindableMkLetNamesB = [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, HasScope lore m) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNames

newtype BinderT lore m a = BinderT (StateT (Stms lore, Scope lore) m a)
  deriving (a -> BinderT lore m b -> BinderT lore m a
(a -> b) -> BinderT lore m a -> BinderT lore m b
(forall a b. (a -> b) -> BinderT lore m a -> BinderT lore m b)
-> (forall a b. a -> BinderT lore m b -> BinderT lore m a)
-> Functor (BinderT lore m)
forall a b. a -> BinderT lore m b -> BinderT lore m a
forall a b. (a -> b) -> BinderT lore m a -> BinderT lore m b
forall lore (m :: * -> *) a b.
Functor m =>
a -> BinderT lore m b -> BinderT lore m a
forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> BinderT lore m a -> BinderT lore m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BinderT lore m b -> BinderT lore m a
$c<$ :: forall lore (m :: * -> *) a b.
Functor m =>
a -> BinderT lore m b -> BinderT lore m a
fmap :: (a -> b) -> BinderT lore m a -> BinderT lore m b
$cfmap :: forall lore (m :: * -> *) a b.
Functor m =>
(a -> b) -> BinderT lore m a -> BinderT lore m b
Functor, Applicative (BinderT lore m)
a -> BinderT lore m a
Applicative (BinderT lore m)
-> (forall a b.
    BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m b)
-> (forall a. a -> BinderT lore m a)
-> Monad (BinderT lore m)
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall a. a -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall a b.
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
forall lore (m :: * -> *). Monad m => Applicative (BinderT lore m)
forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> BinderT lore m a
$creturn :: forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
>> :: BinderT lore m a -> BinderT lore m b -> BinderT lore m b
$c>> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
>>= :: BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
$c>>= :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> (a -> BinderT lore m b) -> BinderT lore m b
$cp1Monad :: forall lore (m :: * -> *). Monad m => Applicative (BinderT lore m)
Monad, Functor (BinderT lore m)
a -> BinderT lore m a
Functor (BinderT lore m)
-> (forall a. a -> BinderT lore m a)
-> (forall a b.
    BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b)
-> (forall a b c.
    (a -> b -> c)
    -> BinderT lore m a -> BinderT lore m b -> BinderT lore m c)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m b)
-> (forall a b.
    BinderT lore m a -> BinderT lore m b -> BinderT lore m a)
-> Applicative (BinderT lore m)
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
forall a. a -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
forall a b.
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall a b.
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
forall a b c.
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
forall lore (m :: * -> *). Monad m => Functor (BinderT lore m)
forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
forall lore (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: BinderT lore m a -> BinderT lore m b -> BinderT lore m a
$c<* :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m a
*> :: BinderT lore m a -> BinderT lore m b -> BinderT lore m b
$c*> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m a -> BinderT lore m b -> BinderT lore m b
liftA2 :: (a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
$cliftA2 :: forall lore (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> BinderT lore m a -> BinderT lore m b -> BinderT lore m c
<*> :: BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
$c<*> :: forall lore (m :: * -> *) a b.
Monad m =>
BinderT lore m (a -> b) -> BinderT lore m a -> BinderT lore m b
pure :: a -> BinderT lore m a
$cpure :: forall lore (m :: * -> *) a. Monad m => a -> BinderT lore m a
$cp1Applicative :: forall lore (m :: * -> *). Monad m => Functor (BinderT lore m)
Applicative)

instance MonadTrans (BinderT lore) where
  lift :: m a -> BinderT lore m a
lift = StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> (m a -> StateT (Stms lore, Scope lore) m a)
-> m a
-> BinderT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT (Stms lore, Scope lore) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

type Binder lore = BinderT lore (State VNameSource)

instance MonadFreshNames m => MonadFreshNames (BinderT lore m) where
  getNameSource :: BinderT lore m VNameSource
getNameSource = m VNameSource -> BinderT lore m VNameSource
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  putNameSource :: VNameSource -> BinderT lore m ()
putNameSource = m () -> BinderT lore m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> BinderT lore m ())
-> (VNameSource -> m ()) -> VNameSource -> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VNameSource -> m ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource

instance (Attributes lore, Monad m) =>
         HasScope lore (BinderT lore m) where
  lookupType :: VName -> BinderT lore m Type
lookupType VName
name = do
    Maybe (NameInfo lore)
t <- StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
-> BinderT lore m (Maybe (NameInfo lore))
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
 -> BinderT lore m (Maybe (NameInfo lore)))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
-> BinderT lore m (Maybe (NameInfo lore))
forall a b. (a -> b) -> a -> b
$ ((Stms lore, Scope lore) -> Maybe (NameInfo lore))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((Stms lore, Scope lore) -> Maybe (NameInfo lore))
 -> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore)))
-> ((Stms lore, Scope lore) -> Maybe (NameInfo lore))
-> StateT (Stms lore, Scope lore) m (Maybe (NameInfo lore))
forall a b. (a -> b) -> a -> b
$ VName -> Scope lore -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Scope lore -> Maybe (NameInfo lore))
-> ((Stms lore, Scope lore) -> Scope lore)
-> (Stms lore, Scope lore)
-> Maybe (NameInfo lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms lore, Scope lore) -> Scope lore
forall a b. (a, b) -> b
snd
    case Maybe (NameInfo lore)
t of
      Maybe (NameInfo lore)
Nothing -> [Char] -> BinderT lore m Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> BinderT lore m Type) -> [Char] -> BinderT lore m Type
forall a b. (a -> b) -> a -> b
$ [Char]
"BinderT.lookupType: unknown variable " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
      Just NameInfo lore
t' -> Type -> BinderT lore m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> BinderT lore m Type) -> Type -> BinderT lore m Type
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> Type
forall t. Typed t => t -> Type
typeOf NameInfo lore
t'
  askScope :: BinderT lore m (Scope lore)
askScope = StateT (Stms lore, Scope lore) m (Scope lore)
-> BinderT lore m (Scope lore)
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m (Scope lore)
 -> BinderT lore m (Scope lore))
-> StateT (Stms lore, Scope lore) m (Scope lore)
-> BinderT lore m (Scope lore)
forall a b. (a -> b) -> a -> b
$ ((Stms lore, Scope lore) -> Scope lore)
-> StateT (Stms lore, Scope lore) m (Scope lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Stms lore, Scope lore) -> Scope lore
forall a b. (a, b) -> b
snd

instance (Attributes lore, Monad m) =>
         LocalScope lore (BinderT lore m) where
  localScope :: Scope lore -> BinderT lore m a -> BinderT lore m a
localScope Scope lore
types (BinderT StateT (Stms lore, Scope lore) m a
m) = StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ do
    ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Scope lore) -> (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m ())
-> ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall a b. (a -> b) -> a -> b
$ (Scope lore -> Scope lore)
-> (Stms lore, Scope lore) -> (Stms lore, Scope lore)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Scope lore -> Scope lore -> Scope lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Scope lore
types)
    a
x <- StateT (Stms lore, Scope lore) m a
m
    ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Scope lore) -> (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m ())
-> ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall a b. (a -> b) -> a -> b
$ (Scope lore -> Scope lore)
-> (Stms lore, Scope lore) -> (Stms lore, Scope lore)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (Scope lore -> Scope lore -> Scope lore
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Scope lore
types)
    a -> StateT (Stms lore, Scope lore) m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

instance (Attributes lore, MonadFreshNames m, BinderOps lore) =>
         MonadBinder (BinderT lore m) where
  type Lore (BinderT lore m) = lore
  mkExpAttrM :: Pattern (Lore (BinderT lore m))
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (ExpAttr (Lore (BinderT lore m)))
mkExpAttrM = Pattern (Lore (BinderT lore m))
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (ExpAttr (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
Pattern lore -> Exp lore -> m (ExpAttr lore)
mkExpAttrB
  mkBodyM :: Stms (Lore (BinderT lore m))
-> Result -> BinderT lore m (Body (Lore (BinderT lore m)))
mkBodyM = Stms (Lore (BinderT lore m))
-> Result -> BinderT lore m (Body (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
Stms lore -> Result -> m (Body lore)
mkBodyB
  mkLetNamesM :: [VName]
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (Stm (Lore (BinderT lore m)))
mkLetNamesM = [VName]
-> Exp (Lore (BinderT lore m))
-> BinderT lore m (Stm (Lore (BinderT lore m)))
forall lore (m :: * -> *).
(BinderOps lore, MonadBinder m, Lore m ~ lore) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNamesB

  addStms :: Stms (Lore (BinderT lore m)) -> BinderT lore m ()
addStms     = Stms (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> BinderT lore m ()
addBinderStms
  collectStms :: BinderT lore m a
-> BinderT lore m (a, Stms (Lore (BinderT lore m)))
collectStms = BinderT lore m a
-> BinderT lore m (a, Stms (Lore (BinderT lore m)))
forall (m :: * -> *) lore a.
Monad m =>
BinderT lore m a -> BinderT lore m (a, Stms lore)
collectBinderStms

  certifying :: Certificates -> BinderT lore m a -> BinderT lore m a
certifying = Certificates -> BinderT lore m a -> BinderT lore m a
forall (m :: * -> *) lore a.
(MonadFreshNames m, BinderOps lore) =>
Certificates -> BinderT lore m a -> BinderT lore m a
certifyingBinder

runBinderT :: MonadFreshNames m =>
              BinderT lore m a
           -> Scope lore
           -> m (a, Stms lore)
runBinderT :: BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (BinderT StateT (Stms lore, Scope lore) m a
m) Scope lore
scope = do
  (a
x, (Stms lore
stms, Scope lore
_)) <- StateT (Stms lore, Scope lore) m a
-> (Stms lore, Scope lore) -> m (a, (Stms lore, Scope lore))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Stms lore, Scope lore) m a
m (Stms lore
forall a. Monoid a => a
mempty, Scope lore
scope)
  (a, Stms lore) -> m (a, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stms lore
stms)

runBinderT_ :: MonadFreshNames m =>
                BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ :: BinderT lore m a -> Scope lore -> m (Stms lore)
runBinderT_ BinderT lore m a
m = ((a, Stms lore) -> Stms lore) -> m (a, Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m (a, Stms lore) -> m (Stms lore))
-> (Scope lore -> m (a, Stms lore)) -> Scope lore -> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderT lore m a -> Scope lore -> m (a, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT lore m a
m

runBinderT' :: (MonadFreshNames m, HasScope somelore m, SameScope somelore lore) =>
               BinderT lore m a
            -> m (a, Stms lore)
runBinderT' :: BinderT lore m a -> m (a, Stms lore)
runBinderT' BinderT lore m a
m = do
  Scope somelore
scope <- m (Scope somelore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  BinderT lore m a -> Scope lore -> m (a, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT lore m a
m (Scope lore -> m (a, Stms lore)) -> Scope lore -> m (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ Scope somelore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope Scope somelore
scope

runBinderT'_ :: (MonadFreshNames m, HasScope somelore m, SameScope somelore lore) =>
                BinderT lore m a -> m (Stms lore)
runBinderT'_ :: BinderT lore m a -> m (Stms lore)
runBinderT'_ = ((a, Stms lore) -> Stms lore) -> m (a, Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m (a, Stms lore) -> m (Stms lore))
-> (BinderT lore m a -> m (a, Stms lore))
-> BinderT lore m a
-> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderT lore m a -> m (a, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT'

runBinder :: (MonadFreshNames m,
              HasScope somelore m, SameScope somelore lore) =>
              Binder lore a
           -> m (a, Stms lore)
runBinder :: Binder lore a -> m (a, Stms lore)
runBinder Binder lore a
m = do
  Scope somelore
types <- m (Scope somelore)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (VNameSource -> ((a, Stms lore), VNameSource)) -> m (a, Stms lore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms lore), VNameSource))
 -> m (a, Stms lore))
-> (VNameSource -> ((a, Stms lore), VNameSource))
-> m (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms lore)
-> VNameSource -> ((a, Stms lore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms lore)
 -> VNameSource -> ((a, Stms lore), VNameSource))
-> State VNameSource (a, Stms lore)
-> VNameSource
-> ((a, Stms lore), VNameSource)
forall a b. (a -> b) -> a -> b
$ Binder lore a -> Scope lore -> State VNameSource (a, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Binder lore a
m (Scope lore -> State VNameSource (a, Stms lore))
-> Scope lore -> State VNameSource (a, Stms lore)
forall a b. (a -> b) -> a -> b
$ Scope somelore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope Scope somelore
types

-- | Like 'runBinder', but throw away the result and just return the
-- added bindings.
runBinder_ :: (MonadFreshNames m,
               HasScope somelore m, SameScope somelore lore) =>
              Binder lore a
           -> m (Stms lore)
runBinder_ :: Binder lore a -> m (Stms lore)
runBinder_ = ((a, Stms lore) -> Stms lore) -> m (a, Stms lore) -> m (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms lore) -> Stms lore
forall a b. (a, b) -> b
snd (m (a, Stms lore) -> m (Stms lore))
-> (Binder lore a -> m (a, Stms lore))
-> Binder lore a
-> m (Stms lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore a -> m (a, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder

-- | As 'runBinder', but uses 'addStm' to add the returned
-- bindings to the surrounding monad.
joinBinder :: MonadBinder m => Binder (Lore m) a -> m a
joinBinder :: Binder (Lore m) a -> m a
joinBinder Binder (Lore m) a
m = do (a
x, Stms (Lore m)
bnds) <- Binder (Lore m) a -> m (a, Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder Binder (Lore m) a
m
                  Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore m)
bnds
                  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

runBodyBinder :: (Bindable lore, MonadFreshNames m,
                  HasScope somelore m, SameScope somelore lore) =>
                 Binder lore (Body lore) -> m (Body lore)
runBodyBinder :: Binder lore (Body lore) -> m (Body lore)
runBodyBinder = ((Body lore, Stms lore) -> Body lore)
-> m (Body lore, Stms lore) -> m (Body lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Body lore -> Stms lore -> Body lore)
-> (Body lore, Stms lore) -> Body lore
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Body lore -> Stms lore -> Body lore)
 -> (Body lore, Stms lore) -> Body lore)
-> (Body lore -> Stms lore -> Body lore)
-> (Body lore, Stms lore)
-> Body lore
forall a b. (a -> b) -> a -> b
$ (Stms lore -> Body lore -> Body lore)
-> Body lore -> Stms lore -> Body lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stms lore -> Body lore -> Body lore
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms) (m (Body lore, Stms lore) -> m (Body lore))
-> (Binder lore (Body lore) -> m (Body lore, Stms lore))
-> Binder lore (Body lore)
-> m (Body lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder lore (Body lore) -> m (Body lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder

addBinderStms :: Monad m =>
                 Stms lore -> BinderT lore m ()
addBinderStms :: Stms lore -> BinderT lore m ()
addBinderStms Stms lore
stms = StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$
  ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((Stms lore, Scope lore) -> (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m ())
-> ((Stms lore, Scope lore) -> (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m ()
forall a b. (a -> b) -> a -> b
$ \(Stms lore
cur_stms,Scope lore
scope) -> (Stms lore
cur_stmsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
stms,
                                 Scope lore
scope Scope lore -> Scope lore -> Scope lore
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Stms lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms)

collectBinderStms :: Monad m =>
                     BinderT lore m a
                  -> BinderT lore m (a, Stms lore)
collectBinderStms :: BinderT lore m a -> BinderT lore m (a, Stms lore)
collectBinderStms BinderT lore m a
m = do
  (Stms lore
old_stms, Scope lore
old_scope) <- StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
-> BinderT lore m (Stms lore, Scope lore)
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
forall s (m :: * -> *). MonadState s m => m s
get
  StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stms lore, Scope lore) -> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore
forall a. Monoid a => a
mempty, Scope lore
old_scope)
  a
x <- BinderT lore m a
m
  (Stms lore
new_stms, Scope lore
_) <- StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
-> BinderT lore m (Stms lore, Scope lore)
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
forall s (m :: * -> *). MonadState s m => m s
get
  StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ (Stms lore, Scope lore) -> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore
old_stms, Scope lore
old_scope)
  (a, Stms lore) -> BinderT lore m (a, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stms lore
new_stms)

certifyingBinder :: (MonadFreshNames m, BinderOps lore) =>
                    Certificates -> BinderT lore m a
                 -> BinderT lore m a
certifyingBinder :: Certificates -> BinderT lore m a -> BinderT lore m a
certifyingBinder Certificates
cs BinderT lore m a
m = do
  (a
x, Seq (Stm lore)
stms) <- BinderT lore m a
-> BinderT lore m (a, Stms (Lore (BinderT lore m)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms BinderT lore m a
m
  Stms (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT lore m)) -> BinderT lore m ())
-> Stms (Lore (BinderT lore m)) -> BinderT lore m ()
forall a b. (a -> b) -> a -> b
$ Certificates -> Stm lore -> Stm lore
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs (Stm lore -> Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq (Stm lore)
stms
  a -> BinderT lore m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- Utility instance defintions for MTL classes.  These require
-- UndecidableInstances, but save on typing elsewhere.

mapInner :: Monad m =>
            (m (a, (Stms lore, Scope lore))
             -> m (b, (Stms lore, Scope lore)))
         -> BinderT lore m a -> BinderT lore m b
mapInner :: (m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
f (BinderT StateT (Stms lore, Scope lore) m a
m) = StateT (Stms lore, Scope lore) m b -> BinderT lore m b
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m b -> BinderT lore m b)
-> StateT (Stms lore, Scope lore) m b -> BinderT lore m b
forall a b. (a -> b) -> a -> b
$ do
  (Stms lore, Scope lore)
s <- StateT (Stms lore, Scope lore) m (Stms lore, Scope lore)
forall s (m :: * -> *). MonadState s m => m s
get
  (b
x, (Stms lore, Scope lore)
s') <- m (b, (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (b, (Stms lore, Scope lore))
 -> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore)))
-> m (b, (Stms lore, Scope lore))
-> StateT (Stms lore, Scope lore) m (b, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
f (m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ StateT (Stms lore, Scope lore) m a
-> (Stms lore, Scope lore) -> m (a, (Stms lore, Scope lore))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Stms lore, Scope lore) m a
m (Stms lore, Scope lore)
s
  (Stms lore, Scope lore) -> StateT (Stms lore, Scope lore) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Stms lore, Scope lore)
s'
  b -> StateT (Stms lore, Scope lore) m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x

instance MonadReader r m => MonadReader r (BinderT lore m) where
  ask :: BinderT lore m r
ask = StateT (Stms lore, Scope lore) m r -> BinderT lore m r
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m r -> BinderT lore m r)
-> StateT (Stms lore, Scope lore) m r -> BinderT lore m r
forall a b. (a -> b) -> a -> b
$ m r -> StateT (Stms lore, Scope lore) m r
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m r
forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: (r -> r) -> BinderT lore m a -> BinderT lore m a
local r -> r
f = (m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m a
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore)))
 -> BinderT lore m a -> BinderT lore m a)
-> (m (a, (Stms lore, Scope lore))
    -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m a
-> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ (r -> r)
-> m (a, (Stms lore, Scope lore)) -> m (a, (Stms lore, Scope lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f

instance MonadState s m => MonadState s (BinderT lore m) where
  get :: BinderT lore m s
get = StateT (Stms lore, Scope lore) m s -> BinderT lore m s
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m s -> BinderT lore m s)
-> StateT (Stms lore, Scope lore) m s -> BinderT lore m s
forall a b. (a -> b) -> a -> b
$ m s -> StateT (Stms lore, Scope lore) m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> BinderT lore m ()
put = StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> (s -> StateT (Stms lore, Scope lore) m ())
-> s
-> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> StateT (Stms lore, Scope lore) m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT (Stms lore, Scope lore) m ())
-> (s -> m ()) -> s -> StateT (Stms lore, Scope lore) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance MonadWriter w m => MonadWriter w (BinderT lore m) where
  tell :: w -> BinderT lore m ()
tell = StateT (Stms lore, Scope lore) m () -> BinderT lore m ()
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m () -> BinderT lore m ())
-> (w -> StateT (Stms lore, Scope lore) m ())
-> w
-> BinderT lore m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> StateT (Stms lore, Scope lore) m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT (Stms lore, Scope lore) m ())
-> (w -> m ()) -> w -> StateT (Stms lore, Scope lore) m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. w -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
  pass :: BinderT lore m (a, w -> w) -> BinderT lore m a
pass = (m ((a, w -> w), (Stms lore, Scope lore))
 -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m (a, w -> w) -> BinderT lore m a
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m ((a, w -> w), (Stms lore, Scope lore))
  -> m (a, (Stms lore, Scope lore)))
 -> BinderT lore m (a, w -> w) -> BinderT lore m a)
-> (m ((a, w -> w), (Stms lore, Scope lore))
    -> m (a, (Stms lore, Scope lore)))
-> BinderT lore m (a, w -> w)
-> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ \m ((a, w -> w), (Stms lore, Scope lore))
m -> m ((a, (Stms lore, Scope lore)), w -> w)
-> m (a, (Stms lore, Scope lore))
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (m ((a, (Stms lore, Scope lore)), w -> w)
 -> m (a, (Stms lore, Scope lore)))
-> m ((a, (Stms lore, Scope lore)), w -> w)
-> m (a, (Stms lore, Scope lore))
forall a b. (a -> b) -> a -> b
$ do
    ((a
x, w -> w
f), (Stms lore, Scope lore)
s) <- m ((a, w -> w), (Stms lore, Scope lore))
m
    ((a, (Stms lore, Scope lore)), w -> w)
-> m ((a, (Stms lore, Scope lore)), w -> w)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, (Stms lore, Scope lore)
s), w -> w
f)
  listen :: BinderT lore m a -> BinderT lore m (a, w)
listen = (m (a, (Stms lore, Scope lore))
 -> m ((a, w), (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m (a, w)
forall (m :: * -> *) a lore b.
Monad m =>
(m (a, (Stms lore, Scope lore)) -> m (b, (Stms lore, Scope lore)))
-> BinderT lore m a -> BinderT lore m b
mapInner ((m (a, (Stms lore, Scope lore))
  -> m ((a, w), (Stms lore, Scope lore)))
 -> BinderT lore m a -> BinderT lore m (a, w))
-> (m (a, (Stms lore, Scope lore))
    -> m ((a, w), (Stms lore, Scope lore)))
-> BinderT lore m a
-> BinderT lore m (a, w)
forall a b. (a -> b) -> a -> b
$ \m (a, (Stms lore, Scope lore))
m -> do
    ((a
x, (Stms lore, Scope lore)
s), w
y) <- m (a, (Stms lore, Scope lore))
-> m ((a, (Stms lore, Scope lore)), w)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (a, (Stms lore, Scope lore))
m
    ((a, w), (Stms lore, Scope lore))
-> m ((a, w), (Stms lore, Scope lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, w
y), (Stms lore, Scope lore)
s)

instance MonadError e m => MonadError e (BinderT lore m) where
  throwError :: e -> BinderT lore m a
throwError = m a -> BinderT lore m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> BinderT lore m a) -> (e -> m a) -> e -> BinderT lore m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: BinderT lore m a -> (e -> BinderT lore m a) -> BinderT lore m a
catchError (BinderT StateT (Stms lore, Scope lore) m a
m) e -> BinderT lore m a
f =
    StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall lore (m :: * -> *) a.
StateT (Stms lore, Scope lore) m a -> BinderT lore m a
BinderT (StateT (Stms lore, Scope lore) m a -> BinderT lore m a)
-> StateT (Stms lore, Scope lore) m a -> BinderT lore m a
forall a b. (a -> b) -> a -> b
$ StateT (Stms lore, Scope lore) m a
-> (e -> StateT (Stms lore, Scope lore) m a)
-> StateT (Stms lore, Scope lore) m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError StateT (Stms lore, Scope lore) m a
m ((e -> StateT (Stms lore, Scope lore) m a)
 -> StateT (Stms lore, Scope lore) m a)
-> (e -> StateT (Stms lore, Scope lore) m a)
-> StateT (Stms lore, Scope lore) m a
forall a b. (a -> b) -> a -> b
$ BinderT lore m a -> StateT (Stms lore, Scope lore) m a
forall lore (m :: * -> *) a.
BinderT lore m a -> StateT (Stms lore, Scope lore) m a
unBinder (BinderT lore m a -> StateT (Stms lore, Scope lore) m a)
-> (e -> BinderT lore m a)
-> e
-> StateT (Stms lore, Scope lore) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> BinderT lore m a
f
    where unBinder :: BinderT lore m a -> StateT (Stms lore, Scope lore) m a
unBinder (BinderT StateT (Stms lore, Scope lore) m a
m') = StateT (Stms lore, Scope lore) m a
m'