{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE InstanceSigs          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}
module Data.RadixTree
  ( RadixTree (..)
  , RadixNode (..)
  , CompressedRadixTree
    -- * Construction
  , fromFoldable
  , compressBy
    -- * Parsing with radix trees
  , RadixParsing (..)
  , search
  ) where
import           Control.Applicative
import           Control.DeepSeq
import           Data.Data               (Data, Typeable)
import           Data.Foldable           (asum, foldr', toList)
import           Data.Map.Strict         (Map)
import qualified Data.Map.Strict         as M
import           Data.Monoid
import           Data.Sequence           (Seq)
import qualified Data.Sequence           as Seq
import           Data.Store              ()
import           Data.Store.TH           (makeStore)
import           Data.Text               (Text)
import qualified Data.Text               as T
import qualified Data.Text.Array         as TI (Array)
import qualified Data.Text.Internal      as TI (Text (..), text)
import           Data.Vector             (Vector)
import qualified Data.Vector             as V
import           Lens.Micro
import           Text.Parser.Char        (CharParsing (anyChar, text))
import           Text.Parser.Combinators (Parsing (try))

--------------------------------------------------------------------------------
-- Stuff to help construct RadixTrees
--
-- I'm not clever enough to write a function to go directly from a 'Foldable' to
-- a fully-optimised RadixTree. Instead, I generate a prefix-tree using a 'Map'
-- directly ('Trie'), and then gradually compress that ('CompressedTrie') before
-- packing the final result into an efficient structure using 'Text' nodes.
--
-- TODO:
-- - generate RadixTree directly, instead of going through 'Trie'/'CompressedTrie'
-- - use compact regions?

data PrefixNode a = Accept !Text !a | Skip !a
  deriving (Show, Eq)

newtype Trie = Trie (PrefixNode (Map Char Trie))
  deriving (Show, Eq)

newtype CompressedTrie = CompressedTrie (PrefixNode (Map (Seq Char) CompressedTrie))
  deriving (Show, Eq)

{-# INLINE node #-}
node :: Lens (PrefixNode a) (PrefixNode b) a b
node = lens
  (\x -> case x of { Accept _ t -> t; Skip t -> t })
  (\x a -> case x of { Accept l _ -> Accept l a; Skip _ -> Skip a })

leaf :: Text -> Text -> Trie
leaf ft t = go (T.unpack t)
  where
    go (x:xs) = Trie (Skip (M.singleton x (go xs)))
    go []     = Trie (Accept ft M.empty)

insert :: Text -> Text -> Trie -> Trie
insert ft text' (Trie n) = case T.uncons text' of
  Just (c, cs) -> Trie ((node %~
    M.insertWith
    (\_ orig -> Data.RadixTree.insert ft cs orig)
    c
    (leaf ft cs)) n)
  Nothing ->
    Trie (n^.node.to (Accept ft))

makeCompressable :: Trie -> CompressedTrie
makeCompressable (Trie n) = CompressedTrie (
  over node (M.map makeCompressable . M.mapKeysMonotonic Seq.singleton) n)

compress :: Trie -> CompressedTrie
compress = go . makeCompressable
  where
    go :: CompressedTrie -> CompressedTrie
    go (CompressedTrie n) = case n of
      Accept l m -> CompressedTrie (Accept l (M.map go m))
      Skip m     -> CompressedTrie (Skip (M.foldMapWithKey compress1 m))

    compress1 :: Seq Char -> CompressedTrie -> Map (Seq Char) CompressedTrie
    compress1 k c@(CompressedTrie n) =
      case M.size sm of
        0 -> M.singleton k c
        1 | Skip _ <- n -> compress1 (k <> k') sm'
          where (k', sm') = M.findMax sm
        _ -> M.singleton k (go (n & node .~ sm & CompressedTrie))
      where sm = n^.node

--------------------------------------------------------------------------------

-- | A node in a radixtree. To advance from here a parser must parse the 'Text'
-- (i.e., the prefix) value at this node.
data RadixNode = RadixNode {-# UNPACK #-} !Text !RadixTree
  deriving (Eq, Show, Typeable, Data)

-- | A radixtree. Construct with 'fromFoldable', and use with 'parse'.
data RadixTree
  = -- | Can terminate a parser successfully, returning the 'Text' value given.
    RadixAccept
    {-# UNPACK #-} !Text -- the final value to return
    {-# UNPACK #-} !(Vector RadixNode) -- possible subtrees beyond this point
  | RadixSkip
    {-# UNPACK #-} !(Vector RadixNode) -- possible subtrees beyond this point
  deriving (Eq, Show, Typeable, Data)

instance NFData RadixNode where
  {-# INLINE rnf #-}
  rnf (RadixNode l t) = rnf l `seq` rnf t

instance NFData RadixTree where
  {-# INLINE rnf #-}
  rnf (RadixAccept t v) = t `seq` rnf v
  rnf (RadixSkip v)     = rnf v

-- | Compress a totally-unoptimised 'Trie' into a nice and easily-parsable
-- 'RadixTree'
fromTrie :: Trie -> RadixTree
fromTrie = go . compress
  where
    !z = V.empty

    radixNode :: Seq Char -> CompressedTrie -> RadixNode
    radixNode l t = RadixNode (T.pack (toList l)) (go t)

    mapToVector :: Map k a -> Vector (k, a)
    mapToVector m = case M.size m of
      0  -> z
      sz -> V.fromListN sz (M.toList m)

    go :: CompressedTrie -> RadixTree
    go (CompressedTrie n) = case n of
      Accept l m -> RadixAccept l . V.map (uncurry radixNode) . mapToVector $! m
      Skip m -> RadixSkip . V.map (uncurry radixNode) . mapToVector $! m

data TextSlice = TextSlice
  { tsOffset16 :: {-# UNPACK #-} !Int -- ^ offset (in units of Word16)
  , tsLength16 :: {-# UNPACK #-} !Int -- ^ length (in units of Word16)
  }

-- | Probably dangerous magic
--
-- When the second argument is found to be within the first, we re-use the
-- 'Text' array of the first. This should allow the second argument to be
-- garbage collected. This is to improve locality and memory use.
magicallySaveSpaceSometimes :: Text -> Text -> Maybe TextSlice
magicallySaveSpaceSometimes full s@(TI.Text _ _ slen) =
  case T.breakOn s full of
    (TI.Text{}, r@(TI.Text _ remoffs _))
      | T.null r  -> Nothing
      | otherwise -> Just TextSlice{tsOffset16 = remoffs, tsLength16 = slen}

-- | A normal 'RadixTree' stores a new 'Text' at every node. In contrast, a
-- 'CompressedRadixTree' takes a single corpus 'Text' which is indexed into by
-- nodes. This can save a lot of memory (e.g., using the radix trees from the
-- parsing benchmarks in this package, the 'CompressedRadixTree' version is
-- 254032 bytes, whereas the ordinary 'RadixTree' is a rotund 709904 bytes) at
-- no runtime cost.
data CompressedRadixTree
  = CompressedRadixTree {-# UNPACK #-} !TI.Array !CompressedRadixTree1

data CompressedRadixTree1
  = CompressedRadixAccept
    {-# UNPACK #-} !TextSlice
    {-# UNPACK #-} !(Vector CompressedRadixNode)
  | CompressedRadixSkip {-# UNPACK #-} !(Vector CompressedRadixNode)

data CompressedRadixNode
  = CompressedRadixNode {-# UNPACK #-} !TextSlice !CompressedRadixTree1

instance NFData CompressedRadixNode where
  {-# INLINE rnf #-}
  rnf (CompressedRadixNode ts t) = ts `seq` rnf t

instance NFData CompressedRadixTree where
  {-# INLINE rnf #-}
  rnf (CompressedRadixTree arr v) = arr `seq` rnf v

instance NFData CompressedRadixTree1 where
  {-# INLINE rnf #-}
  rnf (CompressedRadixAccept ts v) = ts `seq` rnf v
  rnf (CompressedRadixSkip v)      = rnf v

-- | Compress a 'RadixTree' given a corpus. All values in the tree be findable
-- within the corpus, though the corpus does not have to necessarily be the
-- direct source of the tree
compressBy :: Text -> RadixTree -> Maybe CompressedRadixTree
compressBy full@(TI.Text arr _ _) rt =
  CompressedRadixTree arr <$> recompressT rt

  where
    magic = magicallySaveSpaceSometimes full

    recompressN :: RadixNode -> Maybe CompressedRadixNode
    recompressN (RadixNode t tree) = CompressedRadixNode <$> magic t <*> recompressT tree

    recompressT :: RadixTree -> Maybe CompressedRadixTree1
    recompressT (RadixSkip v) = CompressedRadixSkip <$> V.mapM recompressN v
    recompressT (RadixAccept t v) = CompressedRadixAccept <$> magic t <*> V.mapM recompressN v

-- | *Slow*
fromFoldable :: Foldable f => f Text -> RadixTree
fromFoldable =
  fromTrie . foldr' (\t -> insert t t) (Trie (Skip M.empty))

makeStore ''RadixNode
makeStore ''RadixTree

--------------------------------------------------------------------------------
-- Parsers from 'RadixTree's

class RadixParsing radixtree where
  parse :: CharParsing m => radixtree -> m Text

{-# INLINE search #-}
-- | Find all occurences of the terms in a 'RadixTree' from this point on. This
-- will consume the entire remaining input. Can lazily produce results (but this
-- depends on your parser).
search :: (Monad m, CharParsing m, RadixParsing radixtree)
       => radixtree -> m [Text]
search r = go
  where
    go =
      (parse r >>= \x -> (x:) <$> go) <|>
      (anyChar >> go) <|>
      return []

instance RadixParsing RadixTree where
  {-# INLINE parse #-}
  -- | Parse from a 'RadixTree'
  parse :: CharParsing m => RadixTree -> m Text
  parse = go
    where
      go r = case r of
        RadixAccept l nodes
          | T.null l -> empty
          | otherwise -> asum (V.map parseRadixNode nodes) <|> pure l
        RadixSkip nodes -> asum (V.map parseRadixNode nodes)

      {-# INLINE parseRadixNode #-}
      parseRadixNode (RadixNode prefix tree)
        | T.null prefix = go tree
        | otherwise     = try (text prefix *> go tree)

instance RadixParsing CompressedRadixTree where
  {-# INLINE parse #-}
  -- | Parse from a 'RadixTree'
  parse :: CharParsing m => CompressedRadixTree -> m Text
  parse (CompressedRadixTree arr crt) = go crt
    where
      fromSlice (TextSlice offs len) = TI.text arr offs len

      go r = case r of
        CompressedRadixAccept ts nodes -> case fromSlice ts of
          l | T.null l -> empty
            | otherwise -> asum (V.map parseRadixNode nodes) <|> pure l
        CompressedRadixSkip nodes -> asum (V.map parseRadixNode nodes)

      {-# INLINE parseRadixNode #-}
      parseRadixNode (CompressedRadixNode ts tree) = case fromSlice ts of
        prefix | T.null prefix -> go tree
               | otherwise     -> try (text prefix *> go tree)