{-# LANGUAGE CPP
           , DataKinds
           , GADTs
           , TypeOperators
           , PolyKinds
           , FlexibleContexts
           , ScopedTypeVariables
           , UndecidableInstances
           #-}

-- TODO: all the instances here are orphans. To ensure that we don't
-- have issues about orphan instances, we should give them all
-- newtypes and only provide the instance for those newtypes!
-- (and\/or: for the various op types, it's okay to move them to
-- AST.hs to avoid orphanage. It's just the instances for 'Term'
-- itself which are morally suspect outside of testing.)
{-# OPTIONS_GHC -Wall -fwarn-tabs -fno-warn-orphans #-}
----------------------------------------------------------------
--                                                    2016.05.24
-- |
-- Module      :  Language.Hakaru.Syntax.ABT.Eq
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- Warning: The following module is for testing purposes only. Using
-- the 'JmEq1' instance for 'Term' is inefficient and should not
-- be done accidentally. To implement that (orphan) instance we
-- also provide the following (orphan) instances:
--
-- > SArgs      : JmEq1
-- > Term       : JmEq1, Eq1, Eq
-- > TrivialABT : JmEq2, JmEq1, Eq2, Eq1, Eq
--
-- TODO: because this is only for testing, everything else should
-- move to the @Tests@ directory.
----------------------------------------------------------------
module Language.Hakaru.Syntax.AST.Eq where

import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.TypeOf

import Control.Monad.Reader

import qualified Data.Foldable      as F
import qualified Data.List.NonEmpty as L
import qualified Data.Sequence      as S
import qualified Data.Traversable   as T

#if __GLASGOW_HASKELL__ < 710
import           Data.Functor ((<$>))
import           Data.Traversable
#endif


import Data.Maybe
-- import Data.Number.Nat

import Unsafe.Coerce

---------------------------------------------------------------------
-- | This function performs 'jmEq' on a @(:$)@ node of the AST.
-- It's necessary to break it out like this since we can't just
-- give a 'JmEq1' instance for 'SCon' due to polymorphism issues
-- (e.g., we can't just say that 'Lam_' is John Major equal to
-- 'Lam_', since they may be at different types). However, once the
-- 'SArgs' associated with the 'SCon' is given, that resolves the
-- polymorphism.
jmEq_S
    :: (ABT Term abt, JmEq2 abt)
    => SCon args  a  -> SArgs abt args
    -> SCon args' a' -> SArgs abt args'
    -> Maybe (TypeEq a a', TypeEq args args')
jmEq_S Lam_      es Lam_       es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S App_      es App_       es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S Let_      es Let_       es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S (CoerceTo_ c) (es :* End) (CoerceTo_ c') (es' :* End) = do
    (Refl, Refl) <- jmEq2 es es'
    let t1 = coerceTo c  (typeOf es)
    let t2 = coerceTo c' (typeOf es')
    Refl <- jmEq1 t1 t2
    return (Refl, Refl)
jmEq_S (UnsafeFrom_ c) (es :* End) (UnsafeFrom_ c') (es' :* End) = do
    (Refl, Refl) <- jmEq2 es es'
    let t1 = coerceFrom c  (typeOf es)
    let t2 = coerceFrom c' (typeOf es')
    Refl <- jmEq1 t1 t2
    return (Refl, Refl)
