{-# LANGUAGE TypeOperators, GADTs, CPP #-} module Jukebox.Tools.InferTypes where #include "errors.h" import Control.Monad import Jukebox.Form import Jukebox.Name import qualified Data.Map.Strict as Map import Data.Map(Map) import Jukebox.UnionFind hiding (rep) import qualified Data.Set as Set import Data.MemoUgly type Function' = ([(Name, Type)], (Name, Type)) inferTypes :: [Input Clause] -> NameM ([Input Clause], Type -> Type) inferTypes prob = do funMap <- fmap Map.fromList . sequence $ [ do res <- newName (typ f) args <- mapM newName (funArgs f) return (name f, (zipWith (,) args (funArgs f), (res, typ f))) | f <- functions prob ] varMap <- fmap Map.fromList . sequence $ [ do ty <- newName (typ v) return (name v, (ty, typ v)) | v <- vars prob ] let tyMap = Map.fromList $ [(name O, O)] ++ concat [ res:args | (args, res) <- Map.elems funMap ] ++ [ ty | ty <- Map.elems varMap ] let (prob', rep) = solve funMap varMap prob rep' ty = Map.findWithDefault __ (rep (name ty)) tyMap return (prob', rep') solve :: Map Name Function' -> Map Name (Name, Type) -> [Input Clause] -> ([Input Clause], Name -> Name) solve funMap varMap prob = (prob', rep) where prob' = aux prob aux :: Symbolic a => a -> a aux t = case typeOf t of Bind_ -> bind t Term -> term t _ -> recursively aux t bind :: Symbolic a => Bind a -> Bind a bind (Bind vs t) = Bind (Set.map var vs) (aux t) term (f :@: ts) = fun f :@: map term ts term (Var x) = Var (var x) fun = memo fun_ fun_ (f ::: _) = let (args, res) = Map.findWithDefault __ f funMap in f ::: FunType (map type_ args) (type_ res) var = memo var_ var_ (x ::: _) = x ::: type_ (Map.findWithDefault __ x varMap) type_ = memo type__ type__ (_, O) = O type__ (name, _) = Type (rep name) rep = evalUF initial $ do generate funMap varMap prob reps generate :: Map Name Function' -> Map Name (Name, Type) -> [Input Clause] -> UF Name () generate funMap varMap cs = mapM_ (mapM_ atomic) lss where lss = map (map the . toLiterals . what) cs atomic (Tru p) = void (term p) atomic (t :=: u) = do { t' <- term t; u' <- term u; t' =:= u'; return () } term (Var x) = return y where (y, _) = Map.findWithDefault __ (name x) varMap term (f :@: xs) = do ys <- mapM term xs let (zs, r) = Map.findWithDefault __ (name f) funMap zipWithM_ (=:=) ys (map fst zs) return (fst r)