{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE DeriveAnyClass        #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE UndecidableInstances  #-}
-- | This module defines a variation of
-- free scoped (relative) monads relying on the foil for
-- the scope-safe efficient handling of the binders.
--
-- See description of the approach in [«Free Foil: Generating Efficient and Scope-Safe Abstract Syntax»](https://arxiv.org/abs/2405.16384).
module Control.Monad.Free.Foil where

import           Control.DeepSeq
import qualified Control.Monad.Foil.Internal as Foil
import qualified Control.Monad.Foil.Relative as Foil
import           Data.Bifunctor
import           GHC.Generics                (Generic)

-- | Scoped term under a (single) name binder.
data ScopedAST sig n where
  ScopedAST :: Foil.NameBinder n l -> AST sig l -> ScopedAST sig n

instance (forall l. NFData (AST sig l)) => NFData (ScopedAST sig n) where
  rnf :: ScopedAST sig n -> ()
rnf (ScopedAST NameBinder n l
binder AST sig l
body) = NameBinder n l -> ()
forall a. NFData a => a -> ()
rnf NameBinder n l
binder () -> () -> ()
forall a b. a -> b -> b
`seq` AST sig l -> ()
forall a. NFData a => a -> ()
rnf AST sig l
body

-- | A term, generated by a signature 'Bifunctor' @sig@,
-- with (free) variables in scope @n@.
data AST sig n where
  -- | A (free) variable in scope @n@.
  Var :: Foil.Name n -> AST sig n
  -- | A non-variable syntactic construction specified by the signature 'Bifunctor' @sig@.
  Node :: sig (ScopedAST sig n) (AST sig n) -> AST sig n

deriving instance Generic (AST sig n)
deriving instance (forall scope term. (NFData scope, NFData term) => NFData (sig scope term)) => NFData (AST sig n)

instance Bifunctor sig => Foil.Sinkable (AST sig) where
  sinkabilityProof :: forall (n :: S) (l :: S).
(Name n -> Name l) -> AST sig n -> AST sig l
sinkabilityProof Name n -> Name l
rename = \case
    Var Name n
name -> Name l -> AST sig l
forall (n :: S) (sig :: * -> * -> *). Name n -> AST sig n
Var (Name n -> Name l
rename Name n
name)
    Node sig (ScopedAST sig n) (AST sig n)
node -> sig (ScopedAST sig l) (AST sig l) -> AST sig l
forall (sig :: * -> * -> *) (n :: S).
sig (ScopedAST sig n) (AST sig n) -> AST sig n
Node ((ScopedAST sig n -> ScopedAST sig l)
-> (AST sig n -> AST sig l)
-> sig (ScopedAST sig n) (AST sig n)
-> sig (ScopedAST sig l) (AST sig l)
forall a b c d. (a -> b) -> (c -> d) -> sig a c -> sig b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ScopedAST sig n -> ScopedAST sig l
f ((Name n -> Name l) -> AST sig n -> AST sig l
forall (n :: S) (l :: S).
(Name n -> Name l) -> AST sig n -> AST sig l
forall (e :: S -> *) (n :: S) (l :: S).
Sinkable e =>
(Name n -> Name l) -> e n -> e l
Foil.sinkabilityProof Name n -> Name l
rename) sig (ScopedAST sig n) (AST sig n)
node)
    where
      f :: ScopedAST sig n -> ScopedAST sig l
f (ScopedAST NameBinder n l
binder AST sig l
body) =
        (Name n -> Name l)
-> NameBinder n l
-> (forall (l' :: S).
    (Name l -> Name l') -> NameBinder l l' -> ScopedAST sig l)
-> ScopedAST sig l
forall (pattern :: S -> S -> *) (n :: S) (n' :: S) (l :: S) r.
CoSinkable pattern =>
(Name n -> Name n')
-> pattern n l
-> (forall (l' :: S). (Name l -> Name l') -> pattern n' l' -> r)
-> r
Foil.extendRenaming Name n -> Name l
rename NameBinder n l
binder ((forall (l' :: S).
  (Name l -> Name l') -> NameBinder l l' -> ScopedAST sig l)
 -> ScopedAST sig l)
-> (forall (l' :: S).
    (Name l -> Name l') -> NameBinder l l' -> ScopedAST sig l)
-> ScopedAST sig l
forall a b. (a -> b) -> a -> b
$ \Name l -> Name l'
rename' NameBinder l l'
binder' ->
          NameBinder l l' -> AST sig l' -> ScopedAST sig l
forall (n :: S) (l :: S) (sig :: * -> * -> *).
NameBinder n l -> AST sig l -> ScopedAST sig n
ScopedAST NameBinder l l'
binder' ((Name l -> Name l') -> AST sig l -> AST sig l'
forall (n :: S) (l :: S).
(Name n -> Name l) -> AST sig n -> AST sig l
forall (e :: S -> *) (n :: S) (l :: S).
Sinkable e =>
(Name n -> Name l) -> e n -> e l
Foil.sinkabilityProof Name l -> Name l'
rename' AST sig l
body)

instance Foil.InjectName (AST sig) where
  injectName :: forall (n :: S). Name n -> AST sig n
injectName = Name n -> AST sig n
forall (n :: S) (sig :: * -> * -> *). Name n -> AST sig n
Var

-- | Substitution for free (scoped monads).
substitute
  :: (Bifunctor sig, Foil.Distinct o)
  => Foil.Scope o
  -> Foil.Substitution (AST sig) i o
  -> AST sig i
  -> AST sig o
substitute :: forall (sig :: * -> * -> *) (o :: S) (i :: S).
(Bifunctor sig, Distinct o) =>
Scope o -> Substitution (AST sig) i o -> AST sig i -> AST sig o
substitute Scope o
scope Substitution (AST sig) i o
subst = \case
  Var Name i
name -> Substitution (AST sig) i o -> Name i -> AST sig o
forall (e :: S -> *) (i :: S) (o :: S).
InjectName e =>
Substitution e i o -> Name i -> e o
Foil.lookupSubst Substitution (AST sig) i o
subst Name i
name
  Node sig (ScopedAST sig i) (AST sig i)
node -> sig (ScopedAST sig o) (AST sig o) -> AST sig o
forall (sig :: * -> * -> *) (n :: S).
sig (ScopedAST sig n) (AST sig n) -> AST sig n
Node ((ScopedAST sig i -> ScopedAST sig o)
-> (AST sig i -> AST sig o)
-> sig (ScopedAST sig i) (AST sig i)
-> sig (ScopedAST sig o) (AST sig o)
forall a b c d. (a -> b) -> (c -> d) -> sig a c -> sig b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ScopedAST sig i -> ScopedAST sig o
f (Scope o -> Substitution (AST sig) i o -> AST sig i -> AST sig o
forall (sig :: * -> * -> *) (o :: S) (i :: S).
(Bifunctor sig, Distinct o) =>
Scope o -> Substitution (AST sig) i o -> AST sig i -> AST sig o
substitute Scope o
scope Substitution (AST sig) i o
subst) sig (ScopedAST sig i) (AST sig i)
node)
  where
    f :: ScopedAST sig i -> ScopedAST sig o
f (ScopedAST NameBinder i l
binder AST sig l
body) =
      Scope o
-> Name l
-> (forall (o' :: S).
    DExt o o' =>
    NameBinder o o' -> ScopedAST sig o)
-> ScopedAST sig o
forall (o :: S) (i :: S) r.
Distinct o =>
Scope o
-> Name i
-> (forall (o' :: S). DExt o o' => NameBinder o o' -> r)
-> r
Foil.withRefreshed Scope o
scope (NameBinder i l -> Name l
forall (n :: S) (l :: S). NameBinder n l -> Name l
Foil.nameOf NameBinder i l
binder) ((forall (o' :: S).
  DExt o o' =>
  NameBinder o o' -> ScopedAST sig o)
 -> ScopedAST sig o)
-> (forall (o' :: S).
    DExt o o' =>
    NameBinder o o' -> ScopedAST sig o)
-> ScopedAST sig o
forall a b. (a -> b) -> a -> b
$ \NameBinder o o'
binder' ->
        let subst' :: Substitution (AST sig) l o'
subst' = Substitution (AST sig) i o'
-> NameBinder i l -> Name o' -> Substitution (AST sig) l o'
forall (e :: S -> *) (i :: S) (o :: S) (i' :: S).
InjectName e =>
Substitution e i o
-> NameBinder i i' -> Name o -> Substitution e i' o
Foil.addRename (Substitution (AST sig) i o -> Substitution (AST sig) i o'
forall (e :: S -> *) (n :: S) (l :: S).
(Sinkable e, DExt n l) =>
e n -> e l
Foil.sink Substitution (AST sig) i o
subst) NameBinder i l
binder (NameBinder o o' -> Name o'
forall (n :: S) (l :: S). NameBinder n l -> Name l
Foil.nameOf NameBinder o o'
binder')
            scope' :: Scope o'
scope' = NameBinder o o' -> Scope o -> Scope o'
forall (n :: S) (l :: S). NameBinder n l -> Scope n -> Scope l
Foil.extendScope NameBinder o o'
binder' Scope o
scope
            body' :: AST sig o'
body' = Scope o' -> Substitution (AST sig) l o' -> AST sig l -> AST sig o'
forall (sig :: * -> * -> *) (o :: S) (i :: S).
(Bifunctor sig, Distinct o) =>
Scope o -> Substitution (AST sig) i o -> AST sig i -> AST sig o
substitute Scope o'
scope' Substitution (AST sig) l o'
subst' AST sig l
body
        in NameBinder o o' -> AST sig o' -> ScopedAST sig o
forall (n :: S) (l :: S) (sig :: * -> * -> *).
NameBinder n l -> AST sig l -> ScopedAST sig n
ScopedAST NameBinder o o'
binder' AST sig o'
body'

-- | @'AST' sig@ is a monad relative to 'Foil.Name'.
instance Bifunctor sig => Foil.RelMonad Foil.Name (AST sig) where
  rreturn :: forall (a :: S). Name a -> AST sig a
rreturn = Name a -> AST sig a
forall (n :: S) (sig :: * -> * -> *). Name n -> AST sig n
Var
  rbind :: forall (b :: S) (a :: S).
Distinct b =>
Scope b -> AST sig a -> (Name a -> AST sig b) -> AST sig b
rbind Scope b
scope AST sig a
term Name a -> AST sig b
subst =
    case AST sig a
term of
      Var Name a
name  -> Name a -> AST sig b
subst Name a
name
      Node sig (ScopedAST sig a) (AST sig a)
node -> sig (ScopedAST sig b) (AST sig b) -> AST sig b
forall (sig :: * -> * -> *) (n :: S).
sig (ScopedAST sig n) (AST sig n) -> AST sig n
Node ((ScopedAST sig a -> ScopedAST sig b)
-> (AST sig a -> AST sig b)
-> sig (ScopedAST sig a) (AST sig a)
-> sig (ScopedAST sig b) (AST sig b)
forall a b c d. (a -> b) -> (c -> d) -> sig a c -> sig b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ScopedAST sig a -> ScopedAST sig b
g' AST sig a -> AST sig b
g sig (ScopedAST sig a) (AST sig a)
node)
    where
      g :: AST sig a -> AST sig b
g AST sig a
x = Scope b -> AST sig a -> (Name a -> AST sig b) -> AST sig b
forall (b :: S) (a :: S).
Distinct b =>
Scope b -> AST sig a -> (Name a -> AST sig b) -> AST sig b
forall (f :: S -> *) (m :: S -> *) (b :: S) (a :: S).
(RelMonad f m, Distinct b) =>
Scope b -> m a -> (f a -> m b) -> m b
Foil.rbind Scope b
scope AST sig a
x Name a -> AST sig b
subst
      g' :: ScopedAST sig a -> ScopedAST sig b
g' (ScopedAST NameBinder a l
binder AST sig l
body) =
        Scope b
-> Name l
-> (forall (o' :: S).
    DExt b o' =>
    NameBinder b o' -> ScopedAST sig b)
-> ScopedAST sig b
forall (o :: S) (i :: S) r.
Distinct o =>
Scope o
-> Name i
-> (forall (o' :: S). DExt o o' => NameBinder o o' -> r)
-> r
Foil.withRefreshed Scope b
scope (NameBinder a l -> Name l
forall (n :: S) (l :: S). NameBinder n l -> Name l
Foil.nameOf NameBinder a l
binder) ((forall (o' :: S).
  DExt b o' =>
  NameBinder b o' -> ScopedAST sig b)
 -> ScopedAST sig b)
-> (forall (o' :: S).
    DExt b o' =>
    NameBinder b o' -> ScopedAST sig b)
-> ScopedAST sig b
forall a b. (a -> b) -> a -> b
$ \NameBinder b o'
binder' ->
          let scope' :: Scope o'
scope' = NameBinder b o' -> Scope b -> Scope o'
forall (n :: S) (l :: S). NameBinder n l -> Scope n -> Scope l
Foil.extendScope NameBinder b o'
binder' Scope b
scope
              subst' :: Name l -> AST sig o'
subst' Name l
name = case NameBinder a l -> Name l -> Maybe (Name a)
forall (n :: S) (l :: S).
NameBinder n l -> Name l -> Maybe (Name n)
Foil.unsinkName NameBinder a l
binder Name l
name of
                          Maybe (Name a)
Nothing -> Name o' -> AST sig o'
forall (a :: S). Name a -> AST sig a
forall (f :: S -> *) (m :: S -> *) (a :: S).
RelMonad f m =>
f a -> m a
Foil.rreturn (NameBinder b o' -> Name o'
forall (n :: S) (l :: S). NameBinder n l -> Name l
Foil.nameOf NameBinder b o'
binder')
                          Just Name a
n  -> AST sig b -> AST sig o'
forall (e :: S -> *) (n :: S) (l :: S).
(Sinkable e, DExt n l) =>
e n -> e l
Foil.sink (Name a -> AST sig b
subst Name a
n)
           in NameBinder b o' -> AST sig o' -> ScopedAST sig b
forall (n :: S) (l :: S) (sig :: * -> * -> *).
NameBinder n l -> AST sig l -> ScopedAST sig n
ScopedAST NameBinder b o'
binder' (Scope o' -> AST sig l -> (Name l -> AST sig o') -> AST sig o'
forall (b :: S) (a :: S).
Distinct b =>
Scope b -> AST sig a -> (Name a -> AST sig b) -> AST sig b
forall (f :: S -> *) (m :: S -> *) (b :: S) (a :: S).
(RelMonad f m, Distinct b) =>
Scope b -> m a -> (f a -> m b) -> m b
Foil.rbind Scope o'
scope' AST sig l
body Name l -> AST sig o'
subst')