module Language.LBNF.TypeChecker where
import Control.Monad
import Data.List
import Data.Char
import Language.LBNF.CF
import Language.LBNF.Runtime
data Base = BaseT String
| ListT Base
deriving (Eq)
data Type = FunT [Base] Base
deriving (Eq)
instance Show Base where
show (BaseT x) = x
show (ListT t) = "[" ++ show t ++ "]"
instance Show Type where
show (FunT ts t) = unwords $ map show ts ++ ["->", show t]
data Context = Ctx { ctxLabels :: [(String, Type)]
, ctxTokens :: [String]
}
catchErr :: ParseMonad a -> (String -> ParseMonad a) -> ParseMonad a
catchErr (Bad s) f = f s
catchErr (Ok x) _ = Ok x
buildContext :: CF -> Context
buildContext cf@(_,rules) =
Ctx
[ (f, mkType cat args) | (f,(cat,args)) <- rules
, not (isCoercion f)
, not (isNilCons f)
]
("Ident" : tokenNames cf)
where
mkType cat (Left args) = FunT [ mkBase t | Left t <- args, t /= internalCat ]
(mkBase cat)
mkType cat (Right reg) = FunT [ BaseT "String" ] (mkBase cat)
mkBase t
| isList t = ListT $ mkBase $ normCatOfList t
| otherwise = BaseT $ normCat t
isToken :: String -> Context -> Bool
isToken x ctx = elem x $ ctxTokens ctx
extendContext :: Context -> [(String,Type)] -> Context
extendContext ctx xs = ctx { ctxLabels = xs ++ ctxLabels ctx }
lookupCtx :: String -> Context -> ParseMonad Type
lookupCtx x ctx
| isToken x ctx = return $ FunT [BaseT "String"] (BaseT x)
| otherwise =
case lookup x $ ctxLabels ctx of
Nothing -> fail $ "Undefined symbol '" ++ x ++ "'."
Just t -> return t
checkDefinitions :: CF -> ParseMonad ()
checkDefinitions cf@((ps,_),_) =
do checkContext ctx
sequence_ [ checkDefinition ctx f xs e | FunDef f xs e <- ps ]
where
ctx = buildContext cf
checkContext :: Context -> ParseMonad ()
checkContext ctx =
mapM_ checkEntry $ groupSnd $ ctxLabels ctx
where
groupSnd :: Ord a => [(a,b)] -> [(a,[b])]
groupSnd =
map ((fst . head) /\ map snd)
. groupBy ((==) **.* fst)
. sortBy (compare **.* fst)
(f /\ g) x = (f x, g x)
(f **.* g) x y = f (g x) (g y)
checkEntry (f,ts) =
case nub ts of
[_] -> return ()
ts' ->
fail $ "The symbol '" ++ f ++ "' is used at conflicting types:\n" ++
unlines (map ((" " ++) . show) ts')
checkDefinition :: Context -> String -> [String] -> Exp -> ParseMonad ()
checkDefinition ctx f xs e =
do checkDefinition' dummyConstructors ctx f xs e
return ()
data ListConstructors = LC
{ nil :: Base -> String
, cons :: Base -> String
}
dummyConstructors :: ListConstructors
dummyConstructors = LC (const "[]") (const "(:)")
checkDefinition' :: ListConstructors -> Context -> String -> [String] -> Exp -> ParseMonad ([(String,Base)],(Exp,Base))
checkDefinition' list ctx f xs e =
do unless (isLower $ head f) $ fail "Defined functions must start with a lowercase letter."
t@(FunT ts t') <- lookupCtx f ctx `catchErr` \_ ->
fail $ "'" ++ f ++ "' must be used in a rule."
let expect = length ts
given = length xs
unless (expect == given) $ fail $ "'" ++ f ++ "' is used with type " ++ show t ++ " but defined with " ++ show given ++ " argument" ++ plural given ++ "."
e' <- checkExp list (extendContext ctx $ zip xs (map (FunT []) ts)) e t'
return (zip xs ts, (e', t'))
`catchErr` \err -> fail $ "In the definition " ++ unwords (f : xs ++ ["=",show e,";"]) ++ "\n " ++ err
where
plural 1 = ""
plural _ = "s"
checkExp :: ListConstructors -> Context -> Exp -> Base -> ParseMonad Exp
checkExp list ctx (App "[]" []) (ListT t) = return (App (nil list t) [])
checkExp _ _ (App "[]" _) _ = fail $ "[] is applied to too many arguments."
checkExp list ctx (App "(:)" [e,es]) (ListT t) =
do e' <- checkExp list ctx e t
es' <- checkExp list ctx es (ListT t)
return $ App (cons list t) [e',es']
checkExp _ _ (App "(:)" es) _ = fail $ "(:) takes 2 arguments, but has been given " ++ show (length es) ++ "."
checkExp list ctx e@(App x es) t =
do FunT ts t' <- lookupCtx x ctx
es' <- matchArgs ctx es ts
unless (t == t') $ fail $ show e ++ " has type " ++ show t' ++ ", but something of type " ++ show t ++ " was expected."
return $ App x es'
where
matchArgs ctx es ts
| expect /= given = fail $ "'" ++ x ++ "' takes " ++ show expect ++ " arguments, but has been given " ++ show given ++ "."
| otherwise = zipWithM (checkExp list ctx) es ts
where
expect = length ts
given = length es
checkExp _ _ e@(LitInt _) (BaseT "Integer") = return e
checkExp _ _ e@(LitDouble _) (BaseT "Double") = return e
checkExp _ _ e@(LitChar _) (BaseT "Char") = return e
checkExp _ _ e@(LitString _) (BaseT "String") = return e
checkExp _ _ e t = fail $ show e ++ " does not have type " ++ show t ++ "."