{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE BangPatterns #-}

module NLP.Adict.Trie
( TrieD
, Trie (..)
, unTrie
, child
, anyChild
, mkTrie
, setValue
, substChild
, insert

, size
, follow
, lookup
, fromLang
, fromList
, toList

, serialize
, deserialize
, implicitDAWG
) where

import Prelude hiding (lookup)
import Control.Applicative ((<$>), (<*>))
import Control.Monad ((>=>))
import Data.List (foldl')
import Data.Binary (Binary, get, put)
import qualified Data.Map as M

import NLP.Adict.DAWG.Node

-- | A 'Trie' with 'Maybe' values in nodes.
type TrieD a b = Trie a (Maybe b)

-- | A trie of words with character type @a@ and entry type @b@.  It can be
-- thought of as a map of type @[a] -> b@.
data Trie a b = Trie {
    -- | Value in the node.
    valueIn :: b,                  
    -- | Edges to subtries annotated with characters.
    edgeMap :: M.Map a (Trie a b)
    } deriving (Show, Eq, Ord)

instance Functor (Trie a) where
    fmap f Trie{..} = Trie (f valueIn) (fmap (fmap f) edgeMap)

instance (Ord a, Binary a, Binary b) => Binary (Trie a b) where
    put Trie{..} = do
        put valueIn
        put edgeMap
    get = Trie <$> get <*> get

unTrie :: Trie a b -> (b, [(a, Trie a b)])
unTrie t = (valueIn t, M.toList $ edgeMap t)
{-# INLINE unTrie #-}

child :: Ord a => a -> Trie a b -> Maybe (Trie a b)
child x Trie{..} = x `M.lookup` edgeMap
{-# INLINE child #-}

anyChild :: Trie a b -> [(a, Trie a b)]
anyChild = snd . unTrie
{-# INLINE anyChild #-}

mkTrie :: Ord a => b -> [(a, Trie a b)] -> Trie a b
mkTrie !v !cs = Trie v (M.fromList cs)
{-# INLINE mkTrie #-}

empty :: Ord a => TrieD a b
empty = mkTrie Nothing []
{-# INLINE empty #-}

setValue :: b -> Trie a b -> Trie a b
setValue !x !t = t { valueIn = x }
{-# INLINE setValue #-}

substChild :: Ord a => a -> Trie a b -> Trie a b -> Trie a b
substChild !x !trie !newChild =
    let how _ = Just newChild
        !edges = M.alter how x (edgeMap trie)
    in trie { edgeMap = edges }
{-# INLINABLE substChild #-}
{-# SPECIALIZE substChild
    :: Char
    -> Trie Char b
    -> Trie Char b
    -> Trie Char b #-}

insert :: Ord a => [a] -> b -> TrieD a b -> TrieD a b
insert [] v t = setValue (Just v) t
insert (x:xs) v t = substChild x t . insert xs v $
    case child x t of
        Just t' -> t'
        Nothing -> empty
{-# INLINABLE insert #-}
{-# SPECIALIZE insert
    :: String -> b
    -> TrieD Char b
    -> TrieD Char b #-}

size :: Trie a b -> Int
size t = 1 + sum (map (size.snd) (anyChild t))

follow :: Ord a => [a] -> Trie a b -> Maybe (Trie a b)
follow xs t = foldr (>=>) return (map child xs) t

lookup :: Ord a => [a] -> TrieD a b -> Maybe b
lookup xs t = follow xs t >>= valueIn

-- | Construct the 'Trie' from the list of (word, value) pairs.
fromList :: Ord a => [([a], b)] -> TrieD a b
fromList xs =
    let update t (x, v) = insert x v t
    in  foldl' update empty xs

toList :: TrieD a b -> [([a], b)]
toList t = case valueIn t of
    Just y -> ([], y) : lower
    Nothing -> lower
  where
    lower = concatMap goDown $ anyChild t
    goDown (x, t') = map (addChar x) $ toList t'
    addChar x (xs, y) = (x:xs, y)

fromLang :: Ord a => [[a]] -> TrieD a ()
fromLang xs = fromList [(x, ()) | x <- xs]

-- | Elminate common subtries.  The result is algebraically a trie
-- but is represented as a DAWG in memory.
implicitDAWG :: (Ord a, Ord b) => Trie a b -> Trie a b
implicitDAWG = deserialize . serialize

-- | Serialize the trie and eliminate all common subtries
-- along the way.  TODO: perhaps the function name should
-- be different?
serialize :: (Ord a, Ord b) => Trie a b -> [Node a b]
serialize r =
    [ mkNode (valueIn t)
        [ (c, m M.! s)
        | (c, s) <- anyChild t ]
    | t <- M.elems m' ]
  where
    m  = collect r
    m' = M.fromList [(y, x) | (x, y) <- M.toList m]

-- | FIXME: Null node list case.
deserialize :: Ord a => [Node a b] -> Trie a b
deserialize =
    snd . M.findMax . foldl' update M.empty
  where
    update m n =
        let t = mkTrie (nodeValue n) [(c, m M.! k) | (c, k) <- nodeEdges n]
        in  M.insert (M.size m) t m

-- | Collect unique tries and assign identifiers to them.
collect :: (Ord a, Ord b) => Trie a b -> M.Map (Trie a b) Int
collect t = collect' M.empty t

collect' :: (Ord a, Ord b) => M.Map (Trie a b) Int
         -> Trie a b -> M.Map (Trie a b) Int
collect' m0 t = M.alter f t m
  where
    !m = foldl' collect' m0 (M.elems $ edgeMap t)
    !k = M.size m
    f Nothing  = Just k
    f (Just x) = Just x