module Data.Trees.KdTree where -- Haskell implementation of http://en.wikipedia.org/wiki/K-d_tree -- by Issac Trotts import Data.Maybe import qualified Data.Foldable as F import qualified Data.List as L import Test.QuickCheck class Point p where -- |dimension returns the number of coordinates of a point. dimension :: p -> Int -- |coord gets the k'th coordinate, starting from 0. coord :: Int -> p -> Double -- |dist2 returns the squared distance between two points. dist2 :: p -> p -> Double dist2 a b = sum . map diff2 $ [0..dimension a - 1] where diff2 i = (coord i a - coord i b)^2 -- |compareDistance p a b compares the distances of a and b to p. compareDistance :: (Point p) => p -> p -> p -> Ordering compareDistance p a b = dist2 p a `compare` dist2 p b data Point3d = Point3d { p3x :: Double, p3y :: Double, p3z :: Double } deriving (Eq, Ord, Show) instance Point Point3d where dimension _ = 3 coord 0 p = p3x p coord 1 p = p3y p coord 2 p = p3z p data KdTree point = KdNode { kdLeft :: KdTree point, kdPoint :: point, kdRight :: KdTree point, kdAxis :: Int } | KdEmpty deriving (Eq, Ord, Show) instance Functor KdTree where fmap _ KdEmpty = KdEmpty fmap f (KdNode l x r axis) = KdNode (fmap f l) (f x) (fmap f r) axis instance F.Foldable KdTree where foldr f init KdEmpty = init foldr f init (KdNode l x r _) = F.foldr f init3 l where init3 = f x init2 init2 = F.foldr f init r fromList :: Point p => [p] -> KdTree p fromList points = fromListWithDepth points 0 -- Select axis based on depth so that axis cycles through all valid values fromListWithDepth :: Point p => [p] -> Int -> KdTree p fromListWithDepth [] _ = KdEmpty fromListWithDepth points depth = node where axis = axisFromDepth (head points) depth -- Sort point list and choose median as pivot element sortedPoints = L.sortBy (\a b -> coord axis a `compare` coord axis b) points medianIndex = length sortedPoints `div` 2 -- Create node and construct subtrees node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+1), kdPoint = sortedPoints !! medianIndex, kdRight = fromListWithDepth (drop (medianIndex+1) sortedPoints) (depth+1), kdAxis = axis } axisFromDepth :: Point p => p -> Int -> Int axisFromDepth p depth = depth `mod` k where k = dimension p toList :: KdTree p -> [p] toList t = F.foldr (:) [] t subtrees :: KdTree p -> [KdTree p] subtrees KdEmpty = [KdEmpty] subtrees t@(KdNode l x r axis) = subtrees l ++ [t] ++ subtrees r nearestNeighbor :: Point p => KdTree p -> p -> Maybe p nearestNeighbor KdEmpty probe = Nothing nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p nearestNeighbor (KdNode l p r axis) probe = if xProbe <= xp then doStuff l r else doStuff r l where xProbe = coord axis probe xp = coord axis p doStuff tree1 tree2 = let candidates1 = case nearestNeighbor tree1 probe of Nothing -> [p] Just best1 -> [best1, p] sphereIntersectsPlane = (xProbe - xp)^2 <= dist2 probe p candidates2 = if sphereIntersectsPlane then candidates1 ++ maybeToList (nearestNeighbor tree2 probe) else candidates1 in Just . L.minimumBy (compareDistance probe) $ candidates2 -- |invariant tells whether the KD tree property holds for a given tree and -- all its subtrees. -- Specifically, it tests that all points in the left subtree lie to the left -- of the plane, p is on the plane, and all points in the right subtree lie to -- the right. invariant :: Point p => KdTree p -> Bool invariant KdEmpty = True invariant (KdNode l p r axis) = leftIsGood && rightIsGood where x = coord axis p leftIsGood = all ((<= x) . coord axis) (toList l) rightIsGood = all ((>= x) . coord axis) (toList r) invariant' :: Point p => KdTree p -> Bool invariant' = all invariant . subtrees instance Arbitrary Point3d where arbitrary = do x <- arbitrary y <- arbitrary z <- arbitrary return (Point3d x y z)