module F where
import Unbound.LocallyNameless hiding (prec,empty,Data,Refl)
import Control.Monad
import Control.Monad.Trans.Except
import qualified Data.List as List
import Util
import Text.PrettyPrint as PP
type TyName = Name Ty
type TmName = Name Tm
data Ty = TyVar TyName
| TyInt
| Arr Ty Ty
| All (Bind TyName Ty)
| TyProd [Ty]
deriving Show
data Tm = TmInt Int
| TmVar TmName
| Fix (Bind (TmName, TmName, Embed (Ty, Ty)) Tm)
| App Tm Tm
| TmProd [Tm]
| TmPrj Tm Int
| TmPrim Tm Prim Tm
| TmIf0 Tm Tm Tm
| TLam (Bind TyName Tm)
| TApp Tm Ty
| Ann Tm Ty
deriving Show
$(derive [''Ty, ''Tm])
instance Alpha Ty
instance Alpha Tm
instance Subst Tm Prim
instance Subst Tm Ty
instance Subst Ty Prim
instance Subst Ty Tm
instance Subst Tm Tm where
isvar (TmVar x) = Just (SubstName x)
isvar _ = Nothing
instance Subst Ty Ty where
isvar (TyVar x) = Just (SubstName x)
isvar _ = Nothing
x :: Name Tm
y :: Name Tm
z :: Name Tm
f :: Name Tm
n :: Name Tm
(x,y,z,f,n) = (string2Name "x", string2Name "y", string2Name "z", string2Name "f", string2Name "n")
a :: Name Ty
b :: Name Ty
c :: Name Ty
(a,b,c) = (string2Name "a", string2Name "b", string2Name "c")
polyid :: Tm
polyid = TLam (bind a (Fix (bind (y, x, Embed (TyVar a, TyVar a)) (TmVar x))))
polyconst :: Tm
polyconst = TLam (bind a (Fix (bind (y, x, Embed (TyVar a, TyInt)) (TmInt 3))))
polyidty :: Ty
polyidty = All (bind a (Arr (TyVar a) (TyVar a)))
two :: Tm
two = App (Fix (bind (y, x, Embed (TyInt, TyInt))
(TmPrim (TmVar x) Plus (TmInt 1)))) (TmInt 1)
onePlusOne :: Tm
onePlusOne = TmPrim (TmInt 1) Plus (TmInt 1)
sixfact :: Tm
sixfact = App (Fix (bind (f, n, Embed (TyInt, TyInt))
(TmIf0 (TmVar n) (TmInt 1)
(TmPrim (TmVar n) Times
(App (TmVar f)
(TmPrim (TmVar n) Minus (TmInt 1))))))) (TmInt 6)
ctrue :: Tm
ctrue = TLam (bind a
(Fix (bind (y,n, Embed (TyVar a, (Arr (TyVar a) (TyVar a))))
(Fix (bind (z, x, Embed (TyVar a, TyVar a))
(TmVar n))))))
twice = TLam (bind a
(Fix (bind (y,f, Embed (Arr (TyVar a) (TyVar a),
(Arr (TyVar a) (TyVar a))))
(Fix (bind (z, x, Embed (TyVar a, TyVar a))
(App (TmVar f) (App (TmVar f) (TmVar x))))))))
type Delta = [ TyName ]
type Gamma = [ (TmName, Ty) ]
data Ctx = Ctx { getDelta :: Delta , getGamma :: Gamma }
emptyCtx = Ctx { getDelta = [], getGamma = [] }
checkTyVar :: Ctx -> TyName -> M ()
checkTyVar g v = do
if List.elem v (getDelta g) then
return ()
else
throwE "NotFound"
lookupTmVar :: Ctx -> TmName -> M Ty
lookupTmVar g v = do
case lookup v (getGamma g) of
Just s -> return s
Nothing -> throwE "NotFound"
extendTy :: TyName -> Ctx -> Ctx
extendTy n ctx = ctx { getDelta = n : (getDelta ctx) }
extendTm :: TmName -> Ty -> Ctx -> Ctx
extendTm n ty ctx = ctx { getGamma = (n, ty) : (getGamma ctx) }
tcty :: Ctx -> Ty -> M ()
tcty g (TyVar x) =
checkTyVar g x
tcty g (All b) = do
(x, ty') <- unbind b
tcty (extendTy x g) ty'
tcty g (Arr ty1 ty2) = do
tcty g ty1
tcty g ty2
tcty g TyInt = return ()
tcty g (TyProd tys) = do
_ <- mapM (tcty g) tys
return ()
typecheck :: Ctx -> Tm -> M Tm
typecheck g e@(TmVar x) = do
ty <- lookupTmVar g x
return $ Ann e ty
typecheck g (Fix bnd) = do
((f, x, Embed (ty1, ty2)), e1) <- unbind bnd
tcty g ty1
tcty g ty2
ae1@(Ann _ ty2') <- typecheck (extendTm f (Arr ty1 ty2) (extendTm x ty1 g)) e1
if not (ty2 `aeq` ty2')
then throwE $ "Type Error: Can't match " ++ pp ty2 ++ " and " ++ pp ty2'
else return $ Ann
(Fix (bind (f,x, Embed (ty1, ty2)) ae1))
(Arr ty1 ty2)
typecheck g e@(App e1 e2) = do
ae1@(Ann _ ty1) <- typecheck g e1
ae2@(Ann _ ty2) <- typecheck g e2
case ty1 of
Arr ty11 ty21 | ty2 `aeq` ty11 ->
return (Ann (App ae1 ae2) ty21)
_ -> throwE "TypeError"
typecheck g (TLam bnd) = do
(x, e) <- unbind bnd
ae@(Ann _ ty) <- typecheck (extendTy x g) e
return $ Ann (TLam (bind x ae)) (All (bind x ty))
typecheck g (TApp e ty) = do
ae@(Ann _ tyt) <- typecheck g e
case tyt of
(All b) -> do
tcty g ty
(n1, ty1) <- unbind b
return $ Ann (TApp ae ty) (subst n1 ty ty1)
typecheck g (TmProd es) = do
atys <- mapM (typecheck g) es
let tys = map (\(Ann _ ty) -> ty) atys
return $ Ann (TmProd atys) (TyProd tys)
typecheck g (TmPrj e i) = do
ae@(Ann _ ty) <- typecheck g e
case ty of
TyProd tys | i < length tys -> return $ Ann (TmPrj ae i) (tys !! i)
_ -> throwE "TypeError"
typecheck g (TmInt i) = return (Ann (TmInt i) TyInt)
typecheck g (TmPrim e1 p e2) = do
ae1@(Ann _ ty1) <- typecheck g e1
ae2@(Ann _ ty2) <- typecheck g e2
case (ty1 , ty2) of
(TyInt, TyInt) -> return (Ann (TmPrim ae1 p ae2) TyInt)
_ -> throwE "TypeError"
typecheck g (TmIf0 e0 e1 e2) = do
ae0@(Ann _ ty0) <- typecheck g e0
ae1@(Ann _ ty1) <- typecheck g e1
ae2@(Ann _ ty2) <- typecheck g e2
if ty1 `aeq` ty2 && ty0 `aeq` TyInt then
return (Ann (TmIf0 ae0 ae1 ae2) ty1)
else
throwE "TypeError"
value :: Tm -> Bool
value (TmInt _) = True
value (Fix _) = True
value (TmProd es) = all value es
value (TLam _) = True
value _ = False
steps :: [Tm] -> M [Tm]
steps [] = throwE "can't step empty list"
steps (e:es) | value e = do
es' <- steps es
return (e : es')
steps (e:es) = do
e' <- step e
return (e' : es)
step :: Tm -> M Tm
step e | value e = throwE "can't step value"
step (TmVar _) = throwE "unbound variable"
step (App e1@(Fix bnd) e2) =
if value e2
then do
((f, x, _), t) <- unbind bnd
return $ substs [ (x, e2), (f,e1) ] t
else do
e2' <- step e2
return (App e1 e2')
step (App e1 e2) = do
e1' <- step e1
return (App e1' e2)
step (TmPrj e1@(TmProd es) i) | value e1 && i < length es = return $ es !! i
step (TmPrj e1 i) = do
e1' <- step e1
return (TmPrj e1' i)
step (TmProd es) = do
es' <- steps es
return (TmProd es')
step (TmPrim (TmInt i1) p (TmInt i2)) =
return (TmInt ((evalPrim p) i1 i2))
step (TmPrim e1 p e2) | value e1 = do
e2' <- step e2
return (TmPrim e1 p e2')
| otherwise = do
e1' <- step e1
return (TmPrim e1' p e2)
step (TmIf0 (TmInt i) e1 e2) = if i==0 then return e1 else return e2
step (TmIf0 e0 e1 e2) = do
e0' <- step e0
return (TmIf0 e0' e1 e2)
step (TApp (TLam bnd) ty) = do
(a, e) <- unbind bnd
return $ subst a ty e
step (TApp e ty) = do
e' <- step e
return $ TApp e' ty
step (Ann e ty) = return e
evaluate :: Tm -> M Tm
evaluate e = if value e then return e else do
e' <- step e
evaluate e'
instance Display Ty where
display (TyVar n) = display n
display (TyInt) = return $ text "Int"
display (Arr ty1 ty2) = do
d1 <- withPrec (precedence "->" + 1) $ display ty1
d2 <- withPrec (precedence "->") $ display ty2
binop d1 "->" d2
display (All bnd) = lunbind bnd $ \ (a,ty) -> do
da <- display a
dt <- display ty
prefix "forall" (da <> text "." <+> dt)
display (TyProd tys) = displayTuple tys
instance Display Tm where
display (TmInt i) = return $ int i
display (TmVar n) = display n
display (Fix bnd) = lunbind bnd $ \((f,x,Embed (ty1,ty2)), e) -> do
df <- display f
dx <- display x
d1 <- display ty1
d2 <- display ty2
de <- withPrec (precedence "fix") $ display e
let arg = parens (dx <> colon <> d1)
prefix "fix" (df <+> arg <> colon <> d2 <> text "." <+> de)
display (App e1 e2) = do
d1 <- withPrec (precedence " ") $ display e1
d2 <- withPrec (precedence " " + 1) $ display e2
binop d1 " " d2
display (TmProd es) = displayTuple es
display (TmPrj e i) = do
de <- display e
return $ text "Pi" <> int i <+> de
display (TmPrim e1 p e2) = do
let str = show p
d1 <- withPrec (precedence str) $ display e1
d2 <- withPrec (precedence str + 1) $ display e2
binop d1 str d2
display (TmIf0 e0 e1 e2) = do
d0 <- display e0
d1 <- display e1
d2 <- display e2
prefix "if0" $ sep [d0 , text "then" <+> d1 , text "else" <+> d2]
display (TLam bnd) = lunbind bnd $ \(a,e) -> do
da <- display a
de <- withPrec (precedence "/\\") $ display e
prefix "/\\" (da <> text "." <+> de)
display (TApp e ty) = do
d1 <- withPrec (precedence " ") $ display e
d2 <- withPrec (precedence " " + 1) $ display ty
binop d1 " " d2
display (Ann e ty) = display e