{-|
This module contains a simple lambda calculus parser. This parser is not optimized for maximum
performance; instead it's written in a style which emulates the look and feel of conventional
monadic parsers. An optimized implementation would use low-level `switch` expressions more often.
-}

{-# language StrictData #-}

module FlatParse.Examples.BasicLambda.Parser where

import Data.Char (ord)
import qualified Data.ByteString as B

import FlatParse.Basic hiding (Parser, runParser, string, char, cut)
import FlatParse.Examples.BasicLambda.Lexer

--------------------------------------------------------------------------------

type Name = B.ByteString

{-|
A term in the language. The precedences of different constructs are the following, in decreasing
order of strength:

* Identifiers, literals and parenthesized expressions
* Function application (left assoc)
* Multiplication (left assoc)
* Addition (left assoc)
* Equality, less-than (non-assoc)
* @lam@, @let@, @if@ (right assoc)

-}
data Tm
  = Var Name        -- ^ @x@
  | App Tm Tm       -- ^ @t u@
  | Lam Name Tm     -- ^ @lam x. t@
  | Let Name Tm Tm  -- ^ @let x = t in u@
  | BoolLit Bool    -- ^ @true@ or @false@.
  | IntLit Int      -- ^ A positive `Int` literal.
  | If Tm Tm Tm     -- ^ @if t then u else v@
  | Add Tm Tm       -- ^ @t + u@
  | Mul Tm Tm       -- ^ @t * u@
  | Eq Tm Tm        -- ^ @t == u@
  | Lt Tm Tm        -- ^ @t < u@
  deriving Int -> Tm -> ShowS
[Tm] -> ShowS
Tm -> String
(Int -> Tm -> ShowS)
-> (Tm -> String) -> ([Tm] -> ShowS) -> Show Tm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Tm -> ShowS
showsPrec :: Int -> Tm -> ShowS
$cshow :: Tm -> String
show :: Tm -> String
$cshowList :: [Tm] -> ShowS
showList :: [Tm] -> ShowS
Show


-- | Parse an identifier. This parser uses `isKeyword` to check that an identifier is not a
--   keyword.
ident :: Parser Name
ident :: Parser Name
ident = Parser Name -> Parser Name
forall a. Parser a -> Parser a
token (Parser Name -> Parser Name) -> Parser Name -> Parser Name
forall a b. (a -> b) -> a -> b
$ ParserT PureMode Error () -> Parser Name
forall (st :: ZeroBitType) e a. ParserT st e a -> ParserT st e Name
byteStringOf (ParserT PureMode Error () -> Parser Name)
-> ParserT PureMode Error () -> Parser Name
forall a b. (a -> b) -> a -> b
$
  ParserT PureMode Error ()
-> (() -> Span -> ParserT PureMode Error ())
-> ParserT PureMode Error ()
forall (st :: ZeroBitType) e a b.
ParserT st e a -> (a -> Span -> ParserT st e b) -> ParserT st e b
withSpan (Parser Char
identStartChar Parser Char
-> ParserT PureMode Error () -> ParserT PureMode Error ()
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Char -> ParserT PureMode Error ()
forall (st :: ZeroBitType) e a. ParserT st e a -> ParserT st e ()
skipMany Parser Char
identChar) (\()
_ Span
span -> ParserT PureMode Error () -> ParserT PureMode Error ()
forall (st :: ZeroBitType) e a. ParserT st e a -> ParserT st e ()
fails (Span -> ParserT PureMode Error ()
isKeyword Span
span))

-- | Parse an identifier, throw a precise error on failure.
ident' :: Parser Name
ident' :: Parser Name
ident' = Parser Name
ident Parser Name -> Expected -> Parser Name
forall a. Parser a -> Expected -> Parser a
`cut'` (String -> Expected
Msg String
"identifier")

digit :: Parser Int
digit :: Parser Int
digit = (\Char
c -> Char -> Int
ord Char
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Char -> Int
ord Char
'0') (Char -> Int) -> Parser Char -> Parser Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Char -> Bool) -> Parser Char
forall (st :: ZeroBitType) e. (Char -> Bool) -> ParserT st e Char
satisfyAscii Char -> Bool
isDigit

