{-# LANGUAGE ScopedTypeVariables #-} module Data.KdTree ( KdTree -- * Construction , fromVector -- * Queries , nearest , toList -- * Diagnostics , isValid , showKdTree ) where import Prelude hiding (sort) import Data.List (minimumBy) import Data.Maybe (maybeToList) import Data.Ord (comparing) import Linear hiding (point) import Control.Lens import qualified Data.Vector.Generic as V import Data.Vector.Algorithms.Intro (sortBy) -- | The k-d tree is a data structure capable of efficiently answering -- nearest neighbor search queries in low-dimensional spaces. As a rule -- of thumb, for efficient lookups the number of points in @k@ dimensions -- should greatly exceed @2^k@ data KdTree f a = KdNode { point :: !(f a) , axis :: E f , left :: KdTree f a , right :: KdTree f a } | KdEmpty -- | Construct a @KdTree@ from a vector of points fromVector :: (Ord a, V.Vector v (f a)) => [E f] -> v (f a) -> KdTree f a fromVector basis pts = go (cycle basis) pts where go _ pts | V.null pts = KdEmpty go (axis:rest) pts = let pts' = V.modify (sortBy $ comparing (^. el axis)) pts pivotIdx = V.length pts' `div` 2 in KdNode { point = pts' V.! pivotIdx , axis = axis , left = go rest $ V.take pivotIdx pts' , right = go rest $ V.drop (pivotIdx+1) pts' } quadranceTo :: (Num a, Metric f) => f a -> f a -> a quadranceTo a b = quadrance (a ^-^ b) -- | Find the point nearest to the given point nearest :: forall f a. (Ord a, Num a, Metric f) => f a -> KdTree f a -> Maybe (f a) nearest pt tree = go tree where go :: KdTree f a -> Maybe (f a) go KdEmpty = Nothing go (KdNode nodePt axis l r) | (pt ^. el axis) <= (nodePt ^. el axis) = go' nodePt axis l r | otherwise = go' nodePt axis r l go' :: f a -- ^ The point of the node we are sitting at -> E f -- ^ The splitting axis of the node -> KdTree f a -- ^ The subnode the query point sits in -> KdTree f a -- ^ The other subnode -> Maybe (f a) go' nodePt axis side other = let best = case go side of Nothing -> [nodePt] Just best' -> [best', nodePt] tryAdj = (pt^.el axis - nodePt^.el axis)^2 <= quadrance (pt ^-^ nodePt) bestAdj = if tryAdj then maybeToList $ go other else [] in Just $ minimumBy (comparing $ quadranceTo pt) (best ++ bestAdj) -- | List all points in a tree toList :: KdTree f a -> [f a] toList KdEmpty = [] toList (KdNode point _ l r) = point : (toList l ++ toList r) -- | Verify that the node is well-formed nodeIsValid :: Ord a => KdTree f a -> Bool nodeIsValid KdEmpty = True nodeIsValid (KdNode point axis l r) = all (\p->p^.el axis <= point^.el axis) (toList l) && all (\p->p^.el axis > point^.el axis) (toList r) -- | Verify that the tree is well-formed (recursively) isValid :: Ord a => KdTree f a -> Bool isValid KdEmpty = True isValid node@(KdNode _ _ l r) = nodeIsValid node && isValid l && isValid r onAxis :: E f -> (a -> a -> b) -> f a -> f a -> b onAxis (E l) f a b = f (a ^. l) (b ^. l) -- | Given names for the axes show the tree showKdTree :: Show (f a) => f String -> KdTree f a -> String showKdTree axisNames tree = unlines $ fmt 0 tree where --fmt :: Int -> Kdtree f a -> [String] fmt depth node = case node of KdEmpty -> [indent "KdEmpty"] (KdNode point axis l r) -> [ indent $ "KdNode ("++show point++") "++show (axisNames ^. el axis) ] ++ fmt (depth+2) l ++ [""] ++ fmt (depth+2) r where indent = (replicate depth ' ' ++)