{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances    #-}

-- Syntactic Equality of Types up tp forall type renaming

module Language.Haskell.Liquid.Types.Equality where 

import qualified Language.Fixpoint.Types as F
import           Language.Haskell.Liquid.Types
import qualified Language.Haskell.Liquid.GHC.API as Ghc

import Control.Monad.Writer.Lazy
-- import Control.Monad
import qualified Data.List as L

instance REq SpecType where 
  SpecType
t1 =*= :: SpecType -> SpecType -> Bool
=*= SpecType
t2 = SpecType -> SpecType -> Bool
compareRType SpecType
t1 SpecType
t2 
  
compareRType :: SpecType -> SpecType -> Bool 
compareRType :: SpecType -> SpecType -> Bool
compareRType SpecType
i1 SpecType
i2 = Bool
res Bool -> Bool -> Bool
&& [(RTyVar, RTyVar)] -> Bool
forall b a. (Eq b, Eq a) => [(a, b)] -> Bool
unify [(RTyVar, RTyVar)]
vs   
  where 
    unify :: [(a, b)] -> Bool
unify [(a, b)]
vs = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([(a, b)] -> Bool
forall b a. Eq b => [(a, b)] -> Bool
sndEq ([(a, b)] -> Bool) -> [[(a, b)]] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((a, b) -> (a, b) -> Bool) -> [(a, b)] -> [[(a, b)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
L.groupBy (\(a
x1,b
_) (a
x2,b
_) -> a
x1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x2) [(a, b)]
vs)) 
    sndEq :: [(a, b)] -> Bool
sndEq [] = Bool
True 
    sndEq [(a, b)
_] = Bool
True 
    sndEq ((a
_,b
y):[(a, b)]
xs) = (b -> Bool) -> [b] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (b -> b -> Bool
forall a. Eq a => a -> a -> Bool
==b
y) ((a, b) -> b
forall a b. (a, b) -> b
snd ((a, b) -> b) -> [(a, b)] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(a, b)]
xs)

    (Bool
res, [(RTyVar, RTyVar)]
vs) = Writer [(RTyVar, RTyVar)] Bool -> (Bool, [(RTyVar, RTyVar)])
forall w a. Writer w a -> (a, w)
runWriter (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
i1 SpecType
i2)
    go :: SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool  
    go :: SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go (RAllT RTVU RTyCon RTyVar
x1 SpecType
t1 RReft
r1) (RAllT RTVU RTyCon RTyVar
x2 SpecType
t2 RReft
r2)
      | RTV TyVar
v1 <- RTVU RTyCon RTyVar -> RTyVar
forall tv s. RTVar tv s -> tv
ty_var_value RTVU RTyCon RTyVar
x1
      , RTV TyVar
v2 <- RTVU RTyCon RTyVar -> RTyVar
forall tv s. RTVar tv s -> tv
ty_var_value RTVU RTyCon RTyVar
x2 
      , RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t1 ((TyVar, Type) -> SpecType -> SpecType
forall tv ty a. SubsTy tv ty a => (tv, ty) -> a -> a
subt (TyVar
v2, TyVar -> Type
Ghc.mkTyVarTy TyVar
v1) SpecType
t2) 

    go (RVar RTyVar
v1 RReft
r1) (RVar RTyVar
v2 RReft
r2) 
      = do [(RTyVar, RTyVar)] -> WriterT [(RTyVar, RTyVar)] Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [(RTyVar
v1, RTyVar
v2)]
           Bool -> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2) 
     -- = v1 == v2 && r1 =*= r2 
    go (RFun Symbol
x1 SpecType
t11 SpecType
t12 RReft
r1) (RFun Symbol
x2 SpecType
t21 SpecType
t22 RReft
r2)
      | Symbol
x1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
x2 Bool -> Bool -> Bool
&& RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = (Bool -> Bool -> Bool)
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (RImpF Symbol
x1 SpecType
t11 SpecType
t12 RReft
r1) (RImpF Symbol
x2 SpecType
t21 SpecType
t22 RReft
r2)
      | Symbol
x1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
x2    Bool -> Bool -> Bool
&& RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2
      = (Bool -> Bool -> Bool)
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)    
    go (RAllP PVU RTyCon RTyVar
x1 SpecType
t1) (RAllP PVU RTyCon RTyVar
x2 SpecType
t2)
      | PVU RTyCon RTyVar
x1 PVU RTyCon RTyVar -> PVU RTyCon RTyVar -> Bool
forall a. Eq a => a -> a -> Bool
== PVU RTyCon RTyVar
x2 
      = SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t1 SpecType
t2 
    go (RApp RTyCon
x1 [SpecType]
ts1 [RTProp RTyCon RTyVar RReft]
ps1 RReft
r1) (RApp RTyCon
x2 [SpecType]
ts2 [RTProp RTyCon RTyVar RReft]
ps2 RReft
r2)
      | RTyCon
x1 RTyCon -> RTyCon -> Bool
forall a. Eq a => a -> a -> Bool
== RTyCon
x2 Bool -> Bool -> Bool
&&  
        RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2 Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((RTProp RTyCon RTyVar RReft -> RTProp RTyCon RTyVar RReft -> Bool)
-> [RTProp RTyCon RTyVar RReft]
-> [RTProp RTyCon RTyVar RReft]
-> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith RTProp RTyCon RTyVar RReft -> RTProp RTyCon RTyVar RReft -> Bool
forall a. REq a => a -> a -> Bool
(=*=) [RTProp RTyCon RTyVar RReft]
ps1 [RTProp RTyCon RTyVar RReft]
ps2) 
      = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool)
