module Language.LBNF.TypeChecker where

import Control.Monad
import Data.List
import Data.Char

import Language.LBNF.CF
import Language.LBNF.Runtime

data Base = BaseT String
	  | ListT Base
    deriving (Eq)

data Type = FunT [Base] Base
    deriving (Eq)

instance Show Base where
    show (BaseT x) = x
    show (ListT t) = "[" ++ show t ++ "]"

instance Show Type where
    show (FunT ts t) = unwords $ map show ts ++ ["->", show t]

data Context = Ctx  { ctxLabels :: [(String, Type)]
		    , ctxTokens :: [String]
		    }

catchErr :: ParseMonad a -> (String -> ParseMonad a) -> ParseMonad a
catchErr (Bad s) f = f s
catchErr (Ok x) _  = Ok x

buildContext :: CF -> Context
buildContext cf@(_,rules) =
    Ctx
    [ (f, mkType cat args) | (f,(cat,args)) <- rules
			   , not (isCoercion f)
			   , not (isNilCons f)
    ]
    ("Ident" : tokenNames cf)
  where

    mkType cat (Left args) = FunT [ mkBase t | Left t <- args, t /= internalCat ]
			          (mkBase cat)
    mkType cat (Right reg) = FunT [ BaseT "String" ] (mkBase cat)
    mkBase t
	| isList t  = ListT $ mkBase $ normCatOfList t
	| otherwise = BaseT $ normCat t

isToken :: String -> Context -> Bool
isToken x ctx = elem x $ ctxTokens ctx

extendContext :: Context -> [(String,Type)] -> Context
extendContext ctx xs = ctx { ctxLabels = xs ++ ctxLabels ctx }

lookupCtx :: String -> Context -> ParseMonad Type
lookupCtx x ctx
    | isToken x ctx = return $ FunT [BaseT "String"] (BaseT x)
    | otherwise	    =
    case lookup x $ ctxLabels ctx of
	Nothing	-> fail $ "Undefined symbol '" ++ x ++ "'."
	Just t	-> return t

checkDefinitions :: CF -> ParseMonad ()
checkDefinitions cf@((ps,_),_) =
    do	checkContext ctx
	sequence_ [ checkDefinition ctx f xs e | FunDef f xs e <- ps ]
    where
	ctx = buildContext cf

checkContext :: Context -> ParseMonad ()
checkContext ctx =
    mapM_ checkEntry $ groupSnd $ ctxLabels ctx
    where
	-- This is a very handy function which transforms a lookup table
	-- with duplicate keys to a list valued lookup table with no duplicate
	-- keys.
	groupSnd :: Ord a => [(a,b)] -> [(a,[b])]
	groupSnd =
	    map ((fst . head) /\ map snd)
	    . groupBy ((==) **.* fst)
	    . sortBy (compare **.* fst)
	
	(f /\ g) x     = (f x, g x)
	(f **.* g) x y = f (g x) (g y)

	checkEntry (f,ts) =
	    case nub ts of
		[_] -> return ()
		ts' -> 
		    fail $ "The symbol '" ++ f ++ "' is used at conflicting types:\n" ++
			    unlines (map (("  " ++) . show) ts')

checkDefinition :: Context -> String -> [String] -> Exp -> ParseMonad ()
checkDefinition ctx f xs e =
    do	checkDefinition' dummyConstructors ctx f xs e
	return ()

data ListConstructors = LC
	{ nil	:: Base -> String
	, cons	:: Base -> String
	}

dummyConstructors :: ListConstructors
dummyConstructors = LC (const "[]") (const "(:)")

checkDefinition' :: ListConstructors -> Context -> String -> [String] -> Exp -> ParseMonad ([(String,Base)],(Exp,Base))
checkDefinition' list ctx f xs e =
    do	unless (isLower $ head f) $ fail "Defined functions must start with a lowercase letter."
	t@(FunT ts t') <- lookupCtx f ctx `catchErr` \_ ->
				fail $ "'" ++ f ++ "' must be used in a rule."
	let expect = length ts
	    given  = length xs
	unless (expect == given) $ fail $ "'" ++ f ++ "' is used with type " ++ show t ++ " but defined with " ++ show given ++ " argument" ++ plural given ++ "."
	e' <- checkExp list (extendContext ctx $ zip xs (map (FunT []) ts)) e t'
	return (zip xs ts, (e', t'))
    `catchErr` \err -> fail $ "In the definition " ++ unwords (f : xs ++ ["=",show e,";"]) ++ "\n  " ++ err
    where
	plural 1 = ""
	plural _ = "s"

checkExp :: ListConstructors -> Context -> Exp -> Base -> ParseMonad Exp
checkExp list ctx (App "[]" []) (ListT t) = return (App (nil list t) [])
checkExp _ _	  (App "[]" _) _	  = fail $ "[] is applied to too many arguments."
checkExp list ctx (App "(:)" [e,es]) (ListT t) =
    do	e'  <- checkExp list ctx e t
	es' <- checkExp list ctx es (ListT t)
	return $ App (cons list t) [e',es']
checkExp _ _ (App "(:)" es) _	= fail $ "(:) takes 2 arguments, but has been given " ++ show (length es) ++ "."
checkExp list ctx e@(App x es) t =
    do	FunT ts t' <- lookupCtx x ctx
	es' <- matchArgs ctx es ts
	unless (t == t') $ fail $ show e ++ " has type " ++ show t' ++ ", but something of type " ++ show t ++ " was expected."
	return $ App x es'
    where
	matchArgs ctx es ts
	    | expect /= given	= fail $ "'" ++ x ++ "' takes " ++ show expect ++ " arguments, but has been given " ++ show given ++ "."
	    | otherwise		= zipWithM (checkExp list ctx) es ts
	    where
		expect = length ts
		given  = length es
checkExp _ _ e@(LitInt _) (BaseT "Integer")	= return e
checkExp _ _ e@(LitDouble _) (BaseT "Double")	= return e
checkExp _ _ e@(LitChar _) (BaseT "Char")	= return e
checkExp _ _ e@(LitString _) (BaseT "String")	= return e
checkExp _ _ e t = fail $ show e ++ " does not have type " ++ show t ++ "."