{-# LANGUAGE GADTs #-}
module Language.Symantic.Compiling.Read where

import Control.Arrow (left)
import qualified Data.Kind as K

import Language.Symantic.Grammar
import Language.Symantic.Typing

import Language.Symantic.Compiling.Term
import Language.Symantic.Compiling.Module
import Language.Symantic.Compiling.Beta

-- * Type 'ReadTerm'
-- | Convenient type alias for 'readTerm' and related functions.
type ReadTerm src ss ts =
 Source src =>
 SourceInj (TypeVT src) src =>
 SourceInj (KindK src) src =>
 SourceInj (AST_Type src) src =>
 Constable (->) =>
 CtxTy src ts ->
 AST_Term src ss ->
 Either (Error_Term src) (TermVT src ss ts)

-- | Read a 'TermVT' from and 'AST_Term'.
readTerm :: forall src ss ts. ReadTerm src ss ts
readTerm ctxTy ast = do
        ts <- go ctxTy `traverse` ast
        errorInj `left` betaTerms ts
        where
        go ::
         forall ts'.
         CtxTy src ts' ->
         Token_Term src ss ->
         Either (Error_Term src) (TermVT src ss ts')
        go _ts (Token_Term (TermAVT te))  = Right $ TermVT te
        go _ts (Token_TermVT te)          = Right $ liftTermVT te
        go ts  (Token_Term_Var _src name) = teVar name ts
        go _ts (Token_Term_App _src)      = Right $ TermVT teApp
        go ts  (Token_Term_Abst _src name_arg ast_ty_arg ast_body) = do
                TypeVT ty_arg <- errorInj `left` readType ast_ty_arg
                when_EqKind (kindInj @K.Type) (kindOf ty_arg) $ \Refl ->
                        case lenVars ty_arg of
                         LenS{} -> Left $ Error_Term_polymorphic $ TypeVT ty_arg
                         LenZ | (TypeK qa, TypeK ta) <- unQualTy ty_arg -> do
                                TermVT (Term qr tr (TeSym res)) <- readTerm (CtxTyS name_arg ta ts) ast_body
                                let (qa', qr') = appendVars qa qr
                                let (ta', tr') = appendVars ta tr
                                Right $
                                        case (proveConstraint qa', proveConstraint qr') of
                                         -- NOTE: remove provable Constraints to keep those smaller.
                                         (Just Dict, Just Dict) -> TermVT $
                                                Term (noConstraintLen (lenVars tr')) (ta' ~> tr') $
                                                TeSym $ \c -> lam $ \arg -> res (arg `CtxTeS` c)
                                         (Just Dict, Nothing) -> TermVT $
                                                Term qr' (ta' ~> tr') $
                                                TeSym $ \c -> lam $ \arg -> res (arg `CtxTeS` c)
                                         (Nothing, Just Dict) -> TermVT $
                                                Term qa' (ta' ~> tr') $
                                                TeSym $ \c -> lam $ \arg -> res (arg `CtxTeS` c)
                                         (Nothing, Nothing) -> TermVT $
                                                Term (qa' # qr') (ta' ~> tr') $
                                                TeSym $ \c -> lam $ \arg -> res (arg `CtxTeS` c)
        go ts (Token_Term_Let _src name ast_arg ast_body) = do
                TermVT (Term qa ta (TeSym arg)) <- readTerm ts ast_arg
                case lenVars ta of
                 LenS{} -> Left $ Error_Term_polymorphic $ TypeVT (qa #> ta)
                 LenZ -> do
                        TermVT (Term qr tr (TeSym res)) <- readTerm (CtxTyS name ta ts) ast_body
                        let (qa', qr') = appendVars qa qr
                        let tr' = allocVarsL (lenVars ta) tr
                        Right $
                                case (proveConstraint qa', proveConstraint qr') of
                                 -- NOTE: remove provable Constraints to keep those smaller.
                                 (Just Dict, Just Dict) -> TermVT $
                                        Term (noConstraintLen (lenVars tr')) tr' $
                                        TeSym $ \c -> let_ (arg c) $ \a -> res (a `CtxTeS` c)
                                 (Just Dict, Nothing) -> TermVT $
                                        Term qr' tr' $
                                        TeSym $ \c -> let_ (arg c) $ \a -> res (a `CtxTeS` c)
                                 (Nothing, Just Dict) -> TermVT $
                                        Term qa' tr' $
                                        TeSym $ \c -> let_ (arg c) $ \a -> res (a `CtxTeS` c)
                                 (Nothing, Nothing) -> TermVT $
                                        Term (qa' # qr') tr' $
                                        TeSym $ \c -> let_ (arg c) $ \a -> res (a `CtxTeS` c)

teVar ::
 forall ss src ts.
 Source src =>
 NameTe -> CtxTy src ts -> Either (Error_Term src) (TermVT src ss ts)
teVar name CtxTyZ = Left $ Error_Term_unknown name
teVar name (CtxTyS n ty _) | n == name =
        Right $ TermVT $ Term noConstraint ty $
                TeSym $ \(te `CtxTeS` _) -> te
teVar name (CtxTyS _n _typ ts') = do
        TermVT (Term q t (TeSym te)) <- teVar @ss name ts'
        Right $ TermVT $ Term q t $
                TeSym $ \(_ `CtxTeS` ts) -> te ts

teApp ::
 Source src => Constable (->) =>
 Term src ss ts '[Proxy (a::K.Type), Proxy (b::K.Type)] (() #> ((a -> b) -> a -> b))
teApp =
        Term noConstraint ((a ~> b) ~> a ~> b) $
        TeSym $ const apply
        where
        a :: Source src => Type src '[Proxy a, Proxy b] (a::K.Type)
        a = tyVar "a" $ varZ
        b :: Source src => Type src '[Proxy a, Proxy b] (b::K.Type)
        b = tyVar "b" $ VarS varZ

-- | Reduce number of 'Token_Term_App' in given 'AST_Term' by converting them into 'BinTree2'.
--
-- NOTE: 'Token_Term_App' exists only to handle unifix operators applied to arguments.
reduceTeApp :: AST_Term src ss -> AST_Term src ss
reduceTeApp (BinTree2 x y) =
        case reduceTeApp x of
         BinTree0 Token_Term_App{} `BinTree2` x' -> reduceTeApp x' `BinTree2` reduceTeApp y
         _ -> reduceTeApp x `BinTree2` reduceTeApp y
reduceTeApp (BinTree0 (Token_Term_Abst src n ty te)) = BinTree0 $ Token_Term_Abst src n ty (reduceTeApp te)
reduceTeApp (BinTree0 (Token_Term_Let  src n bo te)) = BinTree0 $ Token_Term_Let  src n (reduceTeApp bo) (reduceTeApp te)
reduceTeApp x@BinTree0{} = x

-- ** Type 'ReadTermCF'
-- | Like 'ReadTerm', but 'CtxTe'-free.
-- 
-- Useful in 'readTermWithCtx' to help GHC's type solver, which
-- "Cannot instantiate unification variable with a type involving foralls".
newtype ReadTermCF src ss
 =      ReadTermCF
 {    unReadTermCF :: forall ts. ReadTerm src ss ts }

-- | Like 'readTerm' but with given context, and no more.
readTermWithCtx ::
 Foldable f =>
 Source src =>
 SourceInj (TypeVT src) src =>
 SourceInj (KindK src) src =>
 SourceInj (AST_Type src) src =>
 Constable (->) =>
 f (NameTe, TermT src ss '[] '[]) ->
 AST_Term src ss ->
 Either (Error_Term src) (TermVT src ss '[])
readTermWithCtx env =
        readTermWithCtxClose $
        readTermWithCtxPush env readTerm

-- | Like 'readTerm' but with given context.
readTermWithCtxPush ::
 Foldable f =>
 f (NameTe, TermT src ss '[] '[]) ->
 (forall ts'. ReadTerm src ss ts') ->
 ReadTerm src ss ts
readTermWithCtxPush env readTe =
        unReadTermCF $ foldr
         (\t (ReadTermCF r) -> ReadTermCF $ readTermWithCtxPush1 t r)
         (ReadTermCF readTe) env

-- | Like 'readTerm' but with given 'TermT' pushed onto 'CtxTy' and 'CtxTe'.
readTermWithCtxPush1 ::
 (NameTe, TermT src ss '[] '[]) ->
 (forall ts'. ReadTerm src ss ts') ->
 ReadTerm src ss ts
readTermWithCtxPush1 (n, TermT (Term qn tn (TeSym te_n))) readTe ctxTy ast = do
        TermVT (Term q t (TeSym te)) <- readTe (CtxTyS n (qn #> tn) ctxTy) ast
        case proveConstraint qn of
         Nothing -> Left $ Error_Term_proofless $ TypeVT qn
         Just Dict ->
                Right $ TermVT $ Term q t $ TeSym $ \c ->
                        let cte = qual qn $ te_n CtxTeZ in
                        te $ cte `CtxTeS` c

-- | Close a 'ReadTerm' context.
readTermWithCtxClose ::
 (forall ts'. ReadTerm src ss ts') ->
 Source src =>
 SourceInj (TypeVT src) src =>
 SourceInj (KindK src) src =>
 SourceInj (AST_Type src) src =>
 Constable (->) =>
 AST_Term src ss ->
 Either (Error_Term src) (TermVT src ss '[])
readTermWithCtxClose readTe = readTe CtxTyZ

-- * Type 'CtxTy'
-- | /Typing context/
-- accumulating at each /lambda abstraction/
-- the 'Type' of the introduced variable.
-- It is built top-down from the closest
-- including /lambda abstraction/ to the farest.
-- It determines the 'Type's of 'CtxTe'.
data CtxTy src (ts::[K.Type]) where
        CtxTyZ :: CtxTy src '[]
        CtxTyS :: NameTe
               -> Type  src '[] t
               -> CtxTy src ts
               -> CtxTy src (t ': ts)
infixr 5 `CtxTyS`

appendCtxTy ::
 CtxTy src ts0 ->
 CtxTy src ts1 ->
 CtxTy src (ts0 ++ ts1)
appendCtxTy CtxTyZ c          = c
appendCtxTy (CtxTyS n t c) c' = CtxTyS n t $ appendCtxTy c c'

-- * Type 'Error_Term'
data Error_Term src
 =   Error_Term_unknown NameTe
 |   Error_Term_polymorphic (TypeVT src)
 |   Error_Term_qualified   (TypeVT src)
 |   Error_Term_proofless   (TypeVT src)
 |   Error_Term_Type (Error_Type src)
 |   Error_Term_Beta (Error_Beta src)
 {-   Error_Term_Con_Type (Con_Type src ss) -}
 {-   Error_Term_Con_Kind (Con_Kind src) -}
 deriving (Eq, Show)
instance ErrorInj (Error_Type src) (Error_Term src) where
        errorInj = Error_Term_Type
instance ErrorInj (Error_Beta src) (Error_Term src) where
        errorInj = Error_Term_Beta
instance ErrorInj (Con_Kind src) (Error_Term src) where
        errorInj = Error_Term_Type . errorInj

-- * Type 'SrcTe'
-- | A 'Source' usable when using 'readTerm'.
data SrcTe inp ss
 =   SrcTe_Less
 |   SrcTe_Input    (Span inp)
 |   SrcTe_AST_Term (AST_Term (SrcTe inp ss) ss)
 |   SrcTe_AST_Type (AST_Type (SrcTe inp ss))
 |   SrcTe_Kind     (KindK    (SrcTe inp ss))
 |   SrcTe_Type     (TypeVT   (SrcTe inp ss))
 |   SrcTe_Term
 deriving (Eq, Show)

type instance Source_Input (SrcTe inp ss) = inp

instance Source (SrcTe inp ss) where
        noSource = SrcTe_Less
instance SourceInj (Span inp) (SrcTe inp ss) where
        sourceInj = SrcTe_Input
instance SourceInj (AST_Term (SrcTe inp ss) ss) (SrcTe inp ss) where
        sourceInj = SrcTe_AST_Term
instance SourceInj (AST_Type (SrcTe inp ss)) (SrcTe inp ss) where
        sourceInj = SrcTe_AST_Type
instance SourceInj (KindK (SrcTe inp ss)) (SrcTe inp ss) where
        sourceInj = SrcTe_Kind
instance SourceInj (TypeVT (SrcTe inp ss)) (SrcTe inp ss) where
        sourceInj = SrcTe_Type