{-# LANGUAGE RecordWildCards, DisambiguateRecordFields, GADTs, KindSignatures, GeneralizedNewtypeDeriving, TypeSynonymInstances #-}

module DerivationTrees.CPTS where

import Prelude hiding (abs, pi)
import Data.List
import DerivationTrees
import DerivationTrees.Basics
import Data.Monoid
import qualified DerivationTrees.ManualPTS as M
import Debug.Trace

class Unifyable a where
    (===) :: a -> a -> Bool

class TeXable a where
    texify :: a -> TeX

instance Show (a -> b) where
    show _ = "<fct>"


data V = V Name | Unbound
 deriving (Eq, Show)

type Name = String
type Sort = String
type Colour = String


data Binding :: * where
   (:-) :: V -> Term -> Binding
   Base :: Name -> Binding
   Mult :: Binding -> Binding
 deriving (Eq, Show)

data Term :: * where
  Lam :: Colour -> Binding -> Term -> Term
  Pi  :: Colour -> Binding -> Term -> Term
  App :: Colour -> Term -> Term -> Term
  Sub :: Term -> Binding -> Term
  Var :: Name -> Term
  Con :: Name -> Term
  Sor :: Sort -> Term
  Many :: Term -> Term
 deriving (Eq, Show)

-- Short notation (for colourless)
lam    = Lam ""
(~~>)  = Pi  ""
($$)   = App ""
subs   = Sub 
var = Var
(|=>) = Sub

infixl |=>


-- TODO: BIG FAT HACK!
dS :: (Name,Term) -> Term -> Term
dS (n,s) t | s == t = Var n
dS s (Sub t b) = t
dS s (App k f a) = App k (dS s f) (dS s a)
dS s (Pi k f a) = Pi k f (dS s a)
dS s (Many t) = Many (dS s t)
dS s t = t

-- Helper for non-dependent arrow
(Many t) --> u = Mult (Unbound :- t) ~~> u
t --> u = (Unbound :- t) ~~> u

data Jug = Jug {value :: Term, typ :: Term, env :: Env}
 deriving Show

data Drv :: * where
  Ax :: Drv
  St :: Sort -> Drv -> Drv
  Wk :: Int -> Sort -> Drv -> Drv -> Drv
  Ab :: Sort -> Drv -> Drv -> Drv
  Ap :: Binding -> Drv -> Drv -> Drv
  Pr :: Colour -> (Sort,Sort) -> Drv -> Drv -> Drv
  Co :: Sort -> Term -> Drv -> Drv -> Drv
  An :: (Derivation -> Derivation) -> Drv -> Drv -- Arbitrary annotation / change
  Ln :: Link -> Drv -> Drv
  -- Ma :: Drv -> Drv
  -- Re :: TeX -> Jug -> Drv -> Drv
 deriving Show

wk = Wk 0  

type Env = [Binding]

unbound :: Binding -> Bool
unbound (Mult x) = unbound x
unbound (Unbound :- _) = True
unbound (_ :- _) = False

bound = not . unbound


instance Unifyable V where
    Unbound === _ = True
    _ === Unbound  = True
    x === y = x == y

instance Unifyable Term where
    (===) = (==)

instance Unifyable Binding where
    (x :- t) === (x' :- t') = (x === x') && (t === t')
    Mult x === Mult x' = x === x'
    x === x' = False

instance TeXable Sort where
    texify ('~':v) = M.tso (TeX v)
    texify v = TeX v


interpV :: V -> TeX
interpV (V v) = texify v
interpV (Unbound) = TeX "?"

interpB :: Binding -> TeX
interpB = interpB' (tex "IS" [mempty])

interpB' :: TeX -> Binding -> TeX
interpB' ass (x :- t) = interpV x <> ass <> interpT ColonRhs t
interpB' ass (Base x) = texify x
interpB' ass (Mult x) = M.many (interpB' ass x)


