module DDC.Type.Equiv
        (equivT)
where
import DDC.Type.Exp
import DDC.Type.Compounds
import DDC.Type.Transform.Crush
import DDC.Type.Transform.Trim
import DDC.Base.Pretty
import Data.Maybe
import qualified DDC.Type.Sum   as Sum


-- | Check equivalence of types.
--
--   Checks equivalence up to alpha-renaming, as well as crushing of effects
--   and trimming of closures.
--  
--   * Return `False` if we find any free variables.
--
--   * We assume the types are well-kinded, so that the type annotations on
--     bound variables match the binders. If this is not the case then you get
--     an indeterminate result.
--
equivT  :: (Ord n, Pretty n) => Type n -> Type n -> Bool
equivT t1 t2
        = equivT' [] 0 [] 0 t1 t2


equivT' :: (Ord n, Pretty n)
        => [Bind n] -> Int
        -> [Bind n] -> Int
        -> Type n   -> Type n
        -> Bool

equivT' stack1 depth1 stack2 depth2 t1 t2
 = let  t1'     = unpackSumT $ crushSomeT t1
        t2'     = unpackSumT $ crushSomeT t2
   in case (t1', t2') of
        (TVar u1,         TVar u2)
         -- Bound variables are name-equivalent.
         | u1 == u2     -> True

         -- Variables aren't name equivalent, 
         -- but would be equivalent if we renamed them.
         | depth1 == depth2
         , Just (ix1, t1a)   <- getBindType stack1 u1
         , Just (ix2, t2a)   <- getBindType stack2 u2
         , ix1 == ix2
         -> equivT' stack1 depth1 stack2 depth2 t1a t2a

        -- Constructor names must be equal.
        (TCon tc1,        TCon tc2)
         -> tc1 == tc2

        -- Push binders on the stack as we enter foralls.
        (TForall b11 t12, TForall b21 t22)
         |  equivT  (typeOfBind b11) (typeOfBind b21)
         -> equivT' (b11 : stack1) (depth1 + 1) 
                    (b21 : stack2) (depth2 + 1) 
                    t12 t22

        -- Decend into applications.
        (TApp t11 t12,    TApp t21 t22)
         -> equivT' stack1 depth1 stack2 depth2 t11 t21
         && equivT' stack1 depth1 stack2 depth2 t12 t22
        
        -- Sums are equivalent if all of their components are.
        (TSum ts1,        TSum ts2)
         -> let ts1'      = Sum.toList ts1
                ts2'      = Sum.toList ts2
                equiv     = equivT' stack1 depth1 stack2 depth2

                -- If all the components of the sum were in the element
                -- arrays then they come out of Sum.toList sorted
                -- and we can compare corresponding pairs.
                checkFast = and $ zipWith equiv ts1' ts2'

                -- If any of the components use a higher kinded type variable
                -- like (c : % ~> !) then they won't nessesarally be sorted,
                -- so we need to do this slower O(n^2) check.
                checkSlow = and [ or (map (equiv t1c) ts2') | t1c <- ts1' ]
                         && and [ or (map (equiv t2c) ts1') | t2c <- ts2' ]

            in  (length ts1' == length ts2')
            &&  (checkFast || checkSlow)

        (_, _)  -> False


-- | Unpack single element sums into plain types.
unpackSumT :: Type n -> Type n
unpackSumT (TSum ts)
        | [t]   <- Sum.toList ts = t
unpackSumT tt                     = tt


-- | Crush compound effects and closure terms.
--   We check for a crushable term before calling crushT because that function
--   will recursively crush the components. 
--   As equivT is already recursive, we don't want a doubly-recursive function
--   that tries to re-crush the same non-crushable type over and over.
--
crushSomeT :: (Ord n, Pretty n) => Type n -> Type n
crushSomeT tt
 = case tt of
        (TApp (TCon tc) _)
         -> case tc of
                TyConSpec    TcConDeepRead   -> crushEffect tt
                TyConSpec    TcConDeepWrite  -> crushEffect tt
                TyConSpec    TcConDeepAlloc  -> crushEffect tt

                -- If a closure is miskinded then 'trimClosure' 
                -- can return Nothing, so we just leave the term untrimmed.
                TyConSpec    TcConDeepUse    -> fromMaybe tt (trimClosure tt)

                TyConWitness TwConDeepGlobal -> crushEffect tt
                _                            -> tt

        _ -> tt


-- | Lookup the type of a bound thing from the binder stack.
--   The binder stack contains the binders of all the `TForall`s we've
--   entered under so far.
getBindType :: Eq n => [Bind n] -> Bound n -> Maybe (Int, Type n)
getBindType bs' u
 = go 0 bs'
 where  go n (BName n1 t : bs)
         | UName n2 _   <- u
         , n1 == n2     = Just (n, t)
         | otherwise    = go (n + 1) bs


        go n (BAnon t   : bs)
         | UIx i _      <- u
         , i == 0       = Just (n, t)

         | UIx i _      <- u
         , i < 0        = Nothing

         | otherwise    = go (n + 1) bs


        go n (BNone _   : bs)
         = go (n + 1) bs

        go _ []         = Nothing