module Data.KdMap.Static
(
PointAsListFn
, SquaredDistanceFn
, KdMap
, empty
, emptyWithDist
, singleton
, singletonWithDist
, build
, buildWithDist
, insertUnbalanced
, batchInsertUnbalanced
, nearest
, inRadius
, kNearest
, inRange
, assocs
, keys
, elems
, null
, size
, foldrWithKey
, defaultSqrDist
, isValid
) where
import Control.DeepSeq
import Control.DeepSeq.Generics (genericRnf)
import GHC.Generics
import Control.Applicative hiding (empty)
import Data.Foldable
import Prelude hiding (null)
import qualified Data.List as L
import Data.Maybe
import Data.Ord
import qualified Data.PQueue.Prio.Max as Q
import Data.Traversable
data TreeNode a p v = TreeNode { _treeLeft :: TreeNode a p v
, _treePoint :: (p, v)
, _axisValue :: a
, _treeRight :: TreeNode a p v
} |
Empty
deriving (Generic, Show, Read)
instance (NFData a, NFData p, NFData v) => NFData (TreeNode a p v) where rnf = genericRnf
mapTreeNode :: (v1 -> v2) -> TreeNode a p v1 -> TreeNode a p v2
mapTreeNode _ Empty = Empty
mapTreeNode f (TreeNode left (k, v) axisValue right) =
TreeNode (mapTreeNode f left) (k, f v) axisValue (mapTreeNode f right)
type PointAsListFn a p = p -> [a]
type SquaredDistanceFn a p = p -> p -> a
data KdMap a p v = KdMap { _pointAsList :: PointAsListFn a p
, _distSqr :: SquaredDistanceFn a p
, _rootNode :: TreeNode a p v
, _size :: Int
} deriving Generic
instance (NFData a, NFData p, NFData v) => NFData (KdMap a p v) where rnf = genericRnf
instance (Show a, Show p, Show v) => Show (KdMap a p v) where
show (KdMap _ _ rootNode _) = "KdMap " ++ show rootNode
instance Functor (KdMap a p) where
fmap f kdMap = kdMap { _rootNode = mapTreeNode f (_rootNode kdMap) }
foldrTreeNode :: ((p, v) -> b -> b) -> b -> TreeNode a p v -> b
foldrTreeNode _ z Empty = z
foldrTreeNode f z (TreeNode left p _ right) =
foldrTreeNode f (f p (foldrTreeNode f z right)) left
foldrWithKey :: ((p, v) -> b -> b) -> b -> KdMap a p v -> b
foldrWithKey f z (KdMap _ _ r _) = foldrTreeNode f z r
instance Foldable (KdMap a p) where
foldr f = foldrWithKey (f . snd)
traverseTreeNode :: Applicative f => (b -> f c) -> TreeNode a p b -> f (TreeNode a p c)
traverseTreeNode _ Empty = pure Empty
traverseTreeNode f (TreeNode l (p, v) axisValue r) =
TreeNode <$>
traverseTreeNode f l <*>
((,) p <$> f v) <*>
pure axisValue <*>
traverseTreeNode f r
instance Traversable (KdMap a p) where
traverse f (KdMap p d r n) =
KdMap <$> pure p <*> pure d <*> traverseTreeNode f r <*> pure n
empty :: Real a => PointAsListFn a p -> KdMap a p v
empty p2l = emptyWithDist p2l (defaultSqrDist p2l)
emptyWithDist :: Real a => PointAsListFn a p
-> SquaredDistanceFn a p
-> KdMap a p v
emptyWithDist p2l d2 = KdMap p2l d2 Empty 0
null :: KdMap a p v -> Bool
null kdm = _size kdm == 0
singletonWithDist :: Real a => PointAsListFn a p
-> SquaredDistanceFn a p
-> (p, v)
-> KdMap a p v
singletonWithDist p2l d2 (p, v) =
let singletonTreeNode = TreeNode Empty (p, v) (head $ p2l p) Empty
in KdMap p2l d2 singletonTreeNode 1
singleton :: Real a => PointAsListFn a p -> (p, v) -> KdMap a p v
singleton p2l (p, v) = singletonWithDist p2l (defaultSqrDist p2l) (p, v)
quickselect :: (b -> b -> Ordering) -> Int -> [b] -> b
quickselect cmp = go
where go _ [] = error "quickselect must be called on a non-empty list."
go k (x:xs) | k < l = go k ys
| k > l = go (k l 1) zs
| otherwise = x
where (ys, zs) = L.partition ((== LT) . (`cmp` x)) xs
l = length ys
buildWithDist :: Real a => PointAsListFn a p
-> SquaredDistanceFn a p
-> [(p, v)]
-> KdMap a p v
buildWithDist p2l d2 [] = emptyWithDist p2l d2
buildWithDist pointAsList distSqr dataPoints =
let axisValsPointsPairs = zip (map (cycle . pointAsList . fst) dataPoints) dataPoints
in KdMap { _pointAsList = pointAsList
, _distSqr = distSqr
, _rootNode = buildTreeInternal axisValsPointsPairs
, _size = length dataPoints
}
where buildTreeInternal [] = Empty
buildTreeInternal ps =
let n = length ps
(medianAxisVal : _, _) =
quickselect (comparing (head . fst)) (n `div` 2) ps
f ([], _) _ = error "buildKdMap.f: no empty lists allowed!"
f (v : vt, p) (lt, maybeMedian, gt)
| v < medianAxisVal = ((vt, p) : lt, maybeMedian, gt)
| v > medianAxisVal = (lt, maybeMedian, (vt, p) : gt)
| otherwise =
case maybeMedian of
Nothing -> (lt, Just p, gt)
Just _ -> ((vt, p) : lt, maybeMedian, gt)
(leftPoints, maybeMedianPt, rightPoints) = L.foldr f ([], Nothing, []) ps
in TreeNode
{ _treeLeft = buildTreeInternal leftPoints
, _treePoint = fromJust maybeMedianPt
, _axisValue = medianAxisVal
, _treeRight = buildTreeInternal rightPoints
}
defaultSqrDist :: Num a => PointAsListFn a p -> SquaredDistanceFn a p
defaultSqrDist pointAsList k1 k2 =
L.sum $ map (^ (2 :: Int)) $ zipWith () (pointAsList k1) (pointAsList k2)
build :: Real a => PointAsListFn a p -> [(p, v)] -> KdMap a p v
build pointAsList =
buildWithDist pointAsList $ defaultSqrDist pointAsList
insertUnbalanced :: Real a => KdMap a p v -> p -> v -> KdMap a p v
insertUnbalanced kdm@(KdMap pointAsList _ rootNode n) p' v' =
kdm { _rootNode = go rootNode (cycle $ pointAsList p'), _size = n + 1 }
where
go _ [] = error "insertUnbalanced.go: no empty lists allowed!"
go Empty (axisValue' : _) = TreeNode Empty (p', v') axisValue' Empty
go t@(TreeNode left _ nodeAxisValue right) (axisValue' : nextAxisValues)
| axisValue' <= nodeAxisValue = t { _treeLeft = go left nextAxisValues }
| otherwise = t { _treeRight = go right nextAxisValues }
batchInsertUnbalanced :: Real a => KdMap a p v -> [(p, v)] -> KdMap a p v
batchInsertUnbalanced = foldl' $ \kdm (p, v) -> insertUnbalanced kdm p v
assocsInternal :: TreeNode a p v -> [(p, v)]
assocsInternal t = go t []
where go Empty = id
go (TreeNode l p _ r) = go l . (p :) . go r
assocs :: KdMap a p v -> [(p, v)]
assocs (KdMap _ _ t _) = assocsInternal t
keys :: KdMap a p v -> [p]
keys = map fst . assocs
elems :: KdMap a p v -> [v]
elems = map snd . assocs
nearest :: Real a => KdMap a p v -> p -> (p, v)
nearest (KdMap _ _ Empty _) _ =
error "Attempted to call nearest on an empty KdMap."
nearest (KdMap pointAsList distSqr t@(TreeNode _ root _ _) _) query =
fst $ go (root, distSqr (fst root) query) (cycle $ pointAsList query) t
where
go _ [] _ = error "nearest.go: no empty lists allowed!"
go bestSoFar _ Empty = bestSoFar
go bestSoFar
(queryAxisValue : qvs)
(TreeNode left (nodeK, nodeV) nodeAxisVal right) =
let better x1@(_, dist1) x2@(_, dist2) = if dist1 < dist2
then x1
else x2
currDist = distSqr query nodeK
bestAfterNode = better ((nodeK, nodeV), currDist) bestSoFar
nearestInTree onsideSubtree offsideSubtree =
let bestAfterOnside = go bestAfterNode qvs onsideSubtree
checkOffsideSubtree =
(queryAxisValue nodeAxisVal)^(2 :: Int) < snd bestAfterOnside
in if checkOffsideSubtree
then go bestAfterOnside qvs offsideSubtree
else bestAfterOnside
in if queryAxisValue <= nodeAxisVal
then nearestInTree left right
else nearestInTree right left
inRadius :: Real a => KdMap a p v
-> a
-> p
-> [(p, v)]
inRadius (KdMap pointAsList distSqr t _) radius query =
go (cycle $ pointAsList query) t []
where
go [] _ _ = error "inRadius.go: no empty lists allowed!"
go _ Empty acc = acc
go (queryAxisValue : qvs) (TreeNode left (k, v) nodeAxisVal right) acc =
let onTheLeft = queryAxisValue <= nodeAxisVal
accAfterOnside = if onTheLeft
then go qvs left acc
else go qvs right acc
accAfterOffside = if abs (queryAxisValue nodeAxisVal) < radius
then if onTheLeft
then go qvs right accAfterOnside
else go qvs left accAfterOnside
else accAfterOnside
accAfterCurrent = if distSqr k query <= radius * radius
then (k, v) : accAfterOffside
else accAfterOffside
in accAfterCurrent
kNearest :: Real a => KdMap a p v -> Int -> p -> [(p, v)]
kNearest (KdMap pointAsList distSqr t _) numNeighbors query =
reverse $ map snd $ Q.toList $ go (cycle $ pointAsList query) Q.empty t
where
go [] _ _ = error "kNearest.go: no empty lists allowed!"
go _ q Empty = q
go (queryAxisValue : qvs) q (TreeNode left (k, v) nodeAxisVal right) =
let insertBounded queue dist x
| Q.size queue < numNeighbors = Q.insert dist x queue
| otherwise = if dist < fst (Q.findMax queue)
then Q.insert dist x $ Q.deleteMax queue
else queue
q' = insertBounded q (distSqr k query) (k, v)
kNear queue onsideSubtree offsideSubtree =
let queue' = go qvs queue onsideSubtree
checkOffsideTree =
Q.size queue' < numNeighbors ||
(queryAxisValue nodeAxisVal)^(2 :: Int) < fst (Q.findMax queue')
in if checkOffsideTree
then go qvs queue' offsideSubtree
else queue'
in if queryAxisValue <= nodeAxisVal
then kNear q' left right
else kNear q' right left
inRange :: Real a => KdMap a p v
-> p
-> p
-> [(p, v)]
inRange (KdMap pointAsList _ t _) lowers uppers =
go (cycle (pointAsList lowers) `zip` cycle (pointAsList uppers)) t []
where
go [] _ _ = error "inRange.go: no empty lists allowed!"
go _ Empty acc = acc
go ((lower, upper) : nextBounds) (TreeNode left p nodeAxisVal right) acc =
let accAfterLeft = if lower <= nodeAxisVal
then go nextBounds left acc
else acc
accAfterRight = if upper > nodeAxisVal
then go nextBounds right accAfterLeft
else accAfterLeft
valInRange l x u = l <= x && x <= u
currentInRange =
L.and $ zipWith3 valInRange
(pointAsList lowers) (pointAsList $ fst p) (pointAsList uppers)
accAfterCurrent = if currentInRange
then p : accAfterRight
else accAfterRight
in accAfterCurrent
size :: KdMap a p v -> Int
size (KdMap _ _ _ n) = n
isTreeNodeValid :: Real a => PointAsListFn a p -> Int -> TreeNode a p v -> Bool
isTreeNodeValid _ _ Empty = True
isTreeNodeValid pointAsList axis (TreeNode l (k, _) nodeAxisVal r) =
let childrenAxisValues = map ((!! axis) . pointAsList . fst) . assocsInternal
leftSubtreeLess = L.all (<= nodeAxisVal) $ childrenAxisValues l
rightSubtreeGreater = L.all (> nodeAxisVal) $ childrenAxisValues r
nextAxis = (axis + 1) `mod` length (pointAsList k)
in leftSubtreeLess && rightSubtreeGreater &&
isTreeNodeValid pointAsList nextAxis l && isTreeNodeValid pointAsList nextAxis r
isValid :: Real a => KdMap a p v -> Bool
isValid (KdMap pointAsList _ r _) = isTreeNodeValid pointAsList 0 r