int :: Parser Int
int :: Parser Int
int = Parser Int -> Parser Int
forall a. Parser a -> Parser a
token do
  (Int
place, Int
n) <- (Int -> (Int, Int) -> (Int, Int))
-> Parser Int
-> ParserT PureMode Error (Int, Int)
-> ParserT PureMode Error (Int, Int)
forall a b (st :: ZeroBitType) e.
(a -> b -> b) -> ParserT st e a -> ParserT st e b -> ParserT st e b
chainr (\Int
n (!Int
place, !Int
acc) -> (Int
placeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
10,Int
accInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
placeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n)) Parser Int
digit ((Int, Int) -> ParserT PureMode Error (Int, Int)
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
1, Int
0))
  case Int
place of
    Int
1 -> Parser Int
forall a. ParserT PureMode Error a
forall (f :: * -> *) a. Alternative f => f a
empty
    Int
_ -> Int -> Parser Int
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n

-- | Parse a literal, identifier or parenthesized expression.
atom :: Parser Tm
atom :: Parser Tm
atom =
       (Name -> Tm
Var           (Name -> Tm) -> Parser Name -> Parser Tm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Name
ident)
   Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> (Bool -> Tm
BoolLit Bool
True  Tm -> ParserT PureMode Error () -> Parser Tm
forall a b.
a -> ParserT PureMode Error b -> ParserT PureMode Error a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$  $(keyword "true"))
   Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> (Bool -> Tm
BoolLit Bool
False Tm -> ParserT PureMode Error () -> Parser Tm
forall a b.
a -> ParserT PureMode Error b -> ParserT PureMode Error a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$  $(keyword "false"))
   Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> (Int -> Tm
IntLit        (Int -> Tm) -> Parser Int -> Parser Tm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Int
int)
   Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> ($(symbol "(") ParserT PureMode Error () -> Parser Tm -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Tm
tm' Parser Tm -> ParserT PureMode Error () -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* $(symbol' ")"))

atom' :: Parser Tm
atom' :: Parser Tm
atom' = Parser Tm
atom
  Parser Tm -> [Expected] -> Parser Tm
forall a. Parser a -> [Expected] -> Parser a
`cut` [String -> Expected
Msg String
"identifier", Expected
"true", Expected
"false", String -> Expected
Msg String
"parenthesized expression", String -> Expected
Msg String
"integer literal"]

-- | Parse an `App`-level expression.
app' :: Parser Tm
app' :: Parser Tm
app' = (Tm -> Tm -> Tm) -> Parser Tm -> Parser Tm -> Parser Tm
forall b a (st :: ZeroBitType) e.
(b -> a -> b) -> ParserT st e b -> ParserT st e a -> ParserT st e b
chainl Tm -> Tm -> Tm
App Parser Tm
atom' Parser Tm
atom

-- | Parse a `Mul`-level expression.
mul' :: Parser Tm
mul' :: Parser Tm
mul' = (Tm -> Tm -> Tm) -> Parser Tm -> Parser Tm -> Parser Tm
forall b a (st :: ZeroBitType) e.
(b -> a -> b) -> ParserT st e b -> ParserT st e a -> ParserT st e b
chainl Tm -> Tm -> Tm
Mul Parser Tm
app' ($(symbol "*") ParserT PureMode Error () -> Parser Tm -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Tm
app')

-- | Parse an `Add`-level expression.
add' :: Parser Tm
add' :: Parser Tm
add' = (Tm -> Tm -> Tm) -> Parser Tm -> Parser Tm -> Parser Tm
forall b a (st :: ZeroBitType) e.
(b -> a -> b) -> ParserT st e b -> ParserT st e a -> ParserT st e b
chainl Tm -> Tm -> Tm
Add Parser Tm
mul' ($(symbol "+") ParserT PureMode Error () -> Parser Tm -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Tm
mul')

-- | Parse an `FlatParse.Examples.BasicLambda.Parser.Eq` or `Lt`-level expression.
eqLt' :: Parser Tm
eqLt' :: Parser Tm
eqLt' =
  Parser Tm
add' Parser Tm -> (Tm -> Parser Tm) -> Parser Tm
forall a b.
ParserT PureMode Error a
-> (a -> ParserT PureMode Error b) -> ParserT PureMode Error b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Tm
e1 ->
  ParserT PureMode Error () -> Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a b.
