module Data.Set.BKTree
(
BKTree
,Metric(..)
,null,size,empty
,fromList,singleton
,insert
,member,memberDistance
,delete
,union,unions
,elems,elemsDistance
,closest
#ifdef DEBUG
,runTests
#endif
)where
import qualified Data.IntMap as M
import qualified Data.List as L hiding (null)
import Prelude hiding (null)
import Data.Array.IArray (Array,array,listArray,(!),assocs)
import Data.Array.Unboxed (UArray)
#ifdef DEBUG
import qualified Prelude
import Test.QuickCheck
import Text.Printf
import System.Exit
#endif
class Eq a => Metric a where
distance :: a -> a -> Int
instance Metric Int where
distance i j = abs (i j)
instance Metric Integer where
distance i j = fromInteger (abs (i j))
instance Metric Char where
distance i j = abs (fromEnum i fromEnum j)
hirschberg :: Eq a => [a] -> [a] -> Int
hirschberg xs [] = length xs
hirschberg xs ys = let
lxs = length xs
lys = length ys
start_arr :: UArray Int Int
start_arr = listArray (1,lys) [1..lys]
in (L.foldl' (\arr (i,xi) -> let
narr :: UArray Int Int
narr = array (1,lys) (snd $ L.mapAccumL
(\(s,c) ((j,el),yj) -> let
nc = minimum
[s + (if xi==yj then 0 else 1)
,el + 1
,c + 1
]
in ((el,nc),(j,nc)))
(i1,i)
(zip (assocs arr) ys)
)
in narr
) start_arr (zip [1..] xs))!lys
instance Eq a => Metric [a] where
distance = hirschberg
data BKTree a = Node a !Int (M.IntMap (BKTree a))
| Empty
#ifdef DEBUG
deriving Show
#endif
null :: BKTree a -> Bool
null (Empty) = True
null (Node _ _ _) = False
size :: BKTree a -> Int
size (Empty) = 0
size (Node _ s _) = s
empty :: BKTree a
empty = Empty
singleton :: a -> BKTree a
singleton a = Node a 1 M.empty
insert :: Metric a => a -> BKTree a -> BKTree a
insert a Empty = Node a 1 M.empty
insert a (Node b size map) = Node b (size+1) map'
where map' = M.insertWith recurse d (Node a 1 M.empty) map
d = distance a b
recurse _ tree = insert a tree
member :: Metric a => a -> BKTree a -> Bool
member a Empty = False
member a (Node b _ map)
| d == 0 = True
| otherwise = case M.lookup d map of
Nothing -> False
Just tree -> member a tree
where d = distance a b
memberDistance :: Metric a => Int -> a -> BKTree a -> Bool
memberDistance n a Empty = False
memberDistance n a (Node b _ map)
| d <= n = True
| otherwise = any (memberDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) map of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
delete :: Metric a => a -> BKTree a -> BKTree a
delete a Empty = Empty
delete a t@(Node b _ map)
| d == 0 = unions (M.elems map)
| otherwise = Node b sz subtrees
where d = distance a b
subtrees = M.update (Just . delete a) d map
sz = sum (L.map size (M.elems subtrees)) + 1
elems :: BKTree a -> [a]
elems Empty = []
elems (Node a _ imap) = a : concatMap elems (M.elems imap)
elemsDistance :: Metric a => Int -> a -> BKTree a -> [a]
elemsDistance n a Empty = []
elemsDistance n a (Node b _ imap)
= (if d <= n then (b :) else id) $
concatMap (elemsDistance n a) (M.elems subMap)
where d = distance a b
subMap = case M.split (dn1) imap of
(_,mapRight) ->
case M.split (d+n+1) mapRight of
(mapCenter,_) -> mapCenter
fromList :: Metric a => [a] -> BKTree a
fromList xs = L.foldl' (flip insert) empty xs
unions :: Metric a => [BKTree a] -> BKTree a
unions xs = fromList $ concat $ map elems xs
union :: Metric a => BKTree a -> BKTree a -> BKTree a
union t1 t2 = unions [t1,t2]
closest :: Metric a => a -> BKTree a -> Maybe (a,Int)
closest a Empty = Nothing
closest a tree@(Node b _ _) = Just (closeLoop a (b,distance a b) tree)
closeLoop a candidate Empty = candidate
closeLoop a candidate@(_,d) (Node x _ imap)
= L.foldl' (closeLoop a) newCand (M.elems subMap)
where newCand = if j >= d
then candidate
else (x,j)
j = distance a x
subMap = case M.split (dj1) imap of
(_,mapRight) ->
case M.split (d+j+1) mapRight of
(mapCenter,_) -> mapCenter
on rel f x y = rel (f x) (f y)
#ifdef DEBUG
prop_naiveEmpty xs =
distance [] xs == length xs &&
distance xs [] == length (xs::[Int])
prop_naiveCons x xs ys = distance (x:xs) (x:ys) == distance xs (ys::[Int])
prop_naiveDiff x y xs ys = x /= y ==>
distance (x:xs) (y:ys) ==
1 + minimum [distance (x:xs) (ys :: [Int])
,distance (x:xs) (x:ys)
,distance xs (y:ys)]
sem tree = L.sort (elems tree) :: [Int]
trans f xs = sem (f (fromList xs))
invariant t = inv [] t
inv dict Empty = True
inv dict (Node a _ imap)
= all (\ (d,b) -> distance a b == d) dict &&
all (\ (d,t) -> inv ((d,a):dict) t) (M.toList imap)
prop_empty n = not (member (n::Int) empty)
prop_null xs = null (fromList xs) == Prelude.null (xs :: [Int])
prop_singleton n = elems (fromList [n]) == [n :: Int]
prop_fromList xs = sem (fromList xs) == L.sort xs
prop_fromListInv xs = invariant (fromList (xs :: [Int]))
prop_insert n xs =
trans (insert n) xs == L.sort (n:xs)
prop_insertInv n xs =
invariant (insert n (fromList (xs :: [Int])))
prop_member n xs = member n (fromList xs) == L.elem (n::Int) xs
prop_memberDistance dist n xs =
let d = dist `mod` 5
ref = L.any (\e -> distance n e <= d) xs
in collect ref $
memberDistance d n (fromList xs) ==
L.any (\e -> distance n e <= d) (xs :: [Int])
prop_delete n xs =
trans (delete n) xs ==
L.sort (removeFirst (xs :: [Int]))
where removeFirst [] = []
removeFirst (a:as) | a == n = as
| otherwise = a : removeFirst as
prop_deleteInv n xs =
invariant (delete n (fromList (xs :: [Int])))
prop_elems xs = L.sort (elems (fromList xs)) == L.sort (xs::[Int])
prop_elemsDistance dist n xs =
let d = dist `mod` 5 in
L.sort (elemsDistance d n (fromList xs)) ==
L.sort (filter (\e -> distance n e <= d) (xs::[Int]))
prop_unions xss =
sem (unions (map fromList xss)) ==
L.sort (concat (xss::[[Int]]))
prop_unionsInv xss =
invariant (unions (map fromList (xss :: [[Int]])))
prop_union xs ys =
sem (union (fromList xs) (fromList ys)) ==
L.sort (xs ++ (ys::[Int]))
prop_unionInv xs ys =
invariant (union (fromList (xs :: [Int])) (fromList (ys :: [Int])))
prop_closest n xs =
all (\x -> abs x < 100000) xs ==>
case (closest n (fromList xs),xs) of
(Nothing,[]) -> True
(Just (_,d),ys) -> d == minimum (map (distance n) (ys::[Int]))
_ -> False
prop_insertDelete n xs =
trans (delete n . insert n) xs == L.sort (xs::[Int])
prop_sizeEmpty = size empty == 0
prop_sizeFromList xs = size (fromList xs) == length (xs :: [Int])
prop_sizeSucc n xs = size (insert (n::Int) tree) == size tree + 1
where tree = fromList xs
prop_sizeDelete n xs
= size (delete (n::Int) tree) ==
size tree (if n `member` tree then 1 else 0)
where tree = fromList xs
prop_sizeUnion xs ys = size (union treeXs treeYs) == size treeXs + size treeYs
where (treeXs,treeYs) = (fromList xs, fromList (ys :: [Int]))
prop_sizeUnions xss = size (unions trees) == sum (map size trees)
where trees = map fromList (xss :: [[Int]])
prop_unionsMember xss =
all (\x -> member x tree) (concat (xss :: [[Int]]))
where tree = unions (map fromList xss)
prop_fromListMember xs =
all (\x -> member x tree) (xs :: [Int])
where tree = fromList xs
data TestCase = forall prop. Testable prop => Tc String prop
tests = [Tc "empty" prop_empty
,Tc "null" prop_null
,Tc "singleton" prop_singleton
,Tc "fromList" prop_fromList
,Tc "fromList inv" prop_fromListInv
,Tc "insert" prop_insert
,Tc "insert inv" prop_insertInv
,Tc "member" prop_member
,Tc "memberDistance" prop_memberDistance
,Tc "delete" prop_delete
,Tc "delete inv" prop_deleteInv
,Tc "elems" prop_elems
,Tc "elemsDistance" prop_elemsDistance
,Tc "unions" prop_unions
,Tc "unions inv" prop_unionsInv
,Tc "union" prop_union
,Tc "union inv" prop_unionInv
,Tc "closest" prop_closest
,Tc "size/empty" prop_sizeEmpty
,Tc "size/fromList" prop_sizeFromList
,Tc "size/succ" prop_sizeSucc
,Tc "size/delete" prop_sizeDelete
,Tc "size/union" prop_sizeUnion
,Tc "size/unions" prop_sizeUnions
,Tc "insert/delete" prop_insertDelete
,Tc "fromList/member" prop_fromListMember
,Tc "unions/member" prop_unionsMember
,Tc "naiveEmpty" prop_naiveEmpty
,Tc "naiveCons" prop_naiveCons
,Tc "naiveDiff" prop_naiveDiff
]
runTests = mapM_ runTest tests
where runTest (Tc s prop)
= do printf "%-25s :" s
result <- quickCheckResult prop
case result of
Success _ -> return ()
GaveUp _ _ -> return ()
_ -> exitFailure
#endif