{-# LANGUAGE TypeOperators, GADTs, CPP #-}
module Jukebox.Tools.InferTypes where

#include "errors.h"
import Control.Monad
import Jukebox.Form
import Jukebox.Name
import qualified Data.Map.Strict as Map
import Data.Map(Map)
import Jukebox.UnionFind hiding (rep)
import qualified Data.Set as Set
import Data.MemoUgly

type Function' = ([(Name, Type)], (Name, Type))

inferTypes :: [Input Clause] -> NameM ([Input Clause], Type -> Type)
inferTypes :: [Input Clause] -> NameM ([Input Clause], Type -> Type)
inferTypes [Input Clause]
prob = do
  Map Name ([(Name, Type)], (Name, Type))
funMap <-
    ([(Name, ([(Name, Type)], (Name, Type)))]
 -> Map Name ([(Name, Type)], (Name, Type)))
-> NameM [(Name, ([(Name, Type)], (Name, Type)))]
-> NameM (Map Name ([(Name, Type)], (Name, Type)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Name, ([(Name, Type)], (Name, Type)))]
-> Map Name ([(Name, Type)], (Name, Type))
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (NameM [(Name, ([(Name, Type)], (Name, Type)))]
 -> NameM (Map Name ([(Name, Type)], (Name, Type))))
-> ([NameM (Name, ([(Name, Type)], (Name, Type)))]
    -> NameM [(Name, ([(Name, Type)], (Name, Type)))])
-> [NameM (Name, ([(Name, Type)], (Name, Type)))]
-> NameM (Map Name ([(Name, Type)], (Name, Type)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [NameM (Name, ([(Name, Type)], (Name, Type)))]
-> NameM [(Name, ([(Name, Type)], (Name, Type)))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([NameM (Name, ([(Name, Type)], (Name, Type)))]
 -> NameM (Map Name ([(Name, Type)], (Name, Type))))
-> [NameM (Name, ([(Name, Type)], (Name, Type)))]
-> NameM (Map Name ([(Name, Type)], (Name, Type)))
forall a b. (a -> b) -> a -> b
$
      [ do Name
res <- Type -> NameM Name
forall a. Named a => a -> NameM Name
newName (Function -> Type
forall a. Typed a => a -> Type
typ Function
f)
           [Name]
args <- (Type -> NameM Name) -> [Type] -> NameM [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> NameM Name
forall a. Named a => a -> NameM Name
newName (Function -> [Type]
funArgs Function
f)
           (Name, ([(Name, Type)], (Name, Type)))
-> NameM (Name, ([(Name, Type)], (Name, Type)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Function -> Name
forall a. Named a => a -> Name
name Function
f,
                   ((Name -> Type -> (Name, Type))
-> [Name] -> [Type] -> [(Name, Type)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (,) [Name]
args (Function -> [Type]
funArgs Function
f),
                    (Name
res, Function -> Type
forall a. Typed a => a -> Type
typ Function
f)))
      | Function
f <- [Input Clause] -> [Function]
forall a. Symbolic a => a -> [Function]
functions [Input Clause]
prob ]
  Map Name (Name, Type)
varMap <-
    ([(Name, (Name, Type))] -> Map Name (Name, Type))
-> NameM [(Name, (Name, Type))] -> NameM (Map Name (Name, Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Name, (Name, Type))] -> Map Name (Name, Type)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (NameM [(Name, (Name, Type))] -> NameM (Map Name (Name, Type)))
-> ([NameM (Name, (Name, Type))] -> NameM [(Name, (Name, Type))])
-> [NameM (Name, (Name, Type))]
-> NameM (Map Name (Name, Type))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [NameM (Name, (Name, Type))] -> NameM [(Name, (Name, Type))]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([NameM (Name, (Name, Type))] -> NameM (Map Name (Name, Type)))
-> [NameM (Name, (Name, Type))] -> NameM (Map Name (Name, Type))
forall a b. (a -> b) -> a -> b
$
      [ do Name
ty <- Type -> NameM Name
forall a. Named a => a -> NameM Name
newName (Variable -> Type
forall a. Typed a => a -> Type
typ Variable
v)
           (Name, (Name, Type)) -> NameM (Name, (Name, Type))
forall (m :: * -> *) a. Monad m => a -> m a
return (Variable -> Name
forall a. Named a => a -> Name
name Variable
v, (Name
ty, Variable -> Type
forall a. Typed a => a -> Type
typ Variable
v))
      | Variable
v <- [Input Clause] -> [Variable]
forall a. Symbolic a => a -> [Variable]
vars [Input Clause]
prob ]
  
  let tyMap :: Map Name Type
tyMap = [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Name, Type)] -> Map Name Type)
-> [(Name, Type)] -> Map Name Type
forall a b. (a -> b) -> a -> b
$
              [(Type -> Name
forall a. Named a => a -> Name
name Type
O, Type
O)] [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++
              [[(Name, Type)]] -> [(Name, Type)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ (Name, Type)
res(Name, Type) -> [(Name, Type)] -> [(Name, Type)]
forall a. a -> [a] -> [a]
:[(Name, Type)]
args | ([(Name, Type)]
args, (Name, Type)
res) <- Map Name ([(Name, Type)], (Name, Type))
-> [([(Name, Type)], (Name, Type))]
forall k a. Map k a -> [a]
Map.elems Map Name ([(Name, Type)], (Name, Type))
funMap ] [(Name, Type)] -> [(Name, Type)] -> [(Name, Type)]
forall a. [a] -> [a] -> [a]
++
              [ (Name, Type)
ty | (Name, Type)
ty <- Map Name (Name, Type) -> [(Name, Type)]
forall k a. Map k a -> [a]
Map.elems Map Name (Name, Type)
varMap ]
  
  let ([Input Clause]
prob', Name -> Name
rep) = Map Name ([(Name, Type)], (Name, Type))
-> Map Name (Name, Type)
-> [Input Clause]
-> ([Input Clause], Name -> Name)
solve Map Name ([(Name, Type)], (Name, Type))
funMap Map Name (Name, Type)
varMap [Input Clause]
prob
      rep' :: Type -> Type
rep' Type
ty =
        Type -> Name -> Map Name Type -> Type
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault __ (rep (name ty)) tyMap
  
  ([Input Clause], Type -> Type)
-> NameM ([Input Clause], Type -> Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input Clause]
prob', Type -> Type
rep')

solve :: Map Name Function' -> Map Name (Name, Type) ->
         [Input Clause] -> ([Input Clause], Name -> Name)
solve :: Map Name ([(Name, Type)], (Name, Type))
-> Map Name (Name, Type)
-> [Input Clause]
-> ([Input Clause], Name -> Name)
solve Map Name ([(Name, Type)], (Name, Type))
funMap Map Name (Name, Type)
varMap [Input Clause]
prob = ([Input Clause]
prob', Name -> Name
rep)
  where prob' :: [Input Clause]
prob' = [Input Clause] -> [Input Clause]
forall a. Symbolic a => a -> a
aux [Input Clause]
prob
        aux :: Symbolic a => a -> a
        aux :: a -> a
aux a
t =
          case a -> TypeOf a
forall a. Symbolic a => a -> TypeOf a
typeOf a
t of
            TypeOf a
Bind_ -> Bind a -> Bind a
forall a. Symbolic a => Bind a -> Bind a
bind a
Bind a
t
            TypeOf a
Term -> Term -> Term
term a
Term
t
            TypeOf a
_ -> (forall a. Symbolic a => a -> a) -> a -> a
forall a. Symbolic a => (forall a. Symbolic a => a -> a) -> a -> a
recursively forall a. Symbolic a => a -> a
aux a
t

        bind :: Symbolic a => Bind a -> Bind a
        bind :: Bind a -> Bind a
bind (Bind Set Variable
vs a
t) = Set Variable -> a -> Bind a
forall a. Set Variable -> a -> Bind a
Bind ((Variable -> Variable) -> Set Variable -> Set Variable
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map Variable -> Variable
var Set Variable
vs) (a -> a
forall a. Symbolic a => a -> a
aux a
t)

        term :: Term -> Term
term (Function
f :@: [Term]
ts) = Function -> Function
fun Function
f Function -> [Term] -> Term
:@: (Term -> Term) -> [Term] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Term -> Term
term [Term]
ts
        term (Var Variable
x) = Variable -> Term
Var (Variable -> Variable
var Variable
x)

        fun :: Function -> Function
fun = (Function -> Function) -> Function -> Function
forall a b. Ord a => (a -> b) -> a -> b
memo Function -> Function
fun_
        fun_ :: Function -> Function
fun_ (Name
f ::: FunType
_) =
          let ([(Name, Type)]
args, (Name, Type)
res) = ([(Name, Type)], (Name, Type))
-> Name
-> Map Name ([(Name, Type)], (Name, Type))
-> ([(Name, Type)], (Name, Type))
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault __ f funMap
          in Name
f Name -> FunType -> Function
forall a b. a -> b -> a ::: b
::: [Type] -> Type -> FunType
FunType (((Name, Type) -> Type) -> [(Name, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Type) -> Type
type_ [(Name, Type)]
args) ((Name, Type) -> Type
type_ (Name, Type)
res)

        var :: Variable -> Variable
var = (Variable -> Variable) -> Variable -> Variable
forall a b. Ord a => (a -> b) -> a -> b
memo Variable -> Variable
var_
        var_ :: Variable -> Variable
var_ (Name
x ::: Type
_) = Name
x Name -> Type -> Variable
forall a b. a -> b -> a ::: b
::: (Name, Type) -> Type
type_ ((Name, Type) -> Name -> Map Name (Name, Type) -> (Name, Type)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault __ x varMap)

        type_ :: (Name, Type) -> Type
type_ = ((Name, Type) -> Type) -> (Name, Type) -> Type
forall a b. Ord a => (a -> b) -> a -> b
memo (Name, Type) -> Type
type__
        type__ :: (Name, Type) -> Type
type__ (Name
_, Type
O) = Type
O
        type__ (Name
name, Type
_) = Name -> Type
Type (Name -> Name
rep Name
name)

        rep :: Name -> Name
rep = S Name -> UF Name (Name -> Name) -> Name -> Name
forall a b. S a -> UF a b -> b
evalUF S Name
forall a. S a
initial (UF Name (Name -> Name) -> Name -> Name)
-> UF Name (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ do
          Map Name ([(Name, Type)], (Name, Type))
-> Map Name (Name, Type) -> [Input Clause] -> UF Name ()
generate Map Name ([(Name, Type)], (Name, Type))
funMap Map Name (Name, Type)
varMap [Input Clause]
prob
          UF Name (Name -> Name)
forall a. Ord a => UF a (a -> a)
reps

generate :: Map Name Function' -> Map Name (Name, Type) -> [Input Clause] -> UF Name ()
generate :: Map Name ([(Name, Type)], (Name, Type))
-> Map Name (Name, Type) -> [Input Clause] -> UF Name ()
generate Map Name ([(Name, Type)], (Name, Type))
funMap Map Name (Name, Type)
varMap [Input Clause]
cs = ([Atomic] -> UF Name ()) -> [[Atomic]] -> UF Name ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Atomic -> UF Name ()) -> [Atomic] -> UF Name ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Atomic -> UF Name ()
atomic) [[Atomic]]
lss
  where lss :: [[Atomic]]
lss = (Input Clause -> [Atomic]) -> [Input Clause] -> [[Atomic]]
forall a b. (a -> b) -> [a] -> [b]
map ((Signed Atomic -> Atomic) -> [Signed Atomic] -> [Atomic]
forall a b. (a -> b) -> [a] -> [b]
map Signed Atomic -> Atomic
forall a. Signed a -> a
the ([Signed Atomic] -> [Atomic])
-> (Input Clause -> [Signed Atomic]) -> Input Clause -> [Atomic]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Clause -> [Signed Atomic]
toLiterals (Clause -> [Signed Atomic])
-> (Input Clause -> Clause) -> Input Clause -> [Signed Atomic]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input Clause -> Clause
forall a. Input a -> a
what) [Input Clause]
cs
        atomic :: Atomic -> UF Name ()
atomic (Tru Term
p) = StateT (S Name) Identity Name -> UF Name ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Term -> StateT (S Name) Identity Name
term Term
p)
        atomic (Term
t :=: Term
u) = do { Name
t' <- Term -> StateT (S Name) Identity Name
term Term
t; Name
u' <- Term -> StateT (S Name) Identity Name
term Term
u; Name
t' Name -> Name -> UF Name (Maybe (Replacement Name))
forall a. Ord a => a -> a -> UF a (Maybe (Replacement a))
=:= Name
u'; () -> UF Name ()
forall (m :: * -> *) a. Monad m => a -> m a
return () }
        term :: Term -> StateT (S Name) Identity Name
term (Var Variable
x) = Name -> StateT (S Name) Identity Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
y
          where (Name
y, Type
_) = (Name, Type) -> Name -> Map Name (Name, Type) -> (Name, Type)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault __ (name x) varMap
        term (Function
f :@: [Term]
xs) = do
          [Name]
ys <- (Term -> StateT (S Name) Identity Name)
-> [Term] -> StateT (S Name) Identity [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Term -> StateT (S Name) Identity Name
term [Term]
xs
          let ([(Name, Type)]
zs, (Name, Type)
r) = ([(Name, Type)], (Name, Type))
-> Name
-> Map Name ([(Name, Type)], (Name, Type))
-> ([(Name, Type)], (Name, Type))
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault __ (name f) funMap
          (Name -> Name -> UF Name (Maybe (Replacement Name)))
-> [Name] -> [Name] -> UF Name ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Name -> Name -> UF Name (Maybe (Replacement Name))
forall a. Ord a => a -> a -> UF a (Maybe (Replacement a))
(=:=) [Name]
ys (((Name, Type) -> Name) -> [(Name, Type)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Type) -> Name
forall a b. (a, b) -> a
fst [(Name, Type)]
zs)
          Name -> StateT (S Name) Identity Name
forall (m :: * -> *) a. Monad m => a -> m a
return ((Name, Type) -> Name
forall a b. (a, b) -> a
fst (Name, Type)
r)