{-# LANGUAGE StandaloneDeriving #-}

module GLL.Types.Grammar where

import qualified    Data.Map as M
import qualified    Data.IntMap as IM
import qualified    Data.Set as S 
import qualified    Data.IntSet as IS 
import              Data.List (delete, (\\), elemIndices, findIndices)
import GLL.Types.Abstract

token_length :: Token -> Int
token_length (Char _) = 1
token_length (EOS)    = 1
token_length (Epsilon)= 0
token_length (Int _)  = error "find out nr of digits in int"
token_length (Bool b) = maybe (error "no length for bool tokens") 
                            (\b -> if b then 4 else 5) b-- supposing "True" and "False"
token_length (String s) = maybe (error "no length for string tokens") length s
token_length (Token _ str) = maybe (error "no length of tokens") length str

-- make sure that tokens are equal independent of their character level value
type SlotL      = (Slot, Int)                   -- slot with left extent
type PrL        = (Alt, Int)                     -- Production rule with left extent
type NtL        = (Nt, Int)                     -- Nonterminal with left extent

-- SPPF
type SPPF       =   (SymbMap, ImdMap, PackMap, EdgeMap, IDMap)
type PackMap    =   IM.IntMap (IM.IntMap (IM.IntMap (M.Map Alt IS.IntSet)))
type SymbMap    =   IM.IntMap (IM.IntMap (S.Set Symbol))
type ImdMap     =   IM.IntMap (IM.IntMap (S.Set Slot))
type EdgeMap    =   M.Map SPPFNode (S.Set SPPFNode)
type IDMap      =   (IDFMap,IDTMap)
type IDFMap     =   IM.IntMap SPPFNode
type IDTMap     =   M.Map SPPFNode Int
data SPPFNode   =   SNode (Symbol, Int, Int) 
                |   INode (Slot, Int, Int)
                |   PNode (Slot, Int, Int, Int)
                |   Dummy
    deriving (Ord, Eq)
type SNode      = (Symbol, Int, Int)
type PNode      = (Alt, [Int])
type SEdge      = M.Map SNode (S.Set PNode)
type PEdge      = M.Map PNode (S.Set SNode)

emptySPPF :: SPPF
emptySPPF = (IM.empty, IM.empty, IM.empty, M.empty, (IM.empty, M.empty))

pNodeLookup :: SPPF -> ((Alt, Int), Int, Int) -> Maybe [Int]
pNodeLookup (_,_,pMap,_,_) ((alt,j),l,r) = maybe Nothing inner $ IM.lookup l pMap
    where   inner   = maybe Nothing inner2 . IM.lookup r
            inner2  = maybe Nothing inner3 . IM.lookup j
            inner3  = maybe Nothing (Just . IS.toList) . M.lookup alt

pMapInsert :: SPPFNode -> SPPFNode -> SPPF -> SPPF
pMapInsert f t (sMap,iMap,pMap,eMap,idMap) =  
    let pMap' = case f of 
                    PNode ((Slot x alpha beta), l, k, r) ->   
                        add (Alt x (alpha++beta)) (length alpha) l r k
                    _   -> pMap
    in (sMap,iMap,pMap',eMap,idMap)
 where add alt j l r k = IM.alter addInnerL l pMap
        where addInnerL mm = case mm of 
                             Nothing -> Just singleRJAK
                             Just m ->  Just $ IM.alter addInnerR r m
              addInnerR mm = case mm of
                             Nothing -> Just singleJAK
                             Just m  -> Just $ IM.alter addInnerJ j m
              addInnerJ mm = case mm of
                             Nothing -> Just singleAK
                             Just m  -> Just $ M.insertWith IS.union alt singleK m
              singleRJAK= IM.fromList [(r, singleJAK)]
              singleJAK = IM.fromList [(j, singleAK)]
              singleAK  = M.fromList [(alt, singleK)]
              singleK   = IS.singleton k


sNodeLookup :: SPPF -> (Symbol, Int, Int) -> Bool 
sNodeLookup (sm,_,_,_,_) (s,l,r) = maybe False inner $ IM.lookup l sm
    where   inner   = maybe False (S.member s) . IM.lookup r

sNodeInsert :: SPPFNode -> SPPFNode -> SPPF -> SPPF
sNodeInsert f t (sMap,iMap,pMap,eMap,idMap) = 
    let sMap' = case f of
                SNode (s, l, r) -> newt (add s l r sMap)
                _               -> newt sMap
    in (sMap',iMap,pMap,eMap,idMap)
 where newt sMap = case t of 
                   (SNode (s, l, r)) -> add s l r sMap
                   _                 -> sMap
       add s l r sMap = IM.alter addInnerL l sMap
        where addInnerL mm = case mm of 
                             Nothing -> Just singleRS
                             Just m  -> Just $ IM.insertWith (S.union) r singleS m
              singleRS     = IM.fromList [(r, singleS)]
              singleS      = S.singleton s
 
sNodeRemove :: SPPF -> (Symbol, Int, Int) -> SPPF 
sNodeRemove (sm,iMap,pMap,eMap,idMap) (s,l,r) = 
    (IM.adjust inner l sm, iMap,pMap,eMap,idMap)
    where   inner   = IM.adjust ((s `S.delete`)) r

iNodeLookup :: SPPF -> (Slot, Int, Int) -> Bool 
iNodeLookup (_,iMap,_,_,_) (s,l,r) = maybe False inner $ IM.lookup l iMap
    where   inner   = maybe False (S.member s) . IM.lookup r

iNodeInsert :: SPPFNode -> SPPFNode -> SPPF -> SPPF
iNodeInsert f t (sMap,iMap,pMap,eMap,idMap) = 
    let iMap' = case f of
                INode (s, l, r) -> newt (add s l r iMap)
                _               -> newt iMap
    in (sMap,iMap',pMap,eMap,idMap)
 where newt iMap = case t of 
                   (INode (s, l, r)) -> add s l r iMap
                   _                 -> iMap
       add s l r iMap = IM.alter addInnerL l iMap
        where addInnerL mm = case mm of 
                             Nothing -> Just singleRS
                             Just m  -> Just $ IM.insertWith (S.union) r singleS m
              singleRS     = IM.fromList [(r, singleS)]
              singleS      = S.singleton s
 
iNodeRemove :: SPPF -> (Slot, Int, Int) -> SPPF 
iNodeRemove (sMap,iMap,pMap,eMap,idMap) (s,l,r) = 
    (sMap,IM.adjust inner l iMap,pMap,eMap,idMap)
    where   inner   = IM.adjust ((s `S.delete`)) r

eMapInsert :: SPPFNode -> SPPFNode -> SPPF -> SPPF
eMapInsert f t (sMap,iMap,pMap,eMap,idMap) = 
    (sMap,iMap,pMap,M.insertWith (S.union) f (S.singleton t) eMap,idMap)

idMapInsert :: SPPFNode -> SPPFNode -> SPPF -> (SPPF, Int, Int)
idMapInsert f t (sMap,iMap,pMap,eMap,(idfMap,idtMap)) =
    ((sMap,iMap,pMap,eMap,(idfMap'',idtMap'')),fkey,tkey)
 where  idx     | IM.null idfMap = 0
                | otherwise      = fst (IM.findMax idfMap)
        (fkey,idfMap',idtMap')   = newKey f (idx+1) idfMap  idtMap
        (tkey,idfMap'',idtMap'') = newKey t (idx+2) idfMap' idtMap'
        newKey :: SPPFNode -> Int -> IDFMap -> IDTMap -> (Int,IDFMap,IDTMap)
        newKey n i mf mt = case M.lookup n mt of
                            Nothing -> (i,IM.insert i n mf,M.insert n i mt)
                            Just j  -> (j,mf,mt)
-- helpers for Ucal
inU (slot,l,i) u = maybe False inner $ IM.lookup l u
         where inner = maybe False (S.member slot) . IM.lookup i

toU (slot,l,i) u = IM.alter inner l u
 where inner mm = case mm of
                Nothing -> Just $ singleIS
                Just m  -> Just $ IM.insertWith S.union i singleS m
       singleIS = IM.fromList [(i,singleS)]
       singleS  = S.singleton slot


showD dv = unlines [ show f ++ " --> " ++ show t | (f,ts) <- M.toList dv, t <- ts ]
showG dv = unlines [ show f ++ " --> " ++ show t | (f,ts) <- M.toList dv, t <- ts ]
showP pMap = unlines [ show ((a,j),l,r) ++ " --> " ++ show kset
                            | (l,r2j) <- IM.assocs pMap, (r,j2a) <- IM.assocs r2j
                            , (j,a2k) <- IM.assocs j2a, (a,kset) <- M.assocs a2k ]
showS sMap = unlines [ show (l,r) ++ " --> " ++ show (sset)
                            | (l,r2s) <- IM.assocs sMap, (r,sset) <- IM.assocs r2s]
-- TODO change to Map
showSPPF :: ([(SNode,PNode)],[(PNode,SNode)]) -> String
showSPPF (se,pe) = "\n"++ (unlines $ map ppPn $ pe) ++ "\n" ++
                          (unlines $ map ppSn $ se)
     where ppPn ((Alt x alpha, rs), sn) = ppRhs (x,alpha,rs) ++ " --> " ++ show sn
           ppSn (sn, (Alt x alpha, rs)) = show sn ++ " --> " ++ ppRhs (x,alpha,rs)
           ppRhs (x, alpha, rs) =  "(" ++ x ++ " ::= "++ (foldr ((++) . ppS) "" alpha) ++ 
                                   foldr (\i -> (("," ++ show i) ++)) "" rs ++ ")"
           ppS (Nt s)           = s
           ppS (Term Epsilon)   = "''"
           ppS (Term (Char c))  = [c]
           ppS (Term (Token t _)) = t
           ppS (Term (Int i))     = maybe "Int" show i
           ppS (Term (Bool b))    = maybe "Bool" show b
           ppS (Term (String s))  = maybe "String" id s


-- smart constructors
tokenT :: Token -> Symbol
tokenT t = Term $ t
charT c = Term $ Char c
nT    x = Nt x
charS   = map Char 
epsilon = [Term Epsilon]

type ProdMap   = M.Map Nt [Alt]
type PrefixMap = M.Map (Alt,Int) ([Token], Maybe Nt)
type SelectMap = M.Map (Nt, [Symbol]) (S.Set Token)
type FirstMap  = M.Map Nt (S.Set Token)
type FollowMap = M.Map Nt (S.Set Token)

fixedMaps :: Nt -> [Alt] -> (ProdMap, PrefixMap, FirstMap, FollowMap, SelectMap) 
fixedMaps s prs = let f = (prodMap, prefixMap, firstMap, followMap, selectMap)
                    in f `seq` f
 where
    prodMap = M.fromListWith (++) [ (x,[pr]) | pr@(Alt x _) <- prs ]

    prefixMap :: PrefixMap 
    prefixMap = M.fromList 
        [ ((pr,j), (tokens,msymb)) | pr@(Alt x alpha) <- prs
                                   , (j,tokens,msymb) <- prefix x alpha ]
     where
        prefix x alpha = map rangePrefix ranges
         where  js          = (map ((+) 1) (findIndices isNt alpha))
                ranges      = zip (0:js) (js ++ [length alpha])
                rangePrefix (a,z) | a >= z = (a,[],Nothing)
                                  | a <  z = 
                    let init = map ((\(Term t) -> t) . (alpha !!)) [a .. (z-2)]
                        last = alpha !! (z-1)
                     in case last of    
                           Nt nt     -> (a,init, Just nt)
                           Term t    -> (a,init ++ [t], Nothing)

    firstMap = M.fromList [ (x, first_x [] x) | x <- M.keys prodMap ]

    first_x :: [Nt] -> Nt -> (S.Set Token) -- filter prevents self-calls
    first_x ys x           = S.unions [ first_alpha (x:ys) rhs | Alt _ rhs <- prodMap M.! x ]
 
    selectMap :: SelectMap 
    selectMap = M.fromList [ ((x,alpha), select alpha x) | Alt x rhs <- prs
                           , alpha <- split rhs ]
     where
        split rhs = foldr op [] js
         where op j acc     = drop j rhs : acc
               js           = 0 : findIndices isNt rhs

        -- TODO store intermediate results
        select :: [Symbol] -> Nt -> (S.Set Token)
        select alpha x      = res 
                where   firsts  = first_alpha [] alpha
                        res     | Epsilon `S.member` firsts     = S.delete Epsilon firsts `S.union` (followMap M.! x)
                                | otherwise                 = firsts

    -- list of symbols to get firsts from + non-terminal to ignore
    -- TODO store in map
    first_alpha :: [Nt] -> [Symbol] -> (S.Set Token)
    first_alpha ys []      = S.singleton Epsilon
    first_alpha ys (x:xs)  =  
        case x of
          Term Epsilon    -> first_alpha ys xs
          Term tau        -> S.singleton tau
          Nt x            ->  
            let fs | x `elem` ys       = S.empty 
                   | otherwise        = first_x (x:ys) x
              in  if x `S.member` nullableSet
                        then (S.delete Epsilon fs) `S.union` first_alpha (x:ys) xs 
                        else fs

    followMap :: M.Map Nt (S.Set Token)
    followMap = M.fromList [ (x, follow [] x) | x <- M.keys prodMap ] 
 
    follow :: [Nt] -> Nt -> (S.Set Token)
    follow ys x = S.unions (map fw (maybe [] id $ M.lookup x localMap))
                            `S.union` (if x == s then S.singleton EOS else S.empty)
             where fw (y,ss) = 
                        let ts  = S.delete Epsilon (first_alpha [] ss)
                            fs  = follow (x:ys) y 
                         in if nullable_alpha [] ss && not (x `elem` (y:ys))
                               then ts `S.union` fs 
                               else ts


    localMap = M.fromListWith (++)
                [ (x,[(y,tail)]) | x <- M.keys prodMap, (Alt y rhs) <- prs
                                 , tail <- tails x rhs ]
     where
        tails x symbs = [ drop (index + 1) symbs | index <- indices ]
         where indices = elemIndices (Nt x) symbs
                     
    nullableSet :: S.Set Nt
    nullableSet  = S.fromList $ [ x | x <- M.keys prodMap, nullable_x [] x ]

    -- a nonterminal is nullable if any of its alternatives is empty 
    nullable_x :: [Nt] -> Nt -> Bool
    nullable_x ys x      = or [ nullable_alpha (x:ys) rhs 
                              | (Alt _ rhs) <- prodMap M.! x ] 

    -- TODO store in map
    nullable_alpha :: [Nt] -> [Symbol] -> Bool
    nullable_alpha ys [] = True
    nullable_alpha ys (s:ss) =     
        case s of
            Nt nt      -> if nt `elem` ys 
                            then False --nullable only if some other alternative is nullable
                            else nullable_x ys nt && nullable_alpha (nt:ys) ss
            Term Epsilon -> True
            otherwise  -> False

-- some helpers
isNt (Nt _) = True
isNt _      = False

isTerm (Term _) = True
isTerm _        = False

isChar (Char _) = True
isChar _        = False 

deriving instance Show Grammar 
deriving instance Ord Slot
deriving instance Eq Slot
deriving instance Show Rule
deriving instance Show Alt
deriving instance Ord Alt
deriving instance Eq Alt
deriving instance Eq Symbol
deriving instance Ord Symbol

{-
instance Show Symbol where
    show (Nt nt) = "Nt " ++ show nt
    show (Term t) = "Term " ++ show t
    show (Error t1 t2) = "Error " ++ show t1 ++ " " ++ show t2

instance Eq Symbol where
    (Nt nt) == (Nt nt') = nt == nt'
    (Term t) == (Term t') = t == t'
    (Error t1 t2) == (Error t1' t2') = (t1,t2) == (t1',t2')

instance Ord Symbol where
    (Nt nt) `compare` (Nt nt') = nt `compare` nt
    (Nt _)  `compare`  _       = LT
    _  `compare`  (Nt _)       = GT
    (Term t) `compare` (Term t') = t `compare` t'
    (Term _) `compare` _         = LT
    _        `compare` (Term _)   = GT
    (Error t1 t2) `compare` (Error t1' t2') = (t1,t2) `compare` (t1',t2')
-}

instance Eq Token where
    Token k _   == Token k' _   = k' == k
    Char c      == Char c'      = c' == c
    EOS         == EOS          = True
    Epsilon     == Epsilon      = True
    String _    == String _     = True
    Int _       == Int _        = True
    Bool _      == Bool _       = True
    _           == _            = False

instance Ord Token where
    EOS         `compare` EOS           = EQ 
    EOS         `compare` _             = LT
    _           `compare` EOS           = GT
    Epsilon     `compare` Epsilon       = EQ
    Epsilon     `compare` _             = LT
    _           `compare` Epsilon       = GT
    String _    `compare` String _      = EQ
    String _    `compare` _             = LT
    _           `compare` String _      = GT
    Int _       `compare` Int _         = EQ
    Int _       `compare` _             = LT
    _           `compare` Int _         = GT
    Bool _      `compare` Bool _        = EQ
    Bool _      `compare` _             = LT
    _           `compare` Bool _        = GT
    Char c      `compare` Char c2       = c `compare` c2
    Char _      `compare` _             = LT
    _           `compare` Char c        = GT
    Token k _   `compare` Token k2 _    = k `compare` k2

instance Show Token where
    show (Char c) = ['\'',c,'\'']
    show (EOS)    = "$"
    show Epsilon  = "#"
    show (Int mi) = "int" 
    show (Bool mb)= "bool"
    show (String ms) = "string"
    show (Token t ms) = t 

instance Show Slot where
    show (Slot x alpha beta) = x ++ " ::= " ++ showRhs alpha ++ "." ++ showRhs beta    
     where  showRhs [] = ""
            showRhs ((Term t):rhs) = show t ++ showRhs rhs
            showRhs ((Nt x):rhs)   = x ++ showRhs rhs

instance Show Symbol where
    show (Nt s)         = s
    show (Term t)       = show t
    show (Error e _)    = error ("show Error symbol")