module Data.TrieMap.WordMap (SNode, WHole, TrieMap(WordMap), Hole(Hole), getWordMap, getHole) where
import Data.TrieMap.TrieKey
import Data.TrieMap.Sized
import Control.Exception (assert)
import Control.Applicative (Applicative(..), (<$>))
import Control.Monad hiding (join)
import Data.Bits
import Data.Foldable
import Data.Maybe hiding (mapMaybe)
import Data.Monoid
import Data.TrieMap.Utils
import GHC.Exts
import Prelude hiding (lookup, null, map, foldl, foldr, foldl1, foldr1)
#include "MachDeps.h"
#define NIL SNode{node = Nil}
#define TIP(args) SNode{node = (Tip args)}
#define BIN(args) SNode{node = (Bin args)}
type Nat = Word
type Prefix = Word
type Mask = Word
type Key = Word
type Size = Int
data Path a = Root
| LeftBin !Prefix !Mask (Path a) !(SNode a)
| RightBin !Prefix !Mask !(SNode a) (Path a)
data SNode a = SNode {sz :: !Size, node :: (Node a)}
data Node a = Nil | Tip !Key a | Bin !Prefix !Mask !(SNode a) !(SNode a)
instance Sized (SNode a) where
getSize# SNode{sz} = unbox sz
instance Sized a => Sized (Node a) where
getSize# t = unbox $ case t of
Nil -> 0
Tip _ a -> getSize a
Bin _ _ l r -> getSize l + getSize r
sNode :: Sized a => Node a -> SNode a
sNode !n = SNode (getSize n) n
data WHole a = WHole !Key (Path a)
hole :: Key -> Path a -> Hole Word a
hole k path = Hole (WHole k path)
#define HOLE(args) (Hole (WHole args))
instance TrieKey Word where
newtype TrieMap Word a = WordMap {getWordMap :: SNode a}
newtype Hole Word a = Hole {getHole :: WHole a}
emptyM = WordMap nil
singletonM k a = WordMap (singleton k a)
getSimpleM (WordMap (SNode _ n)) = case n of
Nil -> Null
Tip _ a -> Singleton a
_ -> NonSimple
sizeM (WordMap t) = getSize t
lookupM k (WordMap m) = lookup k m
traverseM f (WordMap m) = WordMap <$> traverse f m
fmapM f (WordMap m) = WordMap (map f m)
mapMaybeM f (WordMap m) = WordMap (mapMaybe f m)
mapEitherM f (WordMap m) = both WordMap WordMap (mapEither f) m
unionM f (WordMap m1) (WordMap m2) = WordMap (unionWith f m1 m2)
isectM f (WordMap m1) (WordMap m2) = WordMap (intersectionWith f m1 m2)
diffM f (WordMap m1) (WordMap m2) = WordMap (differenceWith f m1 m2)
isSubmapM (<=) (WordMap m1) (WordMap m2) = isSubmapOfBy (<=) m1 m2
singleHoleM k = hole k Root
beforeM HOLE(_ path) = WordMap (before nil path)
beforeWithM a HOLE(k path) = WordMap (before (singleton k a) path)
afterM HOLE(_ path) = WordMap (after nil path)
afterWithM a HOLE(k path) = WordMap (after (singleton k a) path)
searchMC !k (WordMap t) = mapSearch (hole k) (searchC k t)
indexM i (WordMap m) = indexT i m Root where
indexT !i TIP(kx x) path = (# i, x, hole kx path #)
indexT !i BIN(p m l r) path
| i < sl = indexT i l (LeftBin p m path r)
| otherwise = indexT (i sl) r (RightBin p m l path)
where !sl = getSize l
indexT _ NIL _ = indexFail ()
extractHoleM (WordMap m) = extractHole Root m where
extractHole _ (SNode _ Nil) = mzero
extractHole path TIP(kx x) = return (x, hole kx path)
extractHole path BIN(p m l r) =
extractHole (LeftBin p m path r) l `mplus`
extractHole (RightBin p m l path) r
clearM HOLE(_ path) = WordMap (assign nil path)
assignM v HOLE(kx path) = WordMap (assign (singleton kx v) path)
unifierM k' k a = Hole <$> unifier k' k a
searchC :: Key -> SNode a -> SearchCont (Path a) a r
searchC !k t notfound found = seek Root t where
seek path t@BIN(p m l r)
| nomatch k p m = notfound (branchHole k p path t)
| zero k m
= seek (LeftBin p m path r) l
| otherwise
= seek (RightBin p m l path) r
seek path t@TIP(ky y)
| k == ky = found y path
| otherwise = notfound (branchHole k ky path t)
seek path NIL = notfound path
before, after :: SNode a -> Path a -> SNode a
before !t Root = t
before !t (LeftBin _ _ path _) = before t path
before !t (RightBin p m l path) = before (bin p m l t) path
after !t Root = t
after !t (RightBin _ _ _ path) = after t path
after !t (LeftBin p m path r) = after (bin p m t r) path
assign :: Sized a => SNode a -> Path a -> SNode a
assign NIL Root = nil
assign NIL (LeftBin _ _ path r) = assign' r path
assign NIL (RightBin _ _ l path) = assign' l path
assign t Root = t
assign t (LeftBin p m path r) = assign' (bin' p m t r) path
assign t (RightBin p m l path) = assign' (bin' p m l t) path
assign' :: Sized a => SNode a -> Path a -> SNode a
assign' !t Root = t
assign' !t (LeftBin p m path r) = assign' (bin' p m t r) path
assign' !t (RightBin p m l path) = assign' (bin' p m l t) path
branchHole :: Key -> Prefix -> Path a -> SNode a -> Path a
branchHole !k !p path t
| zero k m = LeftBin p' m path t
| otherwise = RightBin p' m t path
where m = branchMask k p
p' = mask k m
lookup :: Key -> SNode a -> Lookup a
lookup !k = look where
look BIN(_ m l r) = look (if zeroN k m then l else r)
look TIP(kx x)
| k == kx = some x
look _ = none
singleton :: Sized a => Key -> a -> SNode a
singleton k a = sNode (Tip k a)
singletonMaybe :: Sized a => Key -> Maybe a -> SNode a
singletonMaybe k = maybe nil (singleton k)
traverse :: (Applicative f, Sized b) => (a -> f b) -> SNode a -> f (SNode b)
traverse f = trav where
trav NIL = pure nil
trav TIP(kx x) = singleton kx <$> f x
trav BIN(p m l r) = bin' p m <$> trav l <*> trav r
instance Foldable SNode where
foldMap _ NIL = mempty
foldMap f TIP(_ x) = f x
foldMap f BIN(_ _ l r) = foldMap f l `mappend` foldMap f r
foldr f z BIN(_ _ l r) = foldr f (foldr f z r) l
foldr f z TIP(_ x) = f x z
foldr _ z NIL = z
foldl f z BIN(_ _ l r) = foldl f (foldl f z l) r
foldl f z TIP(_ x) = f z x
foldl _ z NIL = z
foldr1 _ NIL = foldr1Empty
foldr1 _ TIP(_ x) = x
foldr1 f BIN(_ _ l r) = foldr f (foldr1 f r) l
foldl1 _ NIL = foldl1Empty
foldl1 _ TIP(_ x) = x
foldl1 f BIN(_ _ l r) = foldl f (foldl1 f l) r
instance Foldable (TrieMap Word) where
foldMap f (WordMap m) = foldMap f m
foldr f z (WordMap m) = foldr f z m
foldl f z (WordMap m) = foldl f z m
foldr1 f (WordMap m) = foldr1 f m
foldl1 f (WordMap m) = foldl1 f m
map :: Sized b => (a -> b) -> SNode a -> SNode b
map f BIN(p m l r) = bin' p m (map f l) (map f r)
map f TIP(kx x) = singleton kx (f x)
map _ _ = nil
mapMaybe :: Sized b => (a -> Maybe b) -> SNode a -> SNode b
mapMaybe f BIN(p m l r) = bin p m (mapMaybe f l) (mapMaybe f r)
mapMaybe f TIP(kx x) = singletonMaybe kx (f x)
mapMaybe _ _ = nil
mapEither :: (Sized b, Sized c) => (a -> (# Maybe b, Maybe c #)) ->
SNode a -> (# SNode b, SNode c #)
mapEither f BIN(p m l r) = both (bin p m lL) (bin p m lR) (mapEither f) r
where !(# lL, lR #) = mapEither f l
mapEither f TIP(kx x) = both (singletonMaybe kx) (singletonMaybe kx) f x
mapEither _ _ = (# nil, nil #)
unionWith :: Sized a => (a -> a -> Maybe a) -> SNode a -> SNode a -> SNode a
unionWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of
(Nil, _) -> n2
(_, Nil) -> n1
(Tip k x, _) -> alter (maybe (Just x) (f x)) k n2
(_, Tip k x) -> alter (maybe (Just x) (`f` x)) k n1
(Bin p1 m1 l1 r1, Bin p2 m2 l2 r2)
| shorter m1 m2 -> union1
| shorter m2 m1 -> union2
| p1 == p2 -> bin p1 m1 (unionWith f l1 l2) (unionWith f r1 r2)
| otherwise -> join p1 n1 p2 n2
where
union1 | nomatch p2 p1 m1 = join p1 n1 p2 n2
| zero p2 m1 = bin p1 m1 (unionWith f l1 n2) r1
| otherwise = bin p1 m1 l1 (unionWith f r1 n2)
union2 | nomatch p1 p2 m2 = join p1 n1 p2 n2
| zero p1 m2 = bin p2 m2 (unionWith f n1 l2) r2
| otherwise = bin p2 m2 l2 (unionWith f n1 r2)
alter :: Sized a => (Maybe a -> Maybe a) -> Key -> SNode a -> SNode a
alter f k t = getWordMap $ alterM f k (WordMap t)
intersectionWith :: Sized c => (a -> b -> Maybe c) -> SNode a -> SNode b -> SNode c
intersectionWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of
(Nil, _) -> nil
(_, Nil) -> nil
(Tip k x, _) -> option (lookup k n2) nil (singletonMaybe k . f x)
(_, Tip k y) -> option (lookup k n1) nil (singletonMaybe k . flip f y)
(Bin p1 m1 l1 r1, Bin p2 m2 l2 r2)
| shorter m1 m2 -> intersection1
| shorter m2 m1 -> intersection2
| p1 == p2 -> bin p1 m1 (intersectionWith f l1 l2) (intersectionWith f r1 r2)
| otherwise -> nil
where
intersection1 | nomatch p2 p1 m1 = nil
| zero p2 m1 = intersectionWith f l1 n2
| otherwise = intersectionWith f r1 n2
intersection2 | nomatch p1 p2 m2 = nil
| zero p1 m2 = intersectionWith f n1 l2
| otherwise = intersectionWith f n1 r2
differenceWith :: Sized a => (a -> b -> Maybe a) -> SNode a -> SNode b -> SNode a
differenceWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of
(Nil, _) -> nil
(_, Nil) -> n1
(Tip k x, _) -> option (lookup k n2) n1 (singletonMaybe k . f x)
(_, Tip k y) -> alter (>>= flip f y) k n1
(Bin p1 m1 l1 r1, Bin p2 m2 l2 r2)
| shorter m1 m2 -> difference1
| shorter m2 m1 -> difference2
| p1 == p2 -> bin p1 m1 (differenceWith f l1 l2) (differenceWith f r1 r2)
| otherwise -> n1
where
difference1 | nomatch p2 p1 m1 = n1
| zero p2 m1 = bin p1 m1 (differenceWith f l1 n2) r1
| otherwise = bin p1 m1 l1 (differenceWith f r1 n2)
difference2 | nomatch p1 p2 m2 = n1
| zero p1 m2 = differenceWith f n1 l2
| otherwise = differenceWith f n1 r2
isSubmapOfBy :: LEq a b -> LEq (SNode a) (SNode b)
isSubmapOfBy (<=) t1@BIN(p1 m1 l1 r1) BIN(p2 m2 l2 r2)
| shorter m1 m2 = False
| shorter m2 m1 = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2
else isSubmapOfBy (<=) t1 r2)
| otherwise = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2
isSubmapOfBy _ BIN(_ _ _ _) _ = False
isSubmapOfBy (<=) TIP(k x) t2 = option (lookup k t2) False (x <=)
isSubmapOfBy _ NIL _ = True
zero :: Key -> Mask -> Bool
zero i m
= i .&. m == 0
nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
= (mask i m) /= p
match i p m
= (mask i m) == p
zeroN :: Nat -> Nat -> Bool
zeroN i m = (i .&. m) == 0
mask :: Nat -> Nat -> Prefix
mask i m
= i .&. compl ((m1) .|. m)
shorter :: Mask -> Mask -> Bool
shorter m1 m2
= m1 > m2
branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
= highestBitMask (p1 `xor` p2)
highestBitMask :: Nat -> Nat
highestBitMask x0
= case (x0 .|. shiftR x0 1) of
x1 -> case (x1 .|. shiftR x1 2) of
x2 -> case (x2 .|. shiftR x2 4) of
x3 -> case (x3 .|. shiftR x3 8) of
x4 -> case (x4 .|. shiftR x4 16) of
x5 -> case (x5 .|. shiftR x5 32) of
x6 -> (x6 `xor` (shiftR x6 1))
join :: Prefix -> SNode a -> Prefix -> SNode a -> SNode a
join p1 t1 p2 t2
| zero p1 m = bin' p m t1 t2
| otherwise = bin' p m t2 t1
where
m = branchMask p1 p2
p = mask p1 m
nil :: SNode a
nil = SNode 0 Nil
bin :: Prefix -> Mask -> SNode a -> SNode a -> SNode a
bin p m l@(SNode sl tl) r@(SNode sr tr) = case (tl, tr) of
(Nil, _) -> r
(_, Nil) -> l
_ -> SNode (sl + sr) (Bin p m l r)
bin' :: Prefix -> Mask -> SNode a -> SNode a -> SNode a
bin' p m l@SNode{sz=sl} r@SNode{sz=sr} = assert (nonempty l && nonempty r) $ SNode (sl + sr) (Bin p m l r)
where nonempty NIL = False
nonempty _ = True
unifier :: Sized a => Key -> Key -> a -> Maybe (WHole a)
unifier k' k a
| k' == k = Nothing
| otherwise = Just (WHole k' $ branchHole k' k Root (singleton k a))