jmEq_S (PrimOp_ op) es (PrimOp_ op') es' = do
    Refl <- jmEq1 es es'
    (Refl, Refl) <- jmEq2 op op'
    return (Refl, Refl)
jmEq_S (ArrayOp_ op) es (ArrayOp_ op') es' = do
    Refl <- jmEq1 es es'
    (Refl, Refl) <- jmEq2 op op'
    return (Refl, Refl)
jmEq_S (MeasureOp_ op) es (MeasureOp_ op') es' = do
    Refl <- jmEq1 es es'
    (Refl, Refl) <- jmEq2 op op'
    return (Refl, Refl)
jmEq_S Dirac     es Dirac      es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S MBind     es MBind      es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S Integrate es Integrate  es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S (Summate h1 h2) es (Summate h1' h2') es' = do
    Refl <- jmEq1 (sing_HDiscrete h1) (sing_HDiscrete h1')
    Refl <- jmEq1 (sing_HSemiring h2) (sing_HSemiring h2')
    Refl <- jmEq1 es es'
    Just (Refl, Refl)
jmEq_S Expect    es Expect     es' =
    jmEq1 es es' >>= \Refl -> Just (Refl, Refl)
jmEq_S _         _  _          _   = Nothing


-- TODO: Handle jmEq2 of pat and pat'
jmEq_Branch
    :: (ABT Term abt, JmEq2 abt)
    => [(Branch a abt b, Branch a abt b')]
    -> Maybe (TypeEq b b')
jmEq_Branch []                                  = Nothing
jmEq_Branch [(Branch pat e, Branch pat' e')]    = do
    (Refl, Refl) <- jmEq2 e e'
    return Refl
jmEq_Branch ((Branch pat e, Branch pat' e'):es) = do
    (Refl, Refl) <- jmEq2 e e'
    jmEq_Branch es

instance JmEq2 abt => JmEq1 (SArgs abt) where
    jmEq1 End       End       = Just Refl
    jmEq1 (x :* xs) (y :* ys) =
        jmEq2 x  y  >>= \(Refl, Refl) ->
        jmEq1 xs ys >>= \Refl ->
        Just Refl
    jmEq1 _         _         = Nothing


instance (ABT Term abt, JmEq2 abt) => JmEq1 (Term abt) where
    jmEq1 (o :$ es) (o' :$ es') = do
        (Refl, Refl) <- jmEq_S o es o' es'
        return Refl
    jmEq1 (NaryOp_ o es) (NaryOp_ o' es') = do
        Refl <- jmEq1 o o'
        () <- all_jmEq2 es es'
        return Refl
    jmEq1 (Literal_ v)  (Literal_ w)   = jmEq1 v w
    jmEq1 (Empty_ a)    (Empty_ b)     = jmEq1 a b
    jmEq1 (Array_ i f)  (Array_ j g)   = do
        (Refl, Refl) <- jmEq2 i j
        (Refl, Refl) <- jmEq2 f g
        Just Refl
    jmEq1 (Datum_ (Datum hint _ _)) (Datum_ (Datum hint' _ _))
        -- BUG: We need to compare structurally rather than using the hint
        | hint == hint' = unsafeCoerce (Just Refl)
        | otherwise     = Nothing
    jmEq1 (Case_  a bs) (Case_  a' bs')      = do
        (Refl, Refl) <- jmEq2 a a'
        jmEq_Branch (zip bs bs')
    jmEq1 (Superpose_ pms) (Superpose_ pms') = do
      (Refl,Refl) L.:| _ <- T.sequence $ fmap jmEq_Tuple (L.zip pms pms')
      return Refl
    jmEq1 _              _              = Nothing


all_jmEq2
    :: (ABT Term abt, JmEq2 abt)
    => S.Seq (abt '[] a)
    -> S.Seq (abt '[] a)
    -> Maybe ()
all_jmEq2 xs ys =
    let eq x y = isJust (jmEq2 x y)
    in if F.and (S.zipWith eq xs ys) then Just () else Nothing


jmEq_Tuple :: (ABT Term abt, JmEq2 abt)
           => ((abt '[] a , abt '[] b), 
               (abt '[] a', abt '[] b'))
           -> Maybe (TypeEq a a', TypeEq b b')
jmEq_Tuple ((a,b), (a',b')) = do
  a'' <- jmEq2 a a' >>= (\(Refl, Refl) -> Just Refl)
  b'' <- jmEq2 b b' >>= (\(Refl, Refl) -> Just Refl)
  return (a'', b'')


-- TODO: a more general function of type:
--   (JmEq2 abt) => Term abt a -> Term abt b -> Maybe (Sing a, TypeEq a b)
-- This can then be used to define typeOf and instance JmEq2 Term

instance (ABT Term abt, JmEq2 abt) => Eq1 (Term abt) where
    eq1 x y = isJust (jmEq1 x y)

instance (ABT Term abt, JmEq2 abt) => Eq (Term abt a) where
    (==) = eq1

instance ( Show1 (Sing :: k -> *)
         , JmEq1 (Sing :: k -> *)
         , JmEq1 (syn (TrivialABT syn))
         , Foldable21 syn
         ) => JmEq2 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *))
    where
    jmEq2 x y =
        case (viewABT x, viewABT y) of
        (Syn t1, Syn t2) ->
            jmEq1 t1 t2 >>= \Refl -> Just (Refl, Refl) 
        (Var (Variable _ _ t1), Var (Variable _ _ t2)) ->
            jmEq1 t1 t2 >>= \Refl -> Just (Refl, Refl)
        (Bind (Variable _ _ x1) v1, Bind (Variable _ _ x2) v2) -> do
            Refl <- jmEq1 x1 x2
            (Refl,Refl) <- jmEq2 (unviewABT v1) (unviewABT v2)
            return (Refl, Refl)
        _ -> Nothing

instance ( Show1 (Sing :: k -> *)
         , JmEq1 (Sing :: k -> *)
         , JmEq1 (syn (TrivialABT syn))
         , Foldable21 syn
         ) => JmEq1 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs)
    where
    jmEq1 x y = jmEq2 x y >>= \(Refl, Refl) -> Just Refl

instance ( Show1 (Sing :: k ->  *)
         , JmEq1 (Sing :: k -> *)
         , Foldable21 syn
         , JmEq1 (syn (TrivialABT syn))
         ) => Eq2 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *))
    where
    eq2 x y = isJust (jmEq2 x y)

instance ( Show1 (Sing :: k ->  *)
         , JmEq1 (Sing :: k -> *)
         , Foldable21 syn
         , JmEq1 (syn (TrivialABT syn))
         ) => Eq1 (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs)
    where
    eq1 = eq2

instance ( Show1 (Sing :: k ->  *)
         , JmEq1 (Sing :: k -> *)
         , Foldable21 syn
         , JmEq1 (syn (TrivialABT syn))
         ) => Eq (TrivialABT (syn :: ([k] -> k -> *) -> k -> *) xs a)
    where
    (==) = eq1


type Varmap = Assocs (Variable :: Hakaru -> *)

void_jmEq1
    :: Sing (a :: Hakaru)
    -> Sing (b :: Hakaru)
    -> ReaderT Varmap Maybe ()
void_jmEq1 x y = lift (jmEq1 x y) >> return ()

void_varEq
    :: Variable (a :: Hakaru)
    -> Variable (b :: Hakaru)
    -> ReaderT Varmap Maybe ()   
void_varEq x y = lift (varEq x y) >> return ()

try_bool :: Bool -> ReaderT Varmap Maybe ()
try_bool b = lift $ if b then Just () else Nothing

alphaEq
    :: forall abt a
    .  (ABT Term abt)
    => abt '[] a
    -> abt '[] a
    -> Bool
alphaEq e1 e2 =
    maybe False (const True)
        $ runReaderT (go (viewABT e1) (viewABT e2)) emptyAssocs
    where
    -- Don't compare @x@ to @y@ directly; instead,
    -- look up whatever @x@ renames to (i.e., @y'@)
    -- and then see whether that is equal to @y@.
    go  :: forall xs1 xs2 a
        .  View (Term abt) xs1 a
        -> View (Term abt) xs2 a
        -> ReaderT Varmap Maybe ()
    go (Var x) (Var y) = do
        s <- ask
        case lookupAssoc x s of
            Nothing -> void_varEq x  y -- free variables
            Just y' -> void_varEq y' y

    -- remember that @x@ renames to @y@ and recurse
    go (Bind x e1) (Bind y e2) = do
        Refl <- lift $ jmEq1 (varType x) (varType y)
        local (insertAssoc (Assoc x y)) (go e1 e2)

    -- perform the core comparison for syntactic equality
    go (Syn t1) (Syn t2) = termEq t1 t2

    -- if the views don't match, then clearly they are not equal.
    go _ _ = lift Nothing

    termEq :: forall a
        .  Term abt a
        -> Term abt a
        -> ReaderT Varmap Maybe ()
    termEq e1 e2 =
        case (e1, e2) of
        (o1 :$ es1, o2 :$ es2)             -> sConEq o1 es1 o2 es2
        (NaryOp_ op1 es1, NaryOp_ op2 es2) -> do
            try_bool (op1 == op2)
            F.sequence_ $ S.zipWith go (viewABT <$> es1) (viewABT <$> es2)
        (Literal_ x, Literal_ y)           -> try_bool (x == y)
        (Empty_ x, Empty_ y)               -> void_jmEq1 x y
        (Datum_ d1, Datum_ d2)             -> datumEq d1 d2
        (Array_ n1 e1, Array_ n2 e2)       -> do
            go (viewABT n1) (viewABT n2)
            go (viewABT e1) (viewABT e2)
        (Case_ e1 bs1, Case_ e2 bs2)       -> do
            Refl <- lift $ jmEq1 (typeOf e1) (typeOf e2)
            go (viewABT e1) (viewABT e2)
            zipWithM_ sBranch bs1 bs2
        (Superpose_ pms1, Superpose_ pms2) ->
            F.sequence_ $ L.zipWith pairEq pms1 pms2
        (Reject_ x, Reject_ y)             -> void_jmEq1 x y
        (_, _)                             -> lift Nothing

    sArgsEq
        :: forall args
        .  SArgs abt args
        -> SArgs abt args
        -> ReaderT Varmap Maybe ()
    sArgsEq End         End         = return ()
    sArgsEq (e1 :* es1) (e2 :* es2) = do
        go (viewABT e1) (viewABT e2)
        sArgsEq es1 es2
    sArgsEq _ _ = lift Nothing

    sConEq
        :: forall a args1 args2
        .  SCon  args1 a
        -> SArgs abt args1
        -> SCon args2 a
        -> SArgs abt args2
        -> ReaderT Varmap Maybe ()
    sConEq Lam_   e1
           Lam_   e2 = sArgsEq e1 e2

    sConEq App_   (e1  :* e2  :* End)
           App_   (e1' :* e2' :* End) = do
        Refl <- lift $ jmEq1 (typeOf e2) (typeOf e2')
        go (viewABT e1) (viewABT e1')
        go (viewABT e2) (viewABT e2')

    sConEq Let_   (e1  :* e2  :* End)
           Let_   (e1' :* e2' :* End) = do
        Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
        go (viewABT e1) (viewABT e1')
        go (viewABT e2) (viewABT e2')

    sConEq (CoerceTo_ _) (e1 :* End)
           (CoerceTo_ _) (e2 :* End) =
        void_jmEq1 (typeOf e1) (typeOf e2)

    sConEq (UnsafeFrom_ _) (e1 :* End)
           (UnsafeFrom_ _) (e2 :* End) =
        void_jmEq1 (typeOf e1) (typeOf e2)

    sConEq (PrimOp_ o1) es1
           (PrimOp_ o2) es2    = primOpEq o1 es1 o2 es2

    sConEq (ArrayOp_ o1) es1
           (ArrayOp_ o2) es2   = arrayOpEq o1 es1 o2 es2

    sConEq (MeasureOp_ o1) es1
           (MeasureOp_ o2) es2 = measureOpEq o1 es1 o2 es2

    sConEq Dirac e1
           Dirac e2            = sArgsEq e1 e2

    sConEq MBind (e1  :* e2  :* End)
           MBind (e1' :* e2' :* End) = do
        Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
        go (viewABT e1) (viewABT e1')
        go (viewABT e2) (viewABT e2')

    sConEq Plate     e1 Plate     e2    = sArgsEq e1 e2
    sConEq Chain     e1 Chain     e2    = sArgsEq e1 e2
    sConEq Integrate e1 Integrate e2    = sArgsEq e1 e2

    sConEq (Summate h1 h2) e1 (Summate h1' h2') e2 = do
        Refl <- lift $ jmEq1 (sing_HDiscrete h1) (sing_HDiscrete h1')
        Refl <- lift $ jmEq1 (sing_HSemiring h2) (sing_HSemiring h2')
        sArgsEq e1 e2

    sConEq Expect (e1  :* e2  :* End)
           Expect (e1' :* e2' :* End) = do
        Refl <- lift $ jmEq1 (typeOf e1) (typeOf e1')
        go (viewABT e1) (viewABT e1')
        go (viewABT e2) (viewABT e2')

    sConEq _ _ _ _ = lift Nothing


    primOpEq
        :: forall a typs1 typs2 args1 args2
        .  (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
            typs2 ~ UnLCs args2, args2 ~ LCs typs2)
        => PrimOp typs1 a -> SArgs abt args1
        -> PrimOp typs2 a -> SArgs abt args2
        -> ReaderT Varmap Maybe ()
    primOpEq p1 e1 p2 e2 = do
        (Refl, Refl) <- lift $ jmEq2 p1 p2
        sArgsEq e1 e2

    arrayOpEq
        :: forall a typs1 typs2 args1 args2
        .  (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
            typs2 ~ UnLCs args2, args2 ~ LCs typs2)
        => ArrayOp typs1 a -> SArgs abt args1
        -> ArrayOp typs2 a -> SArgs abt args2
        -> ReaderT Varmap Maybe ()
    arrayOpEq p1 e1 p2 e2 = do
        (Refl, Refl) <- lift $ jmEq2 p1 p2
        sArgsEq e1 e2

    measureOpEq
        :: forall a typs1 typs2 args1 args2
        . (typs1 ~ UnLCs args1, args1 ~ LCs typs1,
            typs2 ~ UnLCs args2, args2 ~ LCs typs2)
        => MeasureOp typs1 a -> SArgs abt args1
        -> MeasureOp typs2 a -> SArgs abt args2
        -> ReaderT Varmap Maybe ()
    measureOpEq m1 e1 m2 e2 = do
        (Refl,Refl) <- lift $ jmEq2 m1 m2
        sArgsEq e1 e2

    datumEq :: forall a
        .  Datum (abt '[]) a
        -> Datum (abt '[]) a
        -> ReaderT Varmap Maybe ()
    datumEq (Datum _ _ d1) (Datum _ _ d2) = datumCodeEq d1 d2

    datumCodeEq
        :: forall xss a
        .  DatumCode xss (abt '[]) a
        -> DatumCode xss (abt '[]) a
        -> ReaderT Varmap Maybe ()
    datumCodeEq (Inr c) (Inr d) = datumCodeEq c d
    datumCodeEq (Inl c) (Inl d) = datumStructEq c d
    datumCodeEq _       _       = lift Nothing

    datumStructEq
        :: forall xs a
        .  DatumStruct xs (abt '[]) a
        -> DatumStruct xs (abt '[]) a
        -> ReaderT Varmap Maybe ()
    datumStructEq (Et c1 c2) (Et d1 d2) = do
        datumFunEq c1 d1
        datumStructEq c2 d2
    datumStructEq Done       Done       = return ()
    datumStructEq _          _          = lift Nothing
    
    datumFunEq
        :: forall x a
        .  DatumFun x (abt '[]) a
        -> DatumFun x (abt '[]) a
        -> ReaderT Varmap Maybe ()
    datumFunEq (Konst e) (Konst f) = go (viewABT e) (viewABT f) 
    datumFunEq (Ident e) (Ident f) = go (viewABT e) (viewABT f) 
    datumFunEq _          _        = lift Nothing
    
    pairEq
        :: forall a b
        .  (abt '[] a, abt '[] b)
        -> (abt '[] a, abt '[] b)
        -> ReaderT Varmap Maybe ()
    pairEq (x1, y1) (x2, y2) = do
        go (viewABT x1) (viewABT x2)
        go (viewABT y1) (viewABT y2)

    sBranch
        :: forall a b
        .  Branch a abt b
        -> Branch a abt b
        -> ReaderT Varmap Maybe ()
    sBranch (Branch _ e1) (Branch _ e2) = go (viewABT e1) (viewABT e2)