col :: Colour -> String -> [TeX] -> TeX
col "" macro args = tex macro args
col c  macro args = tex (macro ++ "C") (TeX c:args)

data Ctx = ArrowLhs | TopLvl | ColonRhs | BinderRhs | ApplLhs | ApplRhs | SubsLhs

havePrn :: Ctx -> Term -> Bool
havePrn SubsLhs  (Sub _ _) = False
havePrn ApplLhs (Sub _ _) = False
havePrn ApplLhs (App _ _ _) = False

havePrn TopLvl _ = False
havePrn BinderRhs _ = False
havePrn ColonRhs (Lam _ _ _) = True
havePrn ColonRhs (Pi _ b _) | bound b = True
havePrn ColonRhs _ = False 
havePrn _ _ = True

parn context t = if havePrn context t then paren else id

interpT :: Ctx -> Term -> TeX

interpT d (App _ (Var ('\\':macroName)) t)    = tex macroName [interpT TopLvl t]

interpT d e@(Lam c a b)             = parn d e $ col c "LAMBDA" [interpB a, interpT BinderRhs b]
interpT d e@(Pi c a b) | unbound a  = parn d e $ col c "ARROW"  [interpT ArrowLhs (bToT a), (interpT BinderRhs b)]
interpT d e@(Pi c a b)              = parn d e $ col c "FORALL" [interpB a, interpT BinderRhs b]
interpT d e@(App c f a)             = parn d e $ col c "APP"    [interpT ApplLhs f, interpT ApplRhs a]
interpT d e@(Sub t b)               = parn d e $ (interpT SubsLhs t <> brack (interpB' (tex "mapsto" [mempty]) b))
interpT d e@(Var x)                     = texify x
interpT d e@(Con x)                     = texify x
interpT d e@(Many t)                    = M.many (interpT d t)
interpT d e@(Sor s)                     = texify s

instance TeXable Term where
    texify t = interpT ColonRhs t

(+:) :: Binding -> Env -> Env
b +: e | unbound b = e
b +: e = b : e

interpE :: Env -> [TeX]
interpE e = (map interpB $ reverse e)

renam :: Term -> TeX
renam (Sor x) = texify x
renam x = (interpT ApplRhs x ! TeX "i")

-- nterpret the non-context part of an assertion.
interpA :: Term -> Term -> (TeX,TeX)
interpA (Many x) (Many t) = (renam x, renam t)
interpA (Many x) (Sor t) = (renam x, texify t)
interpA x t = (texify x,texify t)

interpJ :: Jug -> TeX
interpJ (Jug v t env) = M.assert (interpE env) v' t'
    where (v',t') = interpA v t                 

check :: Bool -> String -> (a -> a)
check True _ = id
check False msg = error msg

-- Binding to 1 type.
bToT :: Binding -> Term
bToT (Mult b) = Many (bToT b)
bToT (_ :- t) = t
bToT _ = error "invalid binding for extraction"

-- Reconstruct a binding, assuming the 1st term is a variable-like thing.
rebind :: Term -> Term -> Binding
rebind (Many x) (Many t) = Mult (rebind x t)
rebind (Many x) (Sor s) = Mult (rebind x (Sor s))
rebind (Var x) t = V x :- t
rebind x t = error $ "Expected variable, but got: " ++ show x

push (Many (App c f a)) = App c (push (Many f)) (push (Many a))
push t = t

subsOf (V x :- _ , t)  = (x,t)
subsOf (Mult x,Many t) = subsOf (x,t)
subsOf _ = ("no variable",Con "no term")


jug' :: Term -> Sort -> Env -> Jug
jug' (Many x) s env = Jug (Many x) (Many (Sor s)) env
jug' x s env = Jug x (Sor s) env

rtext = rule . text

rm idx [] = []
rm 0 (x:xs) = xs
rm n (x:xs) = x : rm (n-1) xs

stp :: Drv -> Jug -> (TeX -> Rule (), [(Drv,Jug)])
stp Ax _ = (\c -> (rtext "" c) {-DerivationTrees.style = None-}, [])
stp (St s d) (Jug v t (b':env)) = check (b === b') ("Start: bindings do not match: " ++ show (b, b'))
   $  (rtext "st", [(d, Jug t (Sor s) env)])
  where b = rebind v t
stp (Wk idx s rest d)  (Jug v t env) = (rtext label, [(rest,Jug v t env'), (d, jug' (bToT b) s env')])
   where env' = rm idx env
         b = env !! idx
         label | idx == 0 = "wk"
               | otherwise = "wk (" ++ show idx ++ ")" 

stp (Co s typ d e) (Jug v t env) = 
   (rtext "conv", [(d,Jug v typ env), (e, jug' t s env)])
  where betaCond = rule mempty (interpT TopLvl typ <> TeX "=_\\beta " <> interpT TopLvl t)
stp (Ab s d e) (Jug (Lam k a bo) (Pi k' a' b) env) = check (k == k') ("Abstraction: colours do not match" ++ show (k,k')) $
                                                     check (a === a') ("Abstraction: bindings do not match: " ++ show (a, a')) $
  (rtext "abs", [(d,Jug bo b (a:env)), (e,Jug (Pi k a' b) (Sor s) env)])
stp (Pr k (s1,s2) d e) (Jug (Pi k' a b) (Sor s3) env) = check (k == k') ("Product: colours do not match" ++ show (k,k')) $
 (rule name, [(d,jug' (bToT a) s1 env),(e,jug' b s2 (a+:env))])
 where (cname,cargs) = if s2 == s3 then ("ptsrule",[s1,s2]) else ("prodrule",[s1,s2,s3])
       (fname,fargs) = if null k then (id,id) else ((++ "c"), (k:))
       name = tex (fname cname) (map texify $ fargs cargs)
       
-- stp (Ma d) (Jug v t env) = (rule (text "many"),  [(d,Jug (push v) (push t) env)])

-- if a substitution is there, eat it.
stp (Ap b d e) (Jug (App k f a) t env) = trace (show subst) $
    (rtext "app", [(d,Jug f (Pi k b (dS subst t)) env),(e,Jug a (bToT b) env)])

    where subst = subsOf (b, a)
stp (Ap (Mult b) d e) (Jug (Many (App k f a)) (Many t) env) = 
    (rtext "app", [(d,Jug (Many f) (Many (Pi k b t)) env),(e,Jug (Many a) (Many (bToT b)) env)])
    -- Used only once in the product rule (many (f x))



-- stp (An f d) j = (rtext "usr", [(d,j)])
-- stp (Re t j d) j' = (rule t, [(d,j)])


stp r j = error $ "Unsupported: rule " ++ show r ++ " on judgement " ++ show j

interp :: Drv -> Jug -> Derivation
interp (An f x) j = f $ interp x j
interp (Ln l x) j = interp x j
interp d j =  Node (r (interpJ j)) [link d' ::> interp d' j' | (d',j') <- x]
    where (r,x) = (stp d j)
          link (Ln l _) = l
          link _ = defaultLink

-------------------

dL = Ln (Delayed {align = LeftA})
dC = Ln (Delayed {align = CenterA})
dR = Ln (Delayed {align = RightA})

delay :: Alignment -> Int -> Drv -> Drv
delay align steps = Ln (Link {DerivationTrees.Basics.style = Dotted, ..})
         where label=mempty

detach l ident = Ln (Detached {label = text ('(' : l ++ ")"),..})

halt :: String -> Drv
halt txt = An (haltDrv (text txt)) Ax

named :: String -> Drv -> Drv
-- named txt = An (\(Node r ps) -> Node r [defaultLink ::> Node (rule mempty (text txt)) {DerivationTrees.style = Double} ps])
named txt = An (\(Node r ps) -> Node r {delimiter = text txt} ps)

abort :: Drv
abort = An abortDrv Ax -- (error "abort: Should stop")