{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase         #-}
{-# LANGUAGE PatternSynonyms    #-}
{-# LANGUAGE ViewPatterns       #-}

module Language.Haskell.TH.TypeInterpreter.Expression
    ( TypeAtom (..)
    , TypeEquation (..)
    , TypeExp (..)
    , pattern Synonym
    , substitute
    , substituteAll
    , reduce
    , match )
where

import Control.Monad

import           Data.Data
import           Data.List  (intercalate)
import qualified Data.Map   as Map
import           Data.Maybe

import Language.Haskell.TH        (Name)
import Language.Haskell.TH.Syntax (Lift (..))

-- | Type atom
data TypeAtom
    = Integer Integer
    | String String
    | Name Name
    | PromotedName Name
    deriving (Eq, Data)

instance Lift TypeAtom

instance Show TypeAtom where
    show (Integer i)      = show i
    show (String s)       = show s
    show (Name n)         = show n
    show (PromotedName n) = '\'' : show n

-- | Type equation
data TypeEquation = TypeEquation [TypeExp] TypeExp
    deriving (Eq, Data)

instance Lift TypeEquation

instance Show TypeEquation where
    showsPrec prec (TypeEquation patterns body) =
        showParen (prec >= 10) $ \ tail ->
            showsPrec 9 patterns (" => " ++ showsPrec 9 body tail)

-- | Type expression
data TypeExp
    = Atom TypeAtom
    | Apply TypeExp TypeExp
    | Variable Name
    | Function [TypeEquation]
    deriving Data

instance Lift TypeExp

instance Eq TypeExp where
    (reduce -> left) == (reduce -> right) =
        compare left right
        where
            compare (Atom l     ) (Atom r     ) = l == r
            compare (Apply f x  ) (Apply g y  ) = compare f g && compare x y
            compare (Variable l ) (Variable r ) = l == r
            compare (Synonym n b) (Synonym m c) = compareAfterSubstitute n b m c
            compare (Function l ) (Function r ) = l == r
            compare _             _             = False

            substituteAndReduce searchName replacementName body =
                reduce (substitute searchName (Variable replacementName) body)

            compareAfterSubstitute leftName leftBody rightName rightBody =
                compare (substituteAndReduce leftName rightName leftBody) rightBody
                && compare (substituteAndReduce rightName leftName rightBody) leftBody

instance Show TypeExp where
    showsPrec prec = \case
        Atom atom ->
            showsPrec prec atom

        Variable name ->
            (show name ++)

        Apply fun param ->
            showParen (prec >= 10) $ \ tail ->
                showsPrec 10 fun (' ' : showsPrec 10 param tail)

        Function equations -> \ tail ->
            '{' : intercalate "; " (map show equations) ++ '}' : tail

-- | Synonym
pattern Synonym :: Name -> TypeExp -> TypeExp
pattern Synonym var exp = Function [TypeEquation [Variable var] exp]

-- | @substitute name typ exp@ replaces all occurences of @name@ in @exp@ with @typ@.
substitute :: Name -> TypeExp -> TypeExp -> TypeExp
substitute name typ =
    subst
    where
        subst = \case
            Variable varName
                | varName == name -> typ

            Apply fun param ->
                Apply (subst fun) (subst param)

            Function equations ->
                Function (map substEquation equations)

            expression -> expression

        substEquation (TypeEquation params body)
            | any (usesVariable name) params = TypeEquation params body
            | otherwise                      = TypeEquation params (subst body)

-- | Just like 'substitute' but for more variables.
substituteAll :: Map.Map Name TypeExp -> TypeExp -> TypeExp
substituteAll mapping exp = Map.foldrWithKey substitute exp mapping

-- | Check if the given type expression utilizes the given variable name.
usesVariable :: Name -> TypeExp -> Bool
usesVariable name (Variable varName)   = name == varName
usesVariable name (Apply fun param)    = usesVariable name fun || usesVariable name param
usesVariable name (Synonym pat body)   = name /= pat && usesVariable name body
usesVariable name (Function equations) =
    any (\ (TypeEquation patterns body) ->
            not (any (usesVariable name) patterns) && usesVariable name body)
        equations
usesVariable _    _                    = False

-- | Try to reduce the given type expression as much as possible.
reduce :: TypeExp -> TypeExp
reduce = \case
    Apply (reduce -> fun) (reduce -> param)
        | Synonym var body <- fun   -> reduce (substitute var param body)
        | Function equations <- fun -> applyFamily equations param
        | otherwise                 -> Apply fun param

    Function (map reduceEquation -> equations)
        | [TypeEquation [Variable pat] (Apply fun (Variable var))] <- equations
        , pat == var
        , not (usesVariable pat fun) ->
            fun

        | body : _ <- onlyEmptyEquations equations ->
            body

        | otherwise ->
            Function equations

    expression -> expression

    where
        reduceEquation (TypeEquation patterns body) =
            TypeEquation (map reduce patterns) (reduce body)

        onlyEmptyEquations =
            mapMaybe (\ (TypeEquation equations body) -> body <$ guard (null equations))

        matchEquation param = \case
            TypeEquation (equation : equations) body ->
                TypeEquation equations . reduce . flip substituteAll body <$> match equation param

            _ -> Nothing

        applyFamily equations param =
            case mapMaybe (matchEquation param) equations of
                [] -> Apply (Function equations) param -- Function is stuck
                equations
                    | body : _ <- onlyEmptyEquations equations -> body
                    | otherwise                                -> Function equations

-- | @match pat input@ pattern matches @input@ against the given @pat@.
match :: TypeExp -> TypeExp -> Maybe (Map.Map Name TypeExp)
match pat input =
    matchOnly (reduce pat) (reduce input)
    where
        matchOnly _             (Variable _)  = Nothing
        matchOnly (Variable n)  x             = Just (Map.singleton n x)
        matchOnly (Apply f x )  (Apply g y)   = Map.union <$> matchOnly f g <*> matchOnly x y
        matchOnly (Atom l    )  (Atom r   )   = Map.empty <$ guard (l == r)
        matchOnly _             _             = Nothing