module Language.Lambda.Untyped.Eval
  ( EvalState(..),
    evalExpr,
    subGlobals,
    betaReduce,
    alphaConvert,
    etaConvert,
    freeVarsOf
  ) where

import Control.Monad.Except
import Prettyprinter
import RIO
import RIO.List (find)
import qualified RIO.Map as Map

import Language.Lambda.Shared.Errors
import Language.Lambda.Untyped.Expression
import Language.Lambda.Untyped.State

-- | Evaluate an expression
evalExpr :: (Pretty name, Ord name) => LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr :: LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr (Let name
name LambdaExpr name
expr) = do
  Globals name
globals' <- Eval name (Globals name)
forall name. Eval name (Globals name)
getGlobals
  LambdaExpr name
result <- LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ Globals name -> LambdaExpr name -> LambdaExpr name
forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Globals name
globals' LambdaExpr name
expr

  Globals name -> Eval name ()
forall name. Globals name -> Eval name ()
setGlobals (Globals name -> Eval name ()) -> Globals name -> Eval name ()
forall a b. (a -> b) -> a -> b
$ name -> LambdaExpr name -> Globals name -> Globals name
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert name
name LambdaExpr name
result Globals name
globals'

  LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Let name
name LambdaExpr name
result

evalExpr LambdaExpr name
expr = do
  Globals name
globals' <- Eval name (Globals name)
forall name. Eval name (Globals name)
getGlobals
  LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ Globals name -> LambdaExpr name -> LambdaExpr name
forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Globals name
globals' LambdaExpr name
expr

-- | Evaluate an expression; does not support `let`
evalExpr' :: (Eq name, Pretty name) => LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' :: LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' expr :: LambdaExpr name
expr@(Var name
_) = LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return LambdaExpr name
expr
evalExpr' (Abs name
name LambdaExpr name
expr) = name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name (LambdaExpr name -> LambdaExpr name)
-> Eval name (LambdaExpr name) -> Eval name (LambdaExpr name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
expr
evalExpr' (App LambdaExpr name
e1 LambdaExpr name
e2) = do
  LambdaExpr name
e1' <- LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
e1
  LambdaExpr name
e2' <- LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' LambdaExpr name
e2
  LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce LambdaExpr name
e1' LambdaExpr name
e2'
evalExpr' expr :: LambdaExpr name
expr@(Let name
_ LambdaExpr name
_) = LambdaException -> Eval name (LambdaExpr name)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (LambdaException -> Eval name (LambdaExpr name))
-> (LambdaExpr name -> LambdaException)
-> LambdaExpr name
-> Eval name (LambdaExpr name)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> LambdaException
InvalidLet (Text -> LambdaException)
-> (LambdaExpr name -> Text) -> LambdaExpr name -> LambdaException
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LambdaExpr name -> Text
forall name. Pretty name => LambdaExpr name -> Text
prettyPrint (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ LambdaExpr name
expr

-- | Look up free vars that have global bindings and substitute them
subGlobals
  :: Ord name
  => Map name (LambdaExpr name)
  -> LambdaExpr name
  -> LambdaExpr name
subGlobals :: Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' expr :: LambdaExpr name
expr@(Var name
x) = LambdaExpr name
-> name -> Map name (LambdaExpr name) -> LambdaExpr name
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault LambdaExpr name
expr name
x Map name (LambdaExpr name)
globals'
subGlobals Map name (LambdaExpr name)
globals' (App LambdaExpr name
e1 LambdaExpr name
e2) = LambdaExpr name -> LambdaExpr name -> LambdaExpr name
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
e1) (Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
e2)
subGlobals Map name (LambdaExpr name)
globals' (Abs name
name LambdaExpr name
expr) = name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name LambdaExpr name
expr'
  where expr' :: LambdaExpr name
expr'
          | name -> Map name (LambdaExpr name) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.member name
name Map name (LambdaExpr name)
globals' = LambdaExpr name
expr
          | Bool