-> WriterT [(RTyVar, RTyVar)] Identity [Bool]
-> Writer [(RTyVar, RTyVar)] Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool)
-> [SpecType]
-> [SpecType]
-> WriterT [(RTyVar, RTyVar)] Identity [Bool]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go [SpecType]
ts1 [SpecType]
ts2)
    go (RAllE Symbol
x1 SpecType
t11 SpecType
t12) (RAllE Symbol
x2 SpecType
t21 SpecType
t22) | Symbol
x1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
x2 
      = (Bool -> Bool -> Bool)
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22) 
    go (REx Symbol
x1 SpecType
t11 SpecType
t12) (REx Symbol
x2 SpecType
t21 SpecType
t22) | Symbol
x1 Symbol -> Symbol -> Bool
forall a. Eq a => a -> a -> Bool
== Symbol
x2
      = (Bool -> Bool -> Bool)
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)
    go (RExprArg Located Expr
e1) (RExprArg Located Expr
e2)
      = Bool -> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Located Expr
e1 Located Expr -> Located Expr -> Bool
forall a. REq a => a -> a -> Bool
=*= Located Expr
e2) 
    go (RAppTy SpecType
t11 SpecType
t12 RReft
r1) (RAppTy SpecType
t21 SpecType
t22 RReft
r2) | RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2 
      = (Bool -> Bool -> Bool)
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
-> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Bool -> Bool -> Bool
(&&) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t11 SpecType
t21) (SpecType -> SpecType -> Writer [(RTyVar, RTyVar)] Bool
go SpecType
t12 SpecType
t22)  
    go (RRTy [(Symbol, SpecType)]
_ RReft
_ Oblig
_ SpecType
r1) (RRTy [(Symbol, SpecType)]
_ RReft
_ Oblig
_ SpecType
r2) 
      = Bool -> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (SpecType
r1 SpecType -> SpecType -> Bool
forall a. REq a => a -> a -> Bool
=*= SpecType
r2)
    go (RHole RReft
r1) (RHole RReft
r2)
      = Bool -> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (RReft
r1 RReft -> RReft -> Bool
forall a. REq a => a -> a -> Bool
=*= RReft
r2)  
    go SpecType
_t1 SpecType
_t2 
      = Bool -> Writer [(RTyVar, RTyVar)] Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False 

class REq a where 
  (=*=) :: a -> a -> Bool 

instance REq t2 => REq (Ref t1 t2) where
    (RProp [(Symbol, t1)]
_ t2
t1) =*= :: Ref t1 t2 -> Ref t1 t2 -> Bool
=*= (RProp [(Symbol, t1)]
_ t2
t2) = t2
t1 t2 -> t2 -> Bool
forall a. REq a => a -> a -> Bool
=*= t2
t2 

instance REq (UReft F.Reft) where
  (MkUReft Reft
r1 Predicate
p1) =*= :: RReft -> RReft -> Bool
=*= (MkUReft Reft
r2 Predicate
p2)
     = Reft
r1 Reft -> Reft -> Bool
forall a. REq a => a -> a -> Bool
=*= Reft
r2 Bool -> Bool -> Bool
&& Predicate
p1 Predicate -> Predicate -> Bool
forall a. Eq a => a -> a -> Bool
== Predicate
p2
  
instance REq F.Reft where 
  F.Reft (Symbol
v1, Expr
e1) =*= :: Reft -> Reft -> Bool
=*= F.Reft (Symbol
v2, Expr
e2) = Expr -> (Symbol, Expr) -> Expr
forall a. Subable a => a -> (Symbol, Expr) -> a
F.subst1 Expr
e1 (Symbol
v1, Symbol -> Expr
F.EVar Symbol
v2) Expr -> Expr -> Bool
forall a. REq a => a -> a -> Bool
=*= Expr
e2 

instance REq F.Expr where 
  Expr
e1 =*= :: Expr -> Expr -> Bool
=*= Expr
e2 = Expr -> Expr -> Bool
forall a. (Fixpoint a, Eq a) => a -> a -> Bool
go (Expr -> Expr
forall a. Fixpoint a => a -> a
F.simplify Expr
e1) (Expr -> Expr
forall a. Fixpoint a => a -> a
F.simplify Expr
e2)
    where go :: a -> a -> Bool
go a
r1 a
r2 = String -> Bool -> Bool
forall a. PPrint a => String -> a -> a
F.notracepp (String
"comparing " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Doc, Doc) -> String
forall a. PPrint a => a -> String
showpp (a -> Doc
forall a. Fixpoint a => a -> Doc
F.toFix a
r1, a -> Doc
forall a. Fixpoint a => a -> Doc
F.toFix a
r2)) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ a
r1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
r2 

instance REq r => REq (Located r) where 
  Located r
t1 =*= :: Located r -> Located r -> Bool
=*= Located r
t2 = Located r -> r
forall a. Located a -> a
val Located r
t1 r -> r -> Bool
forall a. REq a => a -> a -> Bool
=*= Located r -> r
forall a. Located a -> a
val Located r
t2