module Data.Trees.KdTree where
import Data.Maybe
import qualified Data.Foldable as F
import qualified Data.List as L
import Test.QuickCheck
class Point p where
dimension :: p -> Int
coord :: Int -> p -> Double
dist2 :: p -> p -> Double
dist2 a b = sum . map diff2 $ [0..dimension a 1]
where diff2 i = (coord i a coord i b)^2
compareDistance :: (Point p) => p -> p -> p -> Ordering
compareDistance p a b = dist2 p a `compare` dist2 p b
data Point3d = Point3d { p3x :: Double, p3y :: Double, p3z :: Double }
deriving (Eq, Ord, Show)
instance Point Point3d where
dimension _ = 3
coord 0 p = p3x p
coord 1 p = p3y p
coord 2 p = p3z p
data KdTree point = KdNode { kdLeft :: KdTree point,
kdPoint :: point,
kdRight :: KdTree point,
kdAxis :: Int }
| KdEmpty
deriving (Eq, Ord, Show)
instance Functor KdTree where
fmap _ KdEmpty = KdEmpty
fmap f (KdNode l x r axis) = KdNode (fmap f l) (f x) (fmap f r) axis
instance F.Foldable KdTree where
foldr f init KdEmpty = init
foldr f init (KdNode l x r _) = F.foldr f init3 l
where init3 = f x init2
init2 = F.foldr f init r
fromList :: Point p => [p] -> KdTree p
fromList points = fromListWithDepth points 0
fromListWithDepth :: Point p => [p] -> Int -> KdTree p
fromListWithDepth [] _ = KdEmpty
fromListWithDepth points depth = node
where axis = depth `mod` dimension (head points)
sortedPoints =
L.sortBy (\a b -> coord axis a `compare` coord axis b) points
medianIndex = length sortedPoints `div` 2
node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+1),
kdPoint = sortedPoints !! medianIndex,
kdRight = fromListWithDepth (drop (medianIndex+1) sortedPoints) (depth+1),
kdAxis = axis }
toList :: KdTree p -> [p]
toList t = F.foldr (:) [] t
subtrees :: KdTree p -> [KdTree p]
subtrees KdEmpty = [KdEmpty]
subtrees t@(KdNode l x r axis) = subtrees l ++ [t] ++ subtrees r
nearestNeighbor :: Point p => KdTree p -> p -> Maybe p
nearestNeighbor KdEmpty probe = Nothing
nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p
nearestNeighbor (KdNode l p r axis) probe =
if xProbe <= xp then findNearest l r else findNearest r l
where xProbe = coord axis probe
xp = coord axis p
findNearest tree1 tree2 =
let candidates1 = case nearestNeighbor tree1 probe of
Nothing -> [p]
Just best1 -> [best1, p]
sphereIntersectsPlane = (xProbe xp)^2 <= dist2 probe p
candidates2 = if sphereIntersectsPlane
then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
else candidates1 in
Just . L.minimumBy (compareDistance probe) $ candidates2
isValid :: Point p => KdTree p -> Bool
isValid KdEmpty = True
isValid (KdNode l p r axis) = leftIsGood && rightIsGood
where x = coord axis p
leftIsGood = all ((<= x) . coord axis) (toList l)
rightIsGood = all ((>= x) . coord axis) (toList r)
allSubtreesAreValid :: Point p => KdTree p -> Bool
allSubtreesAreValid = all isValid . subtrees
kNearestNeighbors :: (Eq p, Point p) => KdTree p -> Int -> p -> [p]
kNearestNeighbors KdEmpty _ _ = []
kNearestNeighbors _ k _ | k <= 0 = []
kNearestNeighbors tree k probe = nearest : kNearestNeighbors tree' (k1) probe
where nearest = fromJust $ nearestNeighbor tree probe
tree' = tree `remove` nearest
remove :: (Eq p, Point p) => KdTree p -> p -> KdTree p
remove KdEmpty _ = KdEmpty
remove (KdNode l p r axis) pKill =
if p == pKill
then fromListWithDepth (toList l ++ toList r) axis
else if coord axis pKill <= coord axis p
then KdNode (remove l pKill) p r axis
else KdNode l p (remove r pKill) axis
instance Arbitrary Point3d where
arbitrary = do
x <- arbitrary
y <- arbitrary
z <- arbitrary
return (Point3d x y z)