otherwise = Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
forall name.
Ord name =>
Map name (LambdaExpr name) -> LambdaExpr name -> LambdaExpr name
subGlobals Map name (LambdaExpr name)
globals' LambdaExpr name
expr
subGlobals Map name (LambdaExpr name)
_ LambdaExpr name
expr = LambdaExpr name
expr

-- | Function application
betaReduce
  :: (Eq name, Pretty name)
  => LambdaExpr name
  -> LambdaExpr name
  -> Eval name (LambdaExpr name)
betaReduce :: LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce expr :: LambdaExpr name
expr@(Var name
_) LambdaExpr name
e2 = LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ LambdaExpr name -> LambdaExpr name -> LambdaExpr name
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App LambdaExpr name
expr LambdaExpr name
e2
betaReduce (App LambdaExpr name
e1 LambdaExpr name
e1') LambdaExpr name
e2 = do
  LambdaExpr name
reduced <- LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> LambdaExpr name -> Eval name (LambdaExpr name)
betaReduce LambdaExpr name
e1 LambdaExpr name
e1'
  LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ LambdaExpr name -> LambdaExpr name -> LambdaExpr name
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App LambdaExpr name
reduced LambdaExpr name
e2
betaReduce (Abs name
n LambdaExpr name
e1) LambdaExpr name
e2 = do
  LambdaExpr name
e1' <- [name] -> LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
Eq name =>
[name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert (LambdaExpr name -> [name]
forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr name
e2) LambdaExpr name
e1
  LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
(Eq name, Pretty name) =>
LambdaExpr name -> Eval name (LambdaExpr name)
evalExpr' (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
e1' name
n LambdaExpr name
e2
betaReduce LambdaExpr name
_ LambdaExpr name
_ = LambdaException -> Eval name (LambdaExpr name)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError LambdaException
ImpossibleError

-- | Rename abstraction parameters to avoid name captures
alphaConvert :: Eq name => [name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert :: [name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert [name]
freeVars (Abs name
name LambdaExpr name
body) = do
  [name]
uniques' <- Eval name [name]
forall name. Eval name [name]
getUniques
  let nextVar :: name
nextVar = name -> Maybe name -> name
forall a. a -> Maybe a -> a
fromMaybe name
name (Maybe name -> name) -> Maybe name -> name
forall a b. (a -> b) -> a -> b
$ (name -> Bool) -> [name] -> Maybe name
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (name -> [name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [name]
freeVars) [name]
uniques'

  if name
name name -> [name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [name]
freeVars
    then LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaExpr name -> Eval name (LambdaExpr name))
-> LambdaExpr name -> Eval name (LambdaExpr name)
forall a b. (a -> b) -> a -> b
$ name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
nextVar (LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
body name
name (name -> LambdaExpr name
forall name. name -> LambdaExpr name
Var name
nextVar))
    else name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name (LambdaExpr name -> LambdaExpr name)
-> Eval name (LambdaExpr name) -> Eval name (LambdaExpr name)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [name] -> LambdaExpr name -> Eval name (LambdaExpr name)
forall name.
Eq name =>
[name] -> LambdaExpr name -> Eval name (LambdaExpr name)
alphaConvert [name]
freeVars LambdaExpr name
body

alphaConvert [name]
_ LambdaExpr name
expr = LambdaExpr name -> Eval name (LambdaExpr name)
forall (m :: * -> *) a. Monad m => a -> m a
return LambdaExpr name
expr

-- | Eliminite superfluous abstractions
etaConvert :: Eq n => LambdaExpr n -> LambdaExpr n
etaConvert :: LambdaExpr n -> LambdaExpr n
etaConvert (Abs n
n (App LambdaExpr n
e1 (Var n
n')))
  | n
n n -> n -> Bool
forall a. Eq a => a -> a -> Bool
== n
n'   = LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1
  | Bool
otherwise = n -> LambdaExpr n -> LambdaExpr n
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n (LambdaExpr n -> LambdaExpr n -> LambdaExpr n
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1) (n -> LambdaExpr n
forall name. name -> LambdaExpr name
Var n
n'))
etaConvert (Abs n
n e :: LambdaExpr n
e@(Abs n
_ LambdaExpr n
_)) 
  -- If `etaConvert e == e` then etaConverting it will create an infinite loop
  | LambdaExpr n
e LambdaExpr n -> LambdaExpr n -> Bool
forall a. Eq a => a -> a -> Bool
== LambdaExpr n
e'   = n -> LambdaExpr n -> LambdaExpr n
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n LambdaExpr n
e'
  | Bool
otherwise = LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert (n -> LambdaExpr n -> LambdaExpr n
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n LambdaExpr n
e')
  where e' :: LambdaExpr n
e' = LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e
etaConvert (Abs n
n LambdaExpr n
expr) = n -> LambdaExpr n -> LambdaExpr n
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs n
n (LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
expr)
etaConvert (App LambdaExpr n
e1 LambdaExpr n
e2)  = LambdaExpr n -> LambdaExpr n -> LambdaExpr n
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e1) (LambdaExpr n -> LambdaExpr n
forall n. Eq n => LambdaExpr n -> LambdaExpr n
etaConvert LambdaExpr n
e2)
etaConvert LambdaExpr n
expr = LambdaExpr n
expr

-- | Substitute an expression for a variable name in another expression
substitute :: Eq name => LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute :: LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute subExpr :: LambdaExpr name
subExpr@(Var name
name) name
subName LambdaExpr name
inExpr
  | name
name name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
subName = LambdaExpr name
inExpr
  | Bool
otherwise = LambdaExpr name
subExpr

substitute subExpr :: LambdaExpr name
subExpr@(Abs name
name LambdaExpr name
expr) name
subName LambdaExpr name
inExpr
  | name
name name -> name -> Bool
forall a. Eq a => a -> a -> Bool
== name
subName = LambdaExpr name
subExpr
  | Bool
otherwise = name -> LambdaExpr name -> LambdaExpr name
forall name. name -> LambdaExpr name -> LambdaExpr name
Abs name
name (LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
expr name
subName LambdaExpr name
inExpr)

substitute (App LambdaExpr name
e1 LambdaExpr name
e2) name
subName LambdaExpr name
inExpr
  = LambdaExpr name -> LambdaExpr name -> LambdaExpr name
forall name. LambdaExpr name -> LambdaExpr name -> LambdaExpr name
App (LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
e1) (LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
e2)
  where sub :: LambdaExpr name -> LambdaExpr name
sub LambdaExpr name
expr = LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
forall name.
Eq name =>
LambdaExpr name -> name -> LambdaExpr name -> LambdaExpr name
substitute LambdaExpr name
expr name
subName LambdaExpr name
inExpr

substitute LambdaExpr name
_ name
_ LambdaExpr name
expr = LambdaExpr name
expr

-- | Find the free variables in an expression
freeVarsOf :: Eq n => LambdaExpr n -> [n]
freeVarsOf :: LambdaExpr n -> [n]
freeVarsOf (Abs n
n LambdaExpr n
expr) = (n -> Bool) -> [n] -> [n]
forall a. (a -> Bool) -> [a] -> [a]
filter (n -> n -> Bool
forall a. Eq a => a -> a -> Bool
/=n
n) ([n] -> [n]) -> (LambdaExpr n -> [n]) -> LambdaExpr n -> [n]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LambdaExpr n -> [n]
forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf (LambdaExpr n -> [n]) -> LambdaExpr n -> [n]
forall a b. (a -> b) -> a -> b
$ LambdaExpr n
expr
freeVarsOf (App LambdaExpr n
e1 LambdaExpr n
e2)  = LambdaExpr n -> [n]
forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr n
e1 [n] -> [n] -> [n]
forall a. [a] -> [a] -> [a]
++ LambdaExpr n -> [n]
forall n. Eq n => LambdaExpr n -> [n]
freeVarsOf LambdaExpr n
e2
freeVarsOf (Var n
n)      = [n
n]
freeVarsOf LambdaExpr n
_ = []