{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Symantic.Compiling.Term where

import Data.Maybe (isJust)
import Data.Semigroup (Semigroup(..))
import qualified Data.Kind as K
import qualified Data.Set as Set
import qualified Data.Text as Text

import Language.Symantic.Grammar
import Language.Symantic.Interpreting
import Language.Symantic.Transforming.Trans
import Language.Symantic.Typing

-- * Type 'Term'
data Term src ss ts vs (t::K.Type) where
 Term :: Type src vs       q
      -> Type src vs       t
      -> TeSym ss ts       (q #> t)
      -> Term src ss ts vs (q #> t)
instance Source src => Eq (Term src ss ts vs t) where
        Term qx tx _ == Term qy ty _ = qx == qy && tx == ty
instance Source src => Show (Term src ss ts vs t) where
        showsPrec p (Term q t _te) = showsPrec p (q #> t)

-- Source
type instance SourceOf (Term src ss ts vs t) = src
instance Source src => Sourceable (Term src ss ts vs t) where
        sourceOf  (Term _q t _te)    = sourceOf t
        setSource (Term q t te) src = Term q (setSource t src) te

-- Const
instance ConstsOf (Term src ss ts vs t) where
        constsOf (Term q t _te) = constsOf q `Set.union` constsOf t

-- Var
type instance VarsOf (Term src ss ts vs t) = vs
instance LenVars (Term src ss ts vs t) where
        lenVars (Term _q t _te) = lenVars t
instance AllocVars (Term src ss ts) where
        allocVarsL len (Term q t te) = Term (allocVarsL len q) (allocVarsL len t) te
        allocVarsR len (Term q t te) = Term (allocVarsR len q) (allocVarsR len t) te

-- Fam
instance ExpandFam (Term src ss ts vs t) where
        expandFam (Term q t te) = Term (expandFam q) (expandFam t) te

-- Type
instance SourceInj (TermT src ss ts vs) src => TypeOf (Term src ss ts vs) where
        typeOf t = typeOfTerm t `withSource` TermT t

typeOfTerm :: Source src => Term src ss ts vs t -> Type src vs t
typeOfTerm (Term q t _) = q #> t

-- ** Type 'TermT'
-- | 'Term' with existentialized 'Type'.
data TermT src ss ts vs = forall t. TermT (Term src ss ts vs t)
instance Source src => Show (TermT src ss ts vs) where
        showsPrec p (TermT t) = showsPrec p t

-- ** Type 'TermVT'
-- | 'Term' with existentialized 'Var's and 'Type'.
data TermVT src ss ts = forall vs t. TermVT (Term src ss ts vs t)
instance Source src => Eq (TermVT src ss ts) where
        TermVT x == TermVT y =
                case appendVars x y of
                 (Term qx' tx' _, Term qy' ty' _) ->
                        isJust $ (qx' #> tx') `eqTypeKi` (qy' #> ty')
instance Source src => Show (TermVT src ss ts) where
        showsPrec p (TermVT t) = showsPrec p t
type instance SourceOf (TermVT src ss ts) = src
instance Source src => Sourceable (TermVT src ss ts) where
        sourceOf  (TermVT t)     = sourceOf t
        setSource (TermVT t) src = TermVT $ setSource t src

liftTermVT :: TermVT src ss '[] -> TermVT src ss ts
liftTermVT (TermVT (Term q t (TeSym te))) =
        TermVT $ Term q t $
        TeSym $ \_c -> te CtxTeZ

-- ** Type 'TermAVT'
-- | Like 'TermVT', but 'CtxTe'-free.
data TermAVT src ss = forall vs t. TermAVT (forall ts. Term src ss ts vs t)
type instance SourceOf (TermAVT src ss) = src
instance Source src => Sourceable (TermAVT src ss) where
        sourceOf  (TermAVT t)     = sourceOf t
        setSource (TermAVT t) src = TermAVT (setSource t src)
instance Source src => Eq (TermAVT src ss) where
        TermAVT x == TermAVT y =
                case appendVars x y of
                 (Term qx' tx' _, Term qy' ty' _) ->
                        isJust $ (qx' #> tx') `eqTypeKi` (qy' #> ty')
instance Source src => Show (TermAVT src ss) where
        showsPrec p (TermAVT t) = showsPrec p t

-- * Type 'TeSym'
-- | Symantic of a 'Term'.
newtype TeSym ss ts (t::K.Type)
 = TeSym
 ( forall term.
   Syms ss term =>
   Sym_Lambda term =>
   QualOf t =>
   CtxTe term ts -> term (UnQualOf t)
 )

-- | Like 'TeSym', but 'CtxTe'-free
-- and using 'symInj' to be able to use 'Sym'@ s@ inside.
teSym ::
 forall s ss ts t.
 SymInj ss s =>
 (forall term. Sym s term => Sym_Lambda term => QualOf t => term (UnQualOf t)) ->
 TeSym ss ts t
teSym t = symInj @s (TeSym $ const t)

-- ** Type family 'QualOf'
-- | Qualification
type family QualOf (t::K.Type) :: Constraint where
        QualOf (q #> t) = q -- (q # QualOf t)
        QualOf t = (()::Constraint)

-- ** Type family 'UnQualOf'
-- | Unqualification
type family UnQualOf (t::K.Type) :: K.Type where
        UnQualOf (q #> t) = t -- UnQualOf t
        UnQualOf t = t

-- | Return 'K.Constraint' and 'K.Type' part of given 'Type'.
unQualTy ::
 Source src =>
 Type src vs (t::K.Type) ->
 ( TypeK src vs K.Constraint
 , TypeK src vs K.Type )
unQualTy (TyApp _ (TyApp _ c q) t)
 | Just HRefl <- proj_ConstKiTy @(K (#>)) @(#>) c
 = (TypeK q, TypeK t)
unQualTy t = (TypeK $ noConstraintLen (lenVars t), TypeK t)

-- | Remove 'K.Constraint's from given 'Type'.
unQualsTy :: Source src => Type src vs (t::kt) -> TypeK src vs kt
unQualsTy (TyApp _ (TyApp _ c _q) t)
 | Just HRefl <- proj_ConstKiTy @(K (#>)) @(#>) c
 = unQualsTy t
unQualsTy (TyApp src f a)
 | TypeK f' <- unQualsTy f
 , TypeK a' <- unQualsTy a
 = TypeK $ TyApp src f' a'
unQualsTy t = TypeK t

-- * Type 'CtxTe'
-- | GADT for an /interpreting context/:
-- accumulating at each /lambda abstraction/
-- the @term@ of the introduced variable.
data CtxTe (term::K.Type -> K.Type) (hs::[K.Type]) where
        CtxTeZ :: CtxTe term '[]
        CtxTeS :: term t
               -> CtxTe term ts
               -> CtxTe term (t ': ts)
infixr 5 `CtxTeS`

-- ** Type 'TermDef'
-- | Convenient type alias to define a 'Term'.
type TermDef s vs t = forall src ss ts. Source src => SymInj ss s => Term src ss ts vs t

-- ** Type family 'Sym'
type family Sym (s::k) :: {-term-}(K.Type -> K.Type) -> Constraint

-- ** Type family 'Syms'
type family Syms (ss::[K.Type]) (term::K.Type -> K.Type) :: Constraint where
        Syms '[] term = ()
        Syms (Proxy s ': ss) term = (Sym s term, Syms ss term)

-- ** Type 'SymInj'
-- | Convenient type synonym wrapping 'SymPInj'
-- applied on the correct 'Index'.
type SymInj ss s = SymInjP (Index ss (Proxy s)) ss s

-- | Inject a given /symantic/ @s@ into a list of them,
-- by returning a function which given a 'TeSym' on @s@
-- returns the same 'TeSym' on @ss@.
symInj ::
 forall s ss ts t.
 SymInj ss s =>
 TeSym '[Proxy s] ts t ->
 TeSym ss ts t
symInj = symInjP @(Index ss (Proxy s))

-- *** Class 'SymPInj'
class SymInjP p ss s where
        symInjP :: TeSym '[Proxy s] ts t -> TeSym ss ts t
instance SymInjP Zero (Proxy s ': ss) (s::k) where
        symInjP (TeSym te) = TeSym te
instance SymInjP p ss s => SymInjP (Succ p) (Proxy not_s ': ss) s where
        symInjP (te::TeSym '[Proxy s] ts t) =
                case symInjP @p te :: TeSym ss ts t of
                 TeSym te' -> TeSym te'

-- * Class 'Sym_Lambda'
class Sym_Lambda term where
        -- | /Function application/.
        apply :: term ((a -> b) -> a -> b)
        default apply :: Sym_Lambda (UnT term) => Trans term => term ((a -> b) -> a -> b)
        apply = trans apply

        -- | /Lambda application/.
        app :: term (a -> b) -> (term a -> term b); infixr 0 `app`
        default app :: Sym_Lambda (UnT term) => Trans term => term (arg -> res) -> term arg -> term res
        app = trans2 app

        -- | /Lambda abstraction/.
        lam :: (term a -> term b) -> term (a -> b)
        default lam :: Sym_Lambda (UnT term) => Trans term => (term arg -> term res) -> term (arg -> res)
        lam f = trans $ lam (unTrans . f . trans)

        -- | Convenient 'lam' and 'app' wrapper.
        let_ :: term var -> (term var -> term res) -> term res
        let_ x f = lam f `app` x

        -- | /Lambda abstraction/ beta-reducable without duplication
        -- (i.e. whose variable is used once at most),
        -- mainly useful in compiled 'Term's
        -- whose symantics are not a single 'term'
        -- but a function between 'term's,
        -- which happens because those are more usable when used as an embedded DSL.
        lam1 :: (term a -> term b) -> term (a -> b)
        default lam1 :: Sym_Lambda (UnT term) => Trans term => (term a -> term b) -> term (a -> b)
        lam1 = lam

        -- | /Qualification/.
        --
        -- Workaround used in 'readTermWithCtx'.
        qual :: proxy q -> term t -> term (q #> t)
        default qual :: Sym_Lambda (UnT term) => Trans term => proxy q -> term t -> term (q #> t)
        qual q = trans1 (qual q)

lam2 :: Sym_Lambda term => (term a -> term b -> term c) -> term (a -> b -> c)
lam3 :: Sym_Lambda term => (term a -> term b -> term c -> term d) -> term (a -> b -> c -> d)
lam4 :: Sym_Lambda term => (term a -> term b -> term c -> term d -> term e) -> term (a -> b -> c -> d -> e)
lam2 f = lam1 $ lam1 . f
lam3 f = lam1 $ lam2 . f
lam4 f = lam1 $ lam3 . f

-- Interpreting
instance Sym_Lambda Eval where
        apply  = Eval ($)
        app    = (<*>)
        lam f  = Eval (unEval . f . Eval)
        lam1   = lam
        qual _q (Eval t) = Eval $ Qual t
        let_ x f = f x -- NOTE: like flip ($)
instance Sym_Lambda View where
        apply = View $ \_po _v -> "($)"
        app (View a1) (View a2) = View $ \po v ->
                pairIfNeeded pairParen po op $
                a1 (op, SideL) v <> " " <> a2 (op, SideR) v
                where op = infixN 10
        lam f = View $ \po v ->
                let x = "x" <> Text.pack (show v) in
                pairIfNeeded pairParen po op $
                "\\" <> x <> " -> " <>
                unView (f (View $ \_po _v -> x)) (op, SideL) (succ v)
                where op = infixN 1
        lam1 = lam
        qual _q (View t) = View t -- TODO: maybe print q
        let_ x f =
                View $ \po v ->
                        let x' = "x" <> Text.pack (show v) in
                        pairIfNeeded pairParen po op $
                        "let" <> " " <> x' <> " = "
                         <> unView x (infixN 0, SideL) (succ v) <> " in "
                         <> unView (f (View $ \_po _v -> x')) (op, SideL) (succ v)
                where op = infixN 1
instance (Sym_Lambda r1, Sym_Lambda r2) => Sym_Lambda (Dup r1 r2) where
        apply = dup0 @Sym_Lambda apply
        app   = dup2 @Sym_Lambda app
        lam f = dup_1 lam_f `Dup` dup_2 lam_f
                where lam_f = lam f
        lam1 = lam
        qual q = dup1 @Sym_Lambda (qual q)