{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms, ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{- | An encoding for algebraic effects, based on the @freer@ monad. 
-}

module Prog (
  -- * Effectful program
    Prog(..)
  , EffectSum
  , Member(..)
  -- * Auxiliary functions
  , run
  , call
  , discharge) where

import Control.Monad ( (>=>) )
import Data.Kind (Constraint)
import FindElem ( Idx(unIdx), FindElem(..) )
import GHC.TypeLits ( TypeError, ErrorMessage(Text, (:<>:), (:$$:), ShowType) )
import Unsafe.Coerce ( unsafeCoerce )

-- | A program that returns a value of type @a@ and can call operations that belong to some effect @e@ in signature @es@; this represents a syntax tree whose nodes are operations and leaves are pure values.
data Prog es a where
  Val 
    :: a                -- ^ pure value 
    -> Prog es a
  Op 
    :: EffectSum es x   -- ^ an operation belonging to some effect in @es@
    -> (x -> Prog es a) -- ^ a continuation from the result of the operation
    -> Prog es a

instance Functor (Prog es) where
  fmap :: forall a b. (a -> b) -> Prog es a -> Prog es b
fmap a -> b
f (Val a
a) = b -> Prog es b
forall a (es :: [* -> *]). a -> Prog es a
Val (a -> b
f a
a)
  fmap a -> b
f (Op EffectSum es x
fx x -> Prog es a
k) = EffectSum es x -> (x -> Prog es b) -> Prog es b
forall (es :: [* -> *]) e a.
EffectSum es e -> (e -> Prog es a) -> Prog es a
Op EffectSum es x
fx ((a -> b) -> Prog es a -> Prog es b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f (Prog es a -> Prog es b) -> (x -> Prog es a) -> x -> Prog es b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog es a
k)

instance Applicative (Prog es) where
  pure :: forall a. a -> Prog es a
pure = a -> Prog es a
forall a (es :: [* -> *]). a -> Prog es a
Val
  Val a -> b
f <*> :: forall a b. Prog es (a -> b) -> Prog es a -> Prog es b
<*> Prog es a
x = (a -> b) -> Prog es a -> Prog es b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f Prog es a
x
  (Op EffectSum es x
fx x -> Prog es (a -> b)
k) <*> Prog es a
x = EffectSum es x -> (x -> Prog es b) -> Prog es b
forall (es :: [* -> *]) e a.
EffectSum es e -> (e -> Prog es a) -> Prog es a
Op EffectSum es x
fx ((Prog es (a -> b) -> Prog es a -> Prog es b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Prog es a
x) (Prog es (a -> b) -> Prog es b)
-> (x -> Prog es (a -> b)) -> x -> Prog es b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog es (a -> b)
k)

instance Monad (Prog es) where
  return :: forall a. a -> Prog es a
return            = a -> Prog es a
forall a (es :: [* -> *]). a -> Prog es a
Val
  Val a
a >>= :: forall a b. Prog es a -> (a -> Prog es b) -> Prog es b
>>= a -> Prog es b
f      = a -> Prog es b
f a
a
  Op EffectSum es x
fx x -> Prog es a
k >>= a -> Prog es b
f = EffectSum es x -> (x -> Prog es b) -> Prog es b
forall (es :: [* -> *]) e a.
EffectSum es e -> (e -> Prog es a) -> Prog es a
Op EffectSum es x
fx (x -> Prog es a
k (x -> Prog es a) -> (a -> Prog es b) -> x -> Prog es b
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> Prog es b
f)

-- | An open sum for an effect signature @es@, containing an operation @e x@ where @e@ is in @es@
data EffectSum (es :: [* -> *]) (x :: *) :: * where
  EffectSum :: Int -> e x -> EffectSum es x

-- | Membership of an effect @e@ in @es@
class (FindElem e es) => Member (e :: * -> *) (es :: [* -> *]) where
  -- | Inject an operation of type @e x@ into an effect sum
  inj ::  e x -> EffectSum es x
  -- | Attempt to project an operation of type @e x@ out from an effect sum
  prj ::  EffectSum es x -> Maybe (e x)

instance {-# INCOHERENT #-} (e ~ e') => Member e '[e'] where
   inj :: forall x. e x -> EffectSum '[e'] x
inj e x
x  = Int -> e x -> EffectSum '[e'] x
forall (e :: * -> *) x (es :: [* -> *]).
Int -> e x -> EffectSum es x
EffectSum Int
0 e x
x
   prj :: forall x. EffectSum '[e'] x -> Maybe (e x)
