{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies      #-}
{-# LANGUAGE BangPatterns      #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Network.MQTT.Trie
-- Copyright   :  (c) Lars Petersen 2016
-- License     :  MIT
--
-- Maintainer  :  info@lars-petersen.net
-- Stability   :  experimental
--------------------------------------------------------------------------------
module Network.MQTT.Trie (
  -- * Trie
    Trie (..)
  , TrieValue (..)
  -- ** null
  , null
  -- ** empty
  , empty
  -- ** size
  , size
  -- ** sizeWith
  , sizeWith
  -- ** singleton
  , singleton
  -- ** matchTopic
  , matchTopic
  -- ** matchFilter
  , matchFilter
  -- ** lookup
  , lookup
  -- ** findMaxBounded
  , findMaxBounded
  -- ** insert
  , insert
  -- ** insertWith
  , insertWith
  -- ** insertFoldable
  , insertFoldable
  -- ** map
  , map
  -- ** mapMaybe
  , mapMaybe
  -- ** foldl'
  , foldl'
  -- ** delete
  , delete
  -- ** union
  , union
  -- ** unionWith
  , unionWith
  -- ** differenceWith
  , differenceWith
  ) where

import           Control.Applicative        ((<|>))
import qualified Data.Binary                as B
import           Data.Functor.Identity
import qualified Data.IntSet                as IS
import qualified Data.List                  as L
import           Data.List.NonEmpty         (NonEmpty (..))
import qualified Data.Map.Strict            as M
import           Data.Maybe                 hiding (mapMaybe)
import           Data.Monoid
import           Prelude                    hiding (lookup, map, null)

import           Network.MQTT.Message.Topic

-- | The `Trie` is a map-like data structure designed to hold elements
--   that can efficiently be queried according to the matching rules specified
--   by MQTT.
--   The primary purpose is to manage client subscriptions, but it can just
--   as well be used to manage permissions etc.
--
--   The tree consists of nodes that may or may not contain values. The edges
--   are filter components. As some value types have the concept of a null
--   (i.e. an empty set) the `TrieValue` is a class defining the data
--   family `TrieNode`. This is a performance and size optimization to
--   avoid unnecessary boxing and case distinction.
newtype Trie a = Trie { branches :: M.Map Level (TrieNode a) }

class TrieValue a where
  data TrieNode a
  node                 :: Trie a -> Maybe a -> TrieNode a
  nodeNull             :: a -> Bool
  nodeTree             :: TrieNode a -> Trie a
  nodeValue            :: TrieNode a -> Maybe a

instance (TrieValue a, Monoid a) => Monoid (Trie a) where
  mempty  = empty
  mappend = unionWith mappend

instance (TrieValue a, Eq a) => Eq (Trie a) where
  Trie m1 == Trie m2 =
    M.size m1 == M.size m2 && and (zipWith f (M.toAscList m1) (M.toAscList m2))
    where
      f (l1,n1) (l2,n2) = l1 == l2 && nodeValue n1 == nodeValue n2 && nodeTree n1 == nodeTree n2

instance (TrieValue a, Show a) => Show (Trie a) where
  show (Trie m) = "Trie [" ++ L.intercalate ", " (f <$> M.toAscList m) ++ "]"
    where
      f (l,n) = "(" ++ show l ++ ", Node (" ++ show (nodeValue n) ++ ") (" ++ show (nodeTree n) ++ ")"

instance B.Binary (Trie ()) where
  put _ = pure ()
  get = pure empty

empty :: Trie a
empty  = Trie mempty

null  :: Trie a -> Bool
null (Trie m) = M.null m

-- | Count all trie nodes that are not `nodeNull`.
size :: TrieValue a => Trie a -> Int
size  = sizeWith (const 1)

sizeWith :: TrieValue a => (a -> Int) -> Trie a -> Int
sizeWith sz = countTrie 0
  where
    -- Depth-first search through the tree.
    -- This implementation uses an accumulator in order to not defer
    -- the evaluation of the additions.
    countTrie !accum t =
      M.foldl' countNode accum (branches t)
    countNode !accum n =
      case nodeValue n of
        Nothing -> countTrie accum (nodeTree n)
        Just v  -> countTrie (accum + sz v) (nodeTree n)

singleton :: TrieValue a => Filter -> a -> Trie a
singleton tf = singleton' (filterLevels tf)
  where
    singleton' (x:|xs) a
      | nodeNull a  = empty
      | otherwise   = Trie $ M.singleton x $ case xs of
          []     -> node empty (Just a)
          (y:ys) -> node (singleton' (y:|ys) a) Nothing

insert :: TrieValue a => Filter -> a -> Trie a -> Trie a
insert  = insertWith const

insertWith :: TrieValue a => (a -> a -> a) -> Filter -> a -> Trie a -> Trie a
insertWith f tf a = insertWith' (filterLevels tf)
  where
    insertWith' (x:|xs) (Trie m)
      | nodeNull a  = Trie m
      | otherwise   = Trie $ M.alter g x m
      where
        g mn = Just $ case xs of
          []     -> case mn of
            Nothing -> node empty (Just a)
            Just n  -> node (nodeTree n) $ (f a <$> nodeValue n) <|> Just a
          (y:ys) -> node (insertWith' (y:|ys) $ fromMaybe empty $ nodeTree <$> mn) Nothing

insertFoldable :: (TrieValue a, Foldable t) => t (Filter, a) -> Trie a -> Trie a
insertFoldable  = flip $ foldr $ uncurry insert

delete :: TrieValue a => Filter -> Trie a -> Trie a
delete tf = delete' (filterLevels tf)
  where
    delete' (x:|xs) (Trie m) = Trie $ M.update g x m
      where
        g n = case xs of
          [] | null (nodeTree n) -> Nothing
             | otherwise         -> Just $ node (nodeTree n) Nothing
          y:ys -> let t = delete' (y:|ys) (nodeTree n) in
           case nodeValue n of
             Nothing | null t    -> Nothing
                     | otherwise -> Just $ node t Nothing
             Just v -> Just $ node t (Just v)

map :: (TrieValue a, TrieValue b) => (a -> b) -> Trie a -> Trie b
map f (Trie m) = Trie $ fmap g m
  where
    g n = let t = map f (nodeTree n) in node t (f <$> nodeValue n)

-- | Applies a functor to a try and removes nodes for which the mapping
--   function returns `Nothing`.
mapMaybe :: (TrieValue a, TrieValue b) => (a -> Maybe b) -> Trie a -> Trie b
mapMaybe f (Trie m) = Trie (M.mapMaybe g m)
  where
    g n | isNothing v' && null t' = Nothing
        | otherwise               = Just (node t' v')
      where
        v' = nodeValue n >>= f
        t' = mapMaybe f $ nodeTree n

foldl'  :: (TrieValue b) => (a -> b -> a) -> a -> Trie b -> a
foldl' f acc (Trie m) = M.foldl' g acc m
  where
    g acc' n = flip (foldl' f) (nodeTree n) $! case nodeValue n of
      Nothing    -> acc'
      Just value -> f acc' value

union     :: (TrieValue a, Monoid a) => Trie a -> Trie a -> Trie a
union (Trie m1) (Trie m2) = Trie (M.unionWith g m1 m2)
  where
    g n1 n2 = node (nodeTree n1 `union` nodeTree n2) (nodeValue n1 <> nodeValue n2)

unionWith :: (TrieValue a) => (a -> a -> a) -> Trie a -> Trie a -> Trie a
unionWith f (Trie m1) (Trie m2) = Trie (M.unionWith g m1 m2)
  where
    g n1 n2 = node (unionWith f (nodeTree n1) (nodeTree n2)) (nodeValue n1 `merge` nodeValue n2)
    merge (Just v1) (Just v2) = Just (f v1 v2)
    merge      mv1       mv2  = mv1 <|> mv2

differenceWith :: (TrieValue a, TrieValue b) => (a -> b -> Maybe a) -> Trie a -> Trie b -> Trie a
differenceWith f (Trie m1) (Trie m2) = Trie (M.differenceWith g m1 m2)
  where
    g n1 n2 = k (differenceWith f (nodeTree n1) (nodeTree n2)) (d (nodeValue n1) (nodeValue n2))
    d (Just v1) (Just v2) = f v1 v2
    d (Just v1)  _        = Just v1
    d  _         _        = Nothing
    k t Nothing  | null t               = Nothing
                 | otherwise            = Just $ node t Nothing
    k t (Just v) | null t && nodeNull v = Nothing
                 | otherwise            = Just $ node t $ Just v

-- | Collect all values of nodes that match a given topic (according to the
--   matching rules specified by the MQTT protocol).
lookup :: (TrieValue a, Monoid a) => Topic -> Trie a -> a
lookup tf = fromMaybe mempty . lookupHead (topicLevels tf)
  where
    -- If the first level starts with $ then it must not be matched against + and #.
    lookupHead (x:|xs) t@(Trie m)
      | startsWithDollar x = case xs of
          []     -> M.lookup x m >>= nodeValue
          (y:ys) -> M.lookup x m >>= lookupTail y ys . nodeTree
      | otherwise = lookupTail x xs t
    lookupTail x [] (Trie m) =
      matchSingleLevelWildcard <> matchMultiLevelWildcard <> matchComponent
      where
        matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= nodeValue
        matchMultiLevelWildcard  = M.lookup multiLevelWildcard  m >>= nodeValue
        matchComponent           = M.lookup x                   m >>= \n->
          case M.lookup multiLevelWildcard $ branches $ nodeTree n of
            -- component match, but no additional multiLevelWildcard below
            Nothing -> nodeValue n
            -- component match and multiLevelWildcard match below
            Just n' -> nodeValue n <> nodeValue n'
    lookupTail x (y:ys) (Trie m) =
      matchSingleLevelWildcard <> matchMultiLevelWildcard <> matchComponent
      where
        matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= lookupTail y ys . nodeTree
        matchMultiLevelWildcard  = M.lookup multiLevelWildcard  m >>= nodeValue
        matchComponent           = M.lookup                   x m >>= lookupTail y ys . nodeTree

-- | Find the greatest value in a trie that matches the topic.
--
--   * Stops search as soon as a `maxBound` element has been found.
--   * Doesn't match into `$` topics.
findMaxBounded :: (TrieValue a, Ord a, Bounded a) => Topic -> Trie a -> Maybe a
findMaxBounded topic = findHead (topicLevels topic)
  where
    findHead (x:|xs) t@(Trie m)
      | startsWithDollar x = case xs of
          []     -> M.lookup x m >>= nodeValue
          (y:ys) -> M.lookup x m >>= findTail y ys . nodeTree
      | otherwise = findTail x xs t
    findTail x [] (Trie m) =
      matchMultiLevelWildcard `maxBounded` matchSingleLevelWildcard `maxBounded` matchComponent
      where
        matchMultiLevelWildcard  = M.lookup multiLevelWildcard  m >>=
          nodeValue
        matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= \n->
          nodeValue n `maxBounded` (nodeValue =<< M.lookup multiLevelWildcard (branches $ nodeTree n))
        matchComponent           = M.lookup x                   m >>= \n->
          nodeValue n `maxBounded` (nodeValue =<< M.lookup multiLevelWildcard (branches $ nodeTree n))
    findTail x (y:ys) (Trie m) =
      matchMultiLevelWildcard `maxBounded` matchSingleLevelWildcard `maxBounded` matchComponent
      where
        matchMultiLevelWildcard  = M.lookup multiLevelWildcard  m >>= nodeValue
        matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= findTail y ys . nodeTree
        matchComponent           = M.lookup                   x m >>= findTail y ys . nodeTree

    maxBounded :: (Ord a, Bounded a) => Maybe a -> Maybe a -> Maybe a
    maxBounded a b
      | a == Just maxBound = a
      | otherwise          = max a b

-- | Match a `Topic` against a `Trie`.
--
--   The function returns true iff the tree contains at least one node that
--   matches the topic /and/ contains a value (including nodes that are
--   indirectly matched by wildcard characters like `+` and `#` as described
--   in the MQTT specification).
matchTopic :: TrieValue a => Topic -> Trie a -> Bool
matchTopic tf = matchTopicHead (topicLevels tf)
  where
    -- The '#' is always a terminal node and therefore does not contain subtrees.
    -- By invariant, a '#' node only exists if it contains a value. For this
    -- reason it does not need to be checked for a value here, but just for
    -- existence.
    -- A '+' node on the other hand may contain subtrees and may not carry a value
    -- itself. This needs to be checked.
    matchTopicHead (x:|xs) t@(Trie m)
      | startsWithDollar x = case xs of
          []     -> matchExact x m
          (y:ys) -> fromMaybe False $ matchTopicTail y ys . nodeTree <$> M.lookup x m
      | otherwise = matchTopicTail x xs t
    matchTopicTail x [] (Trie m) =
      matchExact x m  || matchPlus || matchHash
      where
        matchPlus = isJust ( nodeValue =<< M.lookup singleLevelWildcard m )
        matchHash = M.member multiLevelWildcard m
    matchTopicTail x (y:ys) (Trie m) =
      M.member multiLevelWildcard m || case M.lookup x m of
        -- Same is true for '#' node here. In case no '#' hash node is present it is
        -- first tried to match the exact topic and then to match any '+' node.
        Nothing -> matchPlus
        Just n  -> matchTopicTail y ys (nodeTree n) || matchPlus
      where
        -- A '+' node matches any topic element.
        matchPlus = fromMaybe False
                  $ matchTopicTail y ys . nodeTree <$> M.lookup singleLevelWildcard m
    -- An exact match is the case if the map contains a node for the key and
    -- the node is not empty _or_ the node's subtree contains a wildcard key (wildcards)
    -- always also match the parent node.
    matchExact x m = case M.lookup x m of
      Nothing -> False
      Just n  -> isJust (nodeValue n) || let Trie m' = nodeTree n in  M.member multiLevelWildcard m'

-- | Match a `Filter` against a `Trie`.
--
--   The function returns true iff the tree contains a path that is
--   /less or equally specific/ than the filter and the terminal node contains
--   a value that is not `nodeNull`.
--
--   > match (singleton "#") "a"     = True
--   > match (singleton "#") "+"     = True
--   > match (singleton "#") "a/+/b" = True
--   > match (singleton "#") "a/+/#" = True
--   > match (singleton "+") "a"     = True
--   > match (singleton "+") "+"     = True
--   > match (singleton "+") "+/a"   = False
--   > match (singleton "+") "#"     = False
--   > match (singleton "a") "a"     = True
--   > match (singleton "a") "b"     = False
--   > match (singleton "a") "+"     = False
--   > match (singleton "a") "#"     = False
matchFilter :: TrieValue a => Filter -> Trie a -> Bool
matchFilter tf = matchFilter' (filterLevels tf)
  where
    matchFilter' (x:|[]) (Trie m)
      | x == multiLevelWildcard  = matchMultiLevelWildcard
      | x == singleLevelWildcard = matchMultiLevelWildcard || matchSingleLevelWildcard
      | otherwise                = matchMultiLevelWildcard || matchSingleLevelWildcard || matchExact
      where
        matchMultiLevelWildcard  = M.member multiLevelWildcard m
        matchSingleLevelWildcard = isJust ( nodeValue =<< M.lookup singleLevelWildcard m )
        matchExact               = case M.lookup x m of
          Nothing -> False
          Just n' -> isJust (nodeValue n') || let Trie m' = nodeTree n' in M.member multiLevelWildcard m'
    matchFilter' (x:|y:zs) (Trie m)
      | x == multiLevelWildcard  = matchMultiLevelWildcard
      | x == singleLevelWildcard = matchMultiLevelWildcard || matchSingleLevelWildcard
      | otherwise                = matchMultiLevelWildcard || matchSingleLevelWildcard || matchExact
      where
        matchMultiLevelWildcard  = M.member multiLevelWildcard  m
        matchSingleLevelWildcard = fromMaybe False $ matchFilter' (y:|zs) . nodeTree <$> M.lookup singleLevelWildcard m
        matchExact               = fromMaybe False $ matchFilter' (y:|zs) . nodeTree <$> M.lookup x m

--------------------------------------------------------------------------------
-- Specialised nodeTree implemenations using data families
--------------------------------------------------------------------------------

instance TrieValue IS.IntSet where
  data TrieNode IS.IntSet = IntSetTrieNode !(Trie IS.IntSet) !IS.IntSet
  node t                                = IntSetTrieNode t . fromMaybe mempty
  nodeNull                              = IS.null
  nodeTree (IntSetTrieNode t _) = t
  nodeValue (IntSetTrieNode _ v)
    | nodeNull v = Nothing
    | otherwise  = Just v

instance TrieValue (Identity a) where
  data TrieNode (Identity a)    = IdentityNode !(Trie (Identity a)) !(Maybe (Identity a))
  node t n@Nothing              = IdentityNode t n
  node t n@(Just _)             = IdentityNode t n
  nodeNull                      = const False
  nodeTree  (IdentityNode t _)  = t
  nodeValue (IdentityNode _ mv) = mv

instance TrieValue () where
  data TrieNode ()  = UnitNode {-# UNPACK #-} !Int !(Trie ())
  node t Nothing = UnitNode 0 t
  node t _       = UnitNode 1 t
  nodeNull                 = const False
  nodeTree  (UnitNode _ t) = t
  nodeValue (UnitNode 0 _) = Nothing
  nodeValue (UnitNode _ _) = Just ()

instance TrieValue Bool where
  data TrieNode Bool = BoolNode {-# UNPACK #-} !Int !(Trie Bool)
  node t Nothing      = BoolNode 0 t
  node t (Just False) = BoolNode 1 t
  node t (Just True)  = BoolNode 2 t
  nodeNull                     = const False
  nodeTree  (BoolNode _ t)     = t
  nodeValue (BoolNode 1 _) = Just False
  nodeValue (BoolNode 2 _) = Just True
  nodeValue (BoolNode _ _) = Nothing