{-# LANGUAGE DataKinds      #-}
{-# LANGUAGE GADTs          #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase     #-}
{-# LANGUAGE TypeOperators  #-}
{-# LANGUAGE ViewPatterns   #-}

module Language.Haskell.TH.TypeInterpreter.Expression2
    ( Equation (..)
    , Kind (..)
    , Type (..)
    , typeEquality
    , substituteType
    , reduceType
    , SomeType (..)
    , fromSomeType )
where

import Type.Reflection ((:~~:) (..), TypeRep, Typeable, eqTypeRep, typeRep)

import Control.Monad (guard)

import Data.List  (intercalate)
import Data.Maybe (mapMaybe)

import Language.Haskell.TH (Name)

-- | Equation
data Equation a b = Equation (Type a) (Type b)

instance (Typeable a, Typeable b) => Eq (Equation a b) where
    left == right = equationsEqual (reduceEquation left) (reduceEquation right)

instance Show (Equation a b) where
    show (Equation pat body) = show pat ++ " => " ++ show body

-- | Reduce the given equation.
reduceEquation :: Equation a b -> Equation a b
reduceEquation (Equation pat body) = Equation (reduceType pat) (reduceType body)

-- | Are the given equations equal?
equationsEqual :: (Typeable a, Typeable b) => Equation a b -> Equation a b -> Bool
equationsEqual (Equation leftPat leftBody) (Equation rightPat rightBody) =
    typesEqual leftPat rightPat && typesEqual leftBody rightBody

-- | Match a single equation against an input.
matchEquation
    :: (Typeable a, Typeable b)
    => Equation a b -- ^ Pattern and substitution target
    -> Type a       -- ^ Input
    -> Maybe (Type b)
matchEquation (Equation pat body) input = matchAndSubstituteType pat input body

-- | Match many equations against an input.
matchEquations
    :: (Typeable a, Typeable b)
    => [Equation a b] -- ^ Patterns and substitution targets
    -> Type a         -- ^ Input
    -> Type b
matchEquations equations input
    | body : _  <- mapMaybe (`matchEquation` input) equations = body
    | otherwise                                               = App (Fun equations) input

-- | Type of a 'Type'
data Kind
    = KNat
    | KSym
    | KTyp
    | KArr Kind Kind

-- | Type
data Type :: Kind -> * where
    Nat :: Integer -> Type 'KNat
    Sym :: String -> Type 'KSym
    Var :: Name -> Type a
    Con :: Name -> Type a
    App :: Typeable a => Type ('KArr a b) -> Type a -> Type b
    Fun :: (Typeable a, Typeable b) => [Equation a b] -> Type ('KArr a b)

instance Typeable a => Eq (Type a) where
    lhs == rhs = typesEqual (reduceType lhs) (reduceType rhs)

instance Show (Type a) where
    showsPrec prec = \case
        Nat n   -> showsPrec prec n
        Sym s   -> showsPrec prec s
        Var v   -> showsPrec prec v
        Con c   -> showsPrec prec c
        App f x -> showParen (prec >= 10) (\ tail -> showsPrec 10 f (' ' : showsPrec 10 x tail))
        Fun eqs -> \ tail -> '{' : intercalate "; " (map show eqs) ++ '}' : tail

-- | Type equality for 'Type's
typeEquality :: (Typeable a, Typeable b) => Type a -> Type b -> Maybe (a :~~: b)
typeEquality _ _ = eqTypeRep typeRep typeRep

-- | Check whether two potentially disparate 'Type's are equal.
disparateTypesEqual :: (Typeable a, Typeable b) => Type a -> Type b -> Bool
disparateTypesEqual left right =
    case typeEquality left right of
        Just HRefl -> typesEqual left right
        _          -> False

-- | Check whether two 'Type's are equal.
typesEqual :: Typeable a => Type a -> Type a -> Bool
typesEqual (Nat l    ) (Nat r    ) = l == r
typesEqual (Sym l    ) (Sym r    ) = l == r
typesEqual (Var l    ) (Var r    ) = l == r
typesEqual (Con l    ) (Con r    ) = l == r
typesEqual (App lf lx) (App rf rx) = disparateTypesEqual lf rf && disparateTypesEqual lx rx
typesEqual (Fun l    ) (Fun r    ) = length l == length r && and (zipWith equationsEqual l r)
typesEqual _           _           = False

-- | @substituteType name value haystack@ replaces all occurrences of the variable @name@ with
-- @value@ in @haystack@.
substituteType
    :: (Typeable a, Typeable b)
    => Name   -- ^ Variable name
    -> Type a -- ^ Replacement for variable
    -> Type b -- ^ Haystack
    -> Type b
substituteType name val = \case
    exp@(Var varName) | name == varName ->
        case typeEquality exp val of
            Just HRefl -> val
            _          -> Var varName

    App fun param ->
        App (substituteType name val fun) (substituteType name val param)

    Fun equations ->
        Fun (map (\ (Equation pat body) -> Equation pat (substituteType name val body)) equations)

    typ -> typ

-- | Reduce the given 'Type' as much as possible.
reduceType :: Type a -> Type a
reduceType = \case
    App (reduceType -> fun) (reduceType -> param)
        | Fun equations <- fun -> matchEquations equations param
        | otherwise            -> App fun param

    Fun (map reduceEquation -> equations)
        | Equation var@(Var varName) (App subject use@(Var useName)) : _ <- equations
        , Just HRefl <- typeEquality var use
        , varName == useName ->
            subject

        | otherwise -> Fun equations

    typ -> typ

-- | Matches an input against a pattern. In case of succesful matching, matched variables will be
-- substituted in the target.
matchAndSubstituteType
    :: (Typeable a, Typeable b)
    => Type a -- ^ Pattern
    -> Type a -- ^ Input
    -> Type b -- ^ Substitution target
    -> Maybe (Type b)
matchAndSubstituteType (Var n)   x         target = Just (substituteType n x target)
matchAndSubstituteType _         (Var _  ) _      = Nothing
matchAndSubstituteType (App f x) (App g y) target = do
    HRefl <- typeEquality x y
    target' <- matchAndSubstituteType f g target
    matchAndSubstituteType x y target'
matchAndSubstituteType (Con l)   (Con r  ) target = target <$ guard (l == r)
matchAndSubstituteType _         _         _      = Nothing

-- | Some 'Type'
data SomeType where
    SomeType :: Typeable a => Type a -> SomeType

-- | Convert from a 'SomeType' to 'Type'.
fromSomeType :: Typeable a => SomeType -> Maybe (Type a)
fromSomeType (SomeType typ) =
    typeRepEquality typeRep typ typeRep
    where
        typeRepEquality :: TypeRep a -> Type a -> TypeRep b -> Maybe (Type b)
        typeRepEquality rep typ otherRep = do
            HRefl <- eqTypeRep rep otherRep
            pure typ