module Network.MQTT.Trie (
Trie (..)
, TrieValue (..)
, null
, empty
, size
, sizeWith
, singleton
, matchTopic
, matchFilter
, lookup
, findMaxBounded
, insert
, insertWith
, insertFoldable
, map
, mapMaybe
, foldl'
, delete
, union
, unionWith
, differenceWith
) where
import Control.Applicative ((<|>))
import qualified Data.Binary as B
import Data.Functor.Identity
import qualified Data.IntSet as IS
import qualified Data.List as L
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.Map.Strict as M
import Data.Maybe hiding (mapMaybe)
import Data.Monoid
import Prelude hiding (lookup, map, null)
import Network.MQTT.Message.Topic
newtype Trie a = Trie { branches :: M.Map Level (TrieNode a) }
class TrieValue a where
data TrieNode a
node :: Trie a -> Maybe a -> TrieNode a
nodeNull :: a -> Bool
nodeTree :: TrieNode a -> Trie a
nodeValue :: TrieNode a -> Maybe a
instance (TrieValue a, Monoid a) => Monoid (Trie a) where
mempty = empty
mappend = unionWith mappend
instance (TrieValue a, Eq a) => Eq (Trie a) where
Trie m1 == Trie m2 =
M.size m1 == M.size m2 && and (zipWith f (M.toAscList m1) (M.toAscList m2))
where
f (l1,n1) (l2,n2) = l1 == l2 && nodeValue n1 == nodeValue n2 && nodeTree n1 == nodeTree n2
instance (TrieValue a, Show a) => Show (Trie a) where
show (Trie m) = "Trie [" ++ L.intercalate ", " (f <$> M.toAscList m) ++ "]"
where
f (l,n) = "(" ++ show l ++ ", Node (" ++ show (nodeValue n) ++ ") (" ++ show (nodeTree n) ++ ")"
instance B.Binary (Trie ()) where
put _ = pure ()
get = pure empty
empty :: Trie a
empty = Trie mempty
null :: Trie a -> Bool
null (Trie m) = M.null m
size :: TrieValue a => Trie a -> Int
size = sizeWith (const 1)
sizeWith :: TrieValue a => (a -> Int) -> Trie a -> Int
sizeWith sz = countTrie 0
where
countTrie !accum t =
M.foldl' countNode accum (branches t)
countNode !accum n =
case nodeValue n of
Nothing -> countTrie accum (nodeTree n)
Just v -> countTrie (accum + sz v) (nodeTree n)
singleton :: TrieValue a => Filter -> a -> Trie a
singleton tf = singleton' (filterLevels tf)
where
singleton' (x:|xs) a
| nodeNull a = empty
| otherwise = Trie $ M.singleton x $ case xs of
[] -> node empty (Just a)
(y:ys) -> node (singleton' (y:|ys) a) Nothing
insert :: TrieValue a => Filter -> a -> Trie a -> Trie a
insert = insertWith const
insertWith :: TrieValue a => (a -> a -> a) -> Filter -> a -> Trie a -> Trie a
insertWith f tf a = insertWith' (filterLevels tf)
where
insertWith' (x:|xs) (Trie m)
| nodeNull a = Trie m
| otherwise = Trie $ M.alter g x m
where
g mn = Just $ case xs of
[] -> case mn of
Nothing -> node empty (Just a)
Just n -> node (nodeTree n) $ (f a <$> nodeValue n) <|> Just a
(y:ys) -> node (insertWith' (y:|ys) $ fromMaybe empty $ nodeTree <$> mn) Nothing
insertFoldable :: (TrieValue a, Foldable t) => t (Filter, a) -> Trie a -> Trie a
insertFoldable = flip $ foldr $ uncurry insert
delete :: TrieValue a => Filter -> Trie a -> Trie a
delete tf = delete' (filterLevels tf)
where
delete' (x:|xs) (Trie m) = Trie $ M.update g x m
where
g n = case xs of
[] | null (nodeTree n) -> Nothing
| otherwise -> Just $ node (nodeTree n) Nothing
y:ys -> let t = delete' (y:|ys) (nodeTree n) in
case nodeValue n of
Nothing | null t -> Nothing
| otherwise -> Just $ node t Nothing
Just v -> Just $ node t (Just v)
map :: (TrieValue a, TrieValue b) => (a -> b) -> Trie a -> Trie b
map f (Trie m) = Trie $ fmap g m
where
g n = let t = map f (nodeTree n) in node t (f <$> nodeValue n)
mapMaybe :: (TrieValue a, TrieValue b) => (a -> Maybe b) -> Trie a -> Trie b
mapMaybe f (Trie m) = Trie (M.mapMaybe g m)
where
g n | isNothing v' && null t' = Nothing
| otherwise = Just (node t' v')
where
v' = nodeValue n >>= f
t' = mapMaybe f $ nodeTree n
foldl' :: (TrieValue b) => (a -> b -> a) -> a -> Trie b -> a
foldl' f acc (Trie m) = M.foldl' g acc m
where
g acc' n = flip (foldl' f) (nodeTree n) $! case nodeValue n of
Nothing -> acc'
Just value -> f acc' value
union :: (TrieValue a, Monoid a) => Trie a -> Trie a -> Trie a
union (Trie m1) (Trie m2) = Trie (M.unionWith g m1 m2)
where
g n1 n2 = node (nodeTree n1 `union` nodeTree n2) (nodeValue n1 <> nodeValue n2)
unionWith :: (TrieValue a) => (a -> a -> a) -> Trie a -> Trie a -> Trie a
unionWith f (Trie m1) (Trie m2) = Trie (M.unionWith g m1 m2)
where
g n1 n2 = node (unionWith f (nodeTree n1) (nodeTree n2)) (nodeValue n1 `merge` nodeValue n2)
merge (Just v1) (Just v2) = Just (f v1 v2)
merge mv1 mv2 = mv1 <|> mv2
differenceWith :: (TrieValue a, TrieValue b) => (a -> b -> Maybe a) -> Trie a -> Trie b -> Trie a
differenceWith f (Trie m1) (Trie m2) = Trie (M.differenceWith g m1 m2)
where
g n1 n2 = k (differenceWith f (nodeTree n1) (nodeTree n2)) (d (nodeValue n1) (nodeValue n2))
d (Just v1) (Just v2) = f v1 v2
d (Just v1) _ = Just v1
d _ _ = Nothing
k t Nothing | null t = Nothing
| otherwise = Just $ node t Nothing
k t (Just v) | null t && nodeNull v = Nothing
| otherwise = Just $ node t $ Just v
lookup :: (TrieValue a, Monoid a) => Topic -> Trie a -> a
lookup tf = fromMaybe mempty . lookupHead (topicLevels tf)
where
lookupHead (x:|xs) t@(Trie m)
| startsWithDollar x = case xs of
[] -> M.lookup x m >>= nodeValue
(y:ys) -> M.lookup x m >>= lookupTail y ys . nodeTree
| otherwise = lookupTail x xs t
lookupTail x [] (Trie m) =
matchSingleLevelWildcard <> matchMultiLevelWildcard <> matchComponent
where
matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= nodeValue
matchMultiLevelWildcard = M.lookup multiLevelWildcard m >>= nodeValue
matchComponent = M.lookup x m >>= \n->
case M.lookup multiLevelWildcard $ branches $ nodeTree n of
Nothing -> nodeValue n
Just n' -> nodeValue n <> nodeValue n'
lookupTail x (y:ys) (Trie m) =
matchSingleLevelWildcard <> matchMultiLevelWildcard <> matchComponent
where
matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= lookupTail y ys . nodeTree
matchMultiLevelWildcard = M.lookup multiLevelWildcard m >>= nodeValue
matchComponent = M.lookup x m >>= lookupTail y ys . nodeTree
findMaxBounded :: (TrieValue a, Ord a, Bounded a) => Topic -> Trie a -> Maybe a
findMaxBounded topic = findHead (topicLevels topic)
where
findHead (x:|xs) t@(Trie m)
| startsWithDollar x = case xs of
[] -> M.lookup x m >>= nodeValue
(y:ys) -> M.lookup x m >>= findTail y ys . nodeTree
| otherwise = findTail x xs t
findTail x [] (Trie m) =
matchMultiLevelWildcard `maxBounded` matchSingleLevelWildcard `maxBounded` matchComponent
where
matchMultiLevelWildcard = M.lookup multiLevelWildcard m >>=
nodeValue
matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= \n->
nodeValue n `maxBounded` (nodeValue =<< M.lookup multiLevelWildcard (branches $ nodeTree n))
matchComponent = M.lookup x m >>= \n->
nodeValue n `maxBounded` (nodeValue =<< M.lookup multiLevelWildcard (branches $ nodeTree n))
findTail x (y:ys) (Trie m) =
matchMultiLevelWildcard `maxBounded` matchSingleLevelWildcard `maxBounded` matchComponent
where
matchMultiLevelWildcard = M.lookup multiLevelWildcard m >>= nodeValue
matchSingleLevelWildcard = M.lookup singleLevelWildcard m >>= findTail y ys . nodeTree
matchComponent = M.lookup x m >>= findTail y ys . nodeTree
maxBounded :: (Ord a, Bounded a) => Maybe a -> Maybe a -> Maybe a
maxBounded a b
| a == Just maxBound = a
| otherwise = max a b
matchTopic :: TrieValue a => Topic -> Trie a -> Bool
matchTopic tf = matchTopicHead (topicLevels tf)
where
matchTopicHead (x:|xs) t@(Trie m)
| startsWithDollar x = case xs of
[] -> matchExact x m
(y:ys) -> fromMaybe False $ matchTopicTail y ys . nodeTree <$> M.lookup x m
| otherwise = matchTopicTail x xs t
matchTopicTail x [] (Trie m) =
matchExact x m || matchPlus || matchHash
where
matchPlus = isJust ( nodeValue =<< M.lookup singleLevelWildcard m )
matchHash = M.member multiLevelWildcard m
matchTopicTail x (y:ys) (Trie m) =
M.member multiLevelWildcard m || case M.lookup x m of
Nothing -> matchPlus
Just n -> matchTopicTail y ys (nodeTree n) || matchPlus
where
matchPlus = fromMaybe False
$ matchTopicTail y ys . nodeTree <$> M.lookup singleLevelWildcard m
matchExact x m = case M.lookup x m of
Nothing -> False
Just n -> isJust (nodeValue n) || let Trie m' = nodeTree n in M.member multiLevelWildcard m'
matchFilter :: TrieValue a => Filter -> Trie a -> Bool
matchFilter tf = matchFilter' (filterLevels tf)
where
matchFilter' (x:|[]) (Trie m)
| x == multiLevelWildcard = matchMultiLevelWildcard
| x == singleLevelWildcard = matchMultiLevelWildcard || matchSingleLevelWildcard
| otherwise = matchMultiLevelWildcard || matchSingleLevelWildcard || matchExact
where
matchMultiLevelWildcard = M.member multiLevelWildcard m
matchSingleLevelWildcard = isJust ( nodeValue =<< M.lookup singleLevelWildcard m )
matchExact = case M.lookup x m of
Nothing -> False
Just n' -> isJust (nodeValue n') || let Trie m' = nodeTree n' in M.member multiLevelWildcard m'
matchFilter' (x:|y:zs) (Trie m)
| x == multiLevelWildcard = matchMultiLevelWildcard
| x == singleLevelWildcard = matchMultiLevelWildcard || matchSingleLevelWildcard
| otherwise = matchMultiLevelWildcard || matchSingleLevelWildcard || matchExact
where
matchMultiLevelWildcard = M.member multiLevelWildcard m
matchSingleLevelWildcard = fromMaybe False $ matchFilter' (y:|zs) . nodeTree <$> M.lookup singleLevelWildcard m
matchExact = fromMaybe False $ matchFilter' (y:|zs) . nodeTree <$> M.lookup x m
instance TrieValue IS.IntSet where
data TrieNode IS.IntSet = IntSetTrieNode !(Trie IS.IntSet) !IS.IntSet
node t = IntSetTrieNode t . fromMaybe mempty
nodeNull = IS.null
nodeTree (IntSetTrieNode t _) = t
nodeValue (IntSetTrieNode _ v)
| nodeNull v = Nothing
| otherwise = Just v
instance TrieValue (Identity a) where
data TrieNode (Identity a) = IdentityNode !(Trie (Identity a)) !(Maybe (Identity a))
node t n@Nothing = IdentityNode t n
node t n@(Just _) = IdentityNode t n
nodeNull = const False
nodeTree (IdentityNode t _) = t
nodeValue (IdentityNode _ mv) = mv
instance TrieValue () where
data TrieNode () = UnitNode !Int !(Trie ())
node t Nothing = UnitNode 0 t
node t _ = UnitNode 1 t
nodeNull = const False
nodeTree (UnitNode _ t) = t
nodeValue (UnitNode 0 _) = Nothing
nodeValue (UnitNode _ _) = Just ()
instance TrieValue Bool where
data TrieNode Bool = BoolNode !Int !(Trie Bool)
node t Nothing = BoolNode 0 t
node t (Just False) = BoolNode 1 t
node t (Just True) = BoolNode 2 t
nodeNull = const False
nodeTree (BoolNode _ t) = t
nodeValue (BoolNode 1 _) = Just False
nodeValue (BoolNode 2 _) = Just True
nodeValue (BoolNode _ _) = Nothing