prj (EffectSum Int
_ e x
x) = e x -> Maybe (e x)
forall a. a -> Maybe a
Just (e x -> e x
forall a b. a -> b
unsafeCoerce e x
x)

instance (FindElem e es) => Member e es where
  inj :: forall x. e x -> EffectSum es x
inj = Int -> e x -> EffectSum es x
forall (e :: * -> *) x (es :: [* -> *]).
Int -> e x -> EffectSum es x
EffectSum (Idx e es -> Int
forall {k1} {k2} (x :: k1) (xs :: k2). Idx x xs -> Int
unIdx (Idx e es
forall {k} {k} (x :: k) (xs :: k). FindElem x xs => Idx x xs
findElem :: Idx e es))
  prj :: forall x. EffectSum es x -> Maybe (e x)
prj = Int -> EffectSum es x -> Maybe (e x)
forall {es :: [* -> *]} {x} {a}. Int -> EffectSum es x -> Maybe a
prj' (Idx e es -> Int
forall {k1} {k2} (x :: k1) (xs :: k2). Idx x xs -> Int
unIdx (Idx e es
forall {k} {k} (x :: k) (xs :: k). FindElem x xs => Idx x xs
findElem :: Idx e es))
    where prj' :: Int -> EffectSum es x -> Maybe a
prj' Int
n (EffectSum Int
n' e x
x)
            | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n'   = a -> Maybe a
forall a. a -> Maybe a
Just (e x -> a
forall a b. a -> b
unsafeCoerce e x
x)
            | Bool
otherwise = Maybe a
forall a. Maybe a
Nothing

-- | Membership of many effects @es@ in @ess@
type family Members (es :: [* -> *]) (ess :: [* -> *]) = (cs :: Constraint) | cs -> es where
  Members (e ': es) ess = (Member e ess, Members es ess)
  Members '[] ess       = ()

-- | Run a pure computation
run :: Prog '[] a -> a
run :: forall a. Prog '[] a -> a
run (Val a
x) = a
x
run Prog '[] a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"'run' isn't defined for non-pure computations"

-- | Call an operation of type @e x@ in a computation
call :: (Member e es) => e x -> Prog es x
call :: forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> Prog es x
call e x
e = EffectSum es x -> (x -> Prog es x) -> Prog es x
forall (es :: [* -> *]) e a.
EffectSum es e -> (e -> Prog es a) -> Prog es a
Op (e x -> EffectSum es x
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
e x -> EffectSum es x
inj e x
e) x -> Prog es x
forall a (es :: [* -> *]). a -> Prog es a
Val

-- | Discharges an effect @e@ from the front of an effect signature @es@
discharge :: EffectSum (e ': es) x -> Either (EffectSum es x) (e x)
discharge :: forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge (EffectSum Int
0 e x
tv) = e x -> Either (EffectSum es x) (e x)
forall a b. b -> Either a b
Right (e x -> Either (EffectSum es x) (e x))
-> e x -> Either (EffectSum es x) (e x)
forall a b. (a -> b) -> a -> b
$ e x -> e x
forall a b. a -> b
unsafeCoerce e x
tv
discharge (EffectSum Int
n e x
rv) = EffectSum es x -> Either (EffectSum es x) (e x)
forall a b. a -> Either a b
Left  (EffectSum es x -> Either (EffectSum es x) (e x))
-> EffectSum es x -> Either (EffectSum es x) (e x)
forall a b. (a -> b) -> a -> b
$ Int -> e x -> EffectSum es x
forall (e :: * -> *) x (es :: [* -> *]).
Int -> e x -> EffectSum es x
EffectSum (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) e x
rv

-- | For pattern-matching against operations that belong in the tail of an effect signature
pattern Other :: EffectSum es x -> EffectSum  (e ': es) x
pattern $mOther :: forall {r} {es :: [* -> *]} {x} {e :: * -> *}.
EffectSum (e : es) x -> (EffectSum es x -> r) -> (Void# -> r) -> r
Other u <- (discharge -> Left u)