{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase         #-}
{-# LANGUAGE TemplateHaskell    #-}
{-# LANGUAGE ViewPatterns       #-}

module Language.Haskell.TH.TypeInterpreter.Expression
    ( TypeAtom (..)
    , TypeEquation (..)
    , TypeExp (..)
    , substitute
    , substituteAll
    , reduce
    , match )
where

import Control.Monad

import           Data.Data
import           Data.List  (intercalate)
import qualified Data.Map   as Map
import           Data.Maybe

import Language.Haskell.TH        (Name)
import Language.Haskell.TH.Syntax (Lift (..))

-- | Type atom
data TypeAtom
    = Integer Integer
    | String String
    | Name Name
    | PromotedName Name
    deriving (Eq, Data)

instance Lift TypeAtom

instance Show TypeAtom where
    show (Integer i)      = show i
    show (String s)       = show s
    show (Name n)         = show n
    show (PromotedName n) = '\'' : show n

-- | Type equation
data TypeEquation = TypeEquation [TypeExp] TypeExp
    deriving (Eq, Data)

instance Lift TypeEquation

instance Show TypeEquation where
    showsPrec prec (TypeEquation patterns body) =
        showParen (prec >= 10) $ \ tail ->
            showsPrec 9 patterns (" => " ++ showsPrec 9 body tail)

-- | Type expression
data TypeExp
    = Atom TypeAtom
    | Apply TypeExp TypeExp
    | Variable Name
    | Synonym Name TypeExp
    | Family [TypeEquation]
    deriving Data

instance Lift TypeExp

instance Eq TypeExp where
    (reduce -> left) == (reduce -> right) =
        compare left right
        where
            compare (Atom l     ) (Atom r     ) = l == r
            compare (Apply f x  ) (Apply g y  ) = compare f g && compare x y
            compare (Variable l ) (Variable r ) = l == r
            compare (Synonym n b) (Synonym m c) = substitute n (Variable m) b == c && substitute m (Variable n) c == b
            compare (Family l   ) (Family r   ) = l == r
            compare _             _             = False

instance Show TypeExp where
    showsPrec prec = \case
        Atom atom ->
            showsPrec prec atom

        Variable name ->
            (show name ++)

        Apply fun param ->
            showParen (prec >= 10) $ \ tail ->
                showsPrec 10 fun (' ' : showsPrec 10 param tail)

        Synonym varName body ->
            showParen (prec >= 10) $ \ tail ->
                'λ' : showsPrec 0 varName (". " ++ showsPrec 0 body tail)

        Family equations -> \ tail ->
            '{' : intercalate "; " (map show equations) ++ '}' : tail

-- | @substitute name typ exp@ replaces all occurences of @name@ in @exp@ with @typ@.
substitute :: Name -> TypeExp -> TypeExp -> TypeExp
substitute name typ =
    subst
    where
        subst = \case
            Variable varName
                | varName == name -> typ

            Apply fun param ->
                Apply (subst fun) (subst param)

            Synonym subName body
                | subName == name -> subst body
                | otherwise       -> Synonym subName (subst body)

            expression -> expression

-- | Just like 'substitute' but for more variables.
substituteAll :: Map.Map Name TypeExp -> TypeExp -> TypeExp
substituteAll mapping exp = Map.foldrWithKey substitute exp mapping

-- | Check if the given type expression utilizes the given variable name.
usesVariable :: Name -> TypeExp -> Bool
usesVariable name (Variable varName) = name == varName
usesVariable name (Apply fun param)  = usesVariable name fun || usesVariable name param
usesVariable name (Synonym pat body) = if name == pat then False else usesVariable name body
usesVariable name (Family equations) =
    any (\ (TypeEquation patterns body) ->
            not (any (usesVariable name) patterns) && usesVariable name body)
        equations
usesVariable _    _                  = False

-- | Try to reduce the given type expression as much as possible.
reduce :: TypeExp -> TypeExp
reduce = \case
    Apply (reduce -> fun) (reduce -> param)
        | Synonym var body <- fun -> reduce (substitute var param body)
        | Family equations <- fun -> applyFamily equations param
        | otherwise               -> Apply fun param

    Synonym pat (reduce -> body)
        | Apply fun (Variable var) <- body, pat == var, not (usesVariable pat fun) -> fun
        | otherwise -> Synonym pat body

    Family (map reduceEquation -> equations)
        | body : _ <- onlyEmptyEquations equations -> body
        | otherwise                                -> Family equations

    expression -> expression

    where
        reduceEquation (TypeEquation patterns body) =
            TypeEquation (map reduce patterns) (reduce body)

        onlyEmptyEquations =
            mapMaybe (\ (TypeEquation equations body) -> body <$ guard (null equations))

        matchEquation param = \case
            TypeEquation (equation : equations) body ->
                TypeEquation equations . reduce . flip substituteAll body <$> match equation param

            _ -> Nothing

        applyFamily equations param =
            case mapMaybe (matchEquation param) equations of
                [] -> Apply (Family equations) param -- Family is stuck
                equations
                    | body : _ <- onlyEmptyEquations equations -> body
                    | otherwise                                -> Family equations

-- | @match pattern input@ pattern matches @input@ against the given @pattern@.
match :: TypeExp -> TypeExp -> Maybe (Map.Map Name TypeExp)
match pattern input =
    matchOnly (reduce pattern) (reduce input)
    where
        matchOnly (Variable n)  x             = Just (Map.singleton n x)
        matchOnly (Apply f x )  (Apply g y)   = Map.union <$> matchOnly f g <*> matchOnly x y
        matchOnly (Atom l    )  (Atom r   )   = Map.empty <$ guard (l == r)
        matchOnly (Synonym _ l) (Synonym _ r) = Map.empty <$ matchOnly l r -- This might be stupid
        matchOnly _             _             = Nothing