module Data.BKTree
(
Distance
, BKTree
, empty
, insert
, query
) where
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
type Distance s = s -> s -> Int
data BKTree s =
BKTree !(BK s) (Distance s)
data BK s
= EmptyBK
| BK !s !(IntMap (BK s))
narrow :: Int -> Int -> IntMap a -> IntMap a
narrow n m im | n == m = IntMap.fromList (maybe [] (\v -> [(n, v)]) (IntMap.lookup n im))
narrow n m im | otherwise = insMaybe m res pr
where
(_, pl, res0) = IntMap.splitLookup n im
(res, pr, _) = IntMap.splitLookup m (insMaybe n res0 pl)
insMaybe k im' = maybe im' (\v -> IntMap.insert k v im')
empty :: Distance s
-> BKTree s
empty = BKTree EmptyBK
insert :: s -> BKTree s -> BKTree s
insert s (BKTree bk f) = BKTree (insert' s f bk) f
insert' :: s -> Distance s -> BK s -> BK s
insert' s _ EmptyBK = BK s IntMap.empty
insert' s f bk@(BK s' bks)
| dist == 0 = bk
| otherwise = BK s' $ flip (IntMap.insert dist) bks $
maybe (insert' s f EmptyBK) (insert' s f) (IntMap.lookup dist bks)
where dist = f s s'
query :: Int
-> s -> BKTree s
-> [(s, Int)]
query maxd s (BKTree bk f) = query' maxd s f bk
query' :: Int -> s -> Distance s -> BK s -> [(s, Int)]
query' _ _ _ EmptyBK = []
query' maxd s f (BK s' bks) = match ++ concatMap (query' maxd s f) children
where
dist = f s s'
match = if (dist <= maxd) then [(s', dist)] else []
children = IntMap.elems $ narrow (max (dist maxd) 0) (dist + maxd) bks