ParserT st e a
-> ParserT st e b -> ParserT st e b -> ParserT st e b
branch $(symbol "==") (Tm -> Tm -> Tm
Eq Tm
e1 (Tm -> Tm) -> Parser Tm -> Parser Tm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Tm
add') (Parser Tm -> Parser Tm) -> Parser Tm -> Parser Tm
forall a b. (a -> b) -> a -> b
$
  ParserT PureMode Error () -> Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a b.
ParserT st e a
-> ParserT st e b -> ParserT st e b -> ParserT st e b
branch $(symbol "<")  (Tm -> Tm -> Tm
Lt Tm
e1 (Tm -> Tm) -> Parser Tm -> Parser Tm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Tm
add') (Parser Tm -> Parser Tm) -> Parser Tm -> Parser Tm
forall a b. (a -> b) -> a -> b
$
  Tm -> Parser Tm
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tm
e1

-- | Parse a `Let`.
pLet :: Parser Tm
pLet :: Parser Tm
pLet = do
  $(keyword "let")
  Name
x <- Parser Name
ident'
  $(symbol' "=")
  Tm
t <- Parser Tm
tm'
  $(keyword' "in")
  Tm
u <- Parser Tm
tm'
  Tm -> Parser Tm
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tm -> Parser Tm) -> Tm -> Parser Tm
forall a b. (a -> b) -> a -> b
$ Name -> Tm -> Tm -> Tm
Let Name
x Tm
t Tm
u

-- | Parse a `Lam`.
lam :: Parser Tm
lam :: Parser Tm
lam = do
  $(keyword "lam")
  Name
x <- Parser Name
ident'
  $(symbol' ".")
  Tm
t <- Parser Tm
tm'
  Tm -> Parser Tm
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tm -> Parser Tm) -> Tm -> Parser Tm
forall a b. (a -> b) -> a -> b
$ Name -> Tm -> Tm
Lam Name
x Tm
t

-- | Parse an `If`.
pIf :: Parser Tm
pIf :: Parser Tm
pIf = do
  $(keyword "if")
  Tm
t <- Parser Tm
tm'
  $(keyword' "then")
  Tm
u <- Parser Tm
tm'
  $(keyword' "else")
  Tm
v <- Parser Tm
tm'
  Tm -> Parser Tm
forall a. a -> ParserT PureMode Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tm -> Parser Tm) -> Tm -> Parser Tm
forall a b. (a -> b) -> a -> b
$ Tm -> Tm -> Tm -> Tm
If Tm
t Tm
u Tm
v

-- | Parse any `Tm`.
tm' :: Parser Tm
tm' :: Parser Tm
tm' = (Parser Tm
pLet Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> Parser Tm
lam Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> Parser Tm
pIf Parser Tm -> Parser Tm -> Parser Tm
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e a -> ParserT st e a
<|> Parser Tm
eqLt') Parser Tm -> [Expected] -> Parser Tm
forall a. Parser a -> [Expected] -> Parser a
`cut` [Expected
"let", Expected
"lam", Expected
"if"]

-- | Parse a complete source file.
src' :: Parser Tm
src' :: Parser Tm
src' = ParserT PureMode Error ()
ws ParserT PureMode Error () -> Parser Tm -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Tm
tm' Parser Tm -> ParserT PureMode Error () -> Parser Tm
forall a b.
ParserT PureMode Error a
-> ParserT PureMode Error b -> ParserT PureMode Error a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParserT PureMode Error ()
forall (st :: ZeroBitType) e. ParserT st e ()
eof ParserT PureMode Error ()
-> [Expected] -> ParserT PureMode Error ()
forall a. Parser a -> [Expected] -> Parser a
`cut` [String -> Expected
Msg String
"end of input (lexical error)"]


-- Examples
--------------------------------------------------------------------------------

-- testParser src' p1
p1 :: String
p1 = [String] -> String
unlines [
  String
"let f = lam x. lam y. x (x (x y)) in",
  String
"let g = if f true then false else true in",
  String
"let h = f x y + 200 in",
  String
"f g g h"
  ]