{-# LANGUAGE DeriveGeneric #-}

module Data.KdMap.Static
       ( -- * Usage

         -- $usage

         -- * Reference

         -- ** Types
         PointAsListFn
       , SquaredDistanceFn
       , KdMap
         -- ** /k/-d map construction
       , empty
       , emptyWithDist
       , singleton
       , singletonWithDist
       , build
       , buildWithDist
       , insertUnbalanced
       , batchInsertUnbalanced
         -- ** Query
       , nearest
       , inRadius
       , kNearest
       , inRange
       , assocs
       , keys
       , elems
       , null
       , size
         -- ** Folds
       , foldrWithKey
         -- ** Utilities
       , defaultSqrDist
         -- ** Internal (for testing)
       , isValid
       ) where

import Control.DeepSeq
import Control.DeepSeq.Generics (genericRnf)
import GHC.Generics

import Control.Applicative hiding (empty)
import Data.Foldable
import Prelude hiding (null)
import qualified Data.List as L
import Data.Maybe
import Data.Ord
import qualified Data.PQueue.Prio.Max as Q
import Data.Traversable

-- $usage
--
-- The 'KdMap' is a variant of 'Data.KdTree.Static.KdTree' where each point in
-- the tree is associated with some data. When talking about 'KdMap's,
-- we'll refer to the points and their associated data as the /points/
-- and /values/ of the 'KdMap', respectively. It might help to think
-- of 'Data.KdTree.Static.KdTree' and 'KdMap' as being analogous to
-- 'Set' and 'Map'.
--
-- Suppose you wanted to perform point queries on a set of 3D points,
-- where each point is associated with a 'String'. Here's how to build
-- a 'KdMap' of the data and perform a nearest neighbor query (if this
-- doesn't make sense, start with the documentation for
-- 'Data.KdTree.Static.KdTree'):
--
-- @
-- >>> let points = [(Point3d 0.0 0.0 0.0), (Point3d 1.0 1.0 1.0)]
--
-- >>> let valueStrings = [\"First\", \"Second\"]
--
-- >>> let pointValuePairs = 'zip' points valueStrings
--
-- >>> let kdm = 'build' point3dAsList pointValuePairs
--
-- >>> 'nearest' kdm (Point3d 0.1 0.1 0.1)
-- [Point3d {x = 0.0, y = 0.0, z = 0.0}, \"First\"]
-- @

data TreeNode a p v = TreeNode { _treeLeft :: TreeNode a p v
                               , _treePoint :: (p, v)
                               , _axisValue :: a
                               , _treeRight :: TreeNode a p v
                               } |
                      Empty
  deriving (Generic, Show, Read)
instance (NFData a, NFData p, NFData v) => NFData (TreeNode a p v) where rnf = genericRnf

mapTreeNode :: (v1 -> v2) -> TreeNode a p v1 -> TreeNode a p v2
mapTreeNode _ Empty = Empty
mapTreeNode f (TreeNode left (k, v) axisValue right) =
  TreeNode (mapTreeNode f left) (k, f v) axisValue (mapTreeNode f right)

-- | Converts a point of type @p@ with axis values of type
-- @a@ into a list of axis values [a].
type PointAsListFn a p = p -> [a]

-- | Returns the squared distance between two points of type
-- @p@ with axis values of type @a@.
type SquaredDistanceFn a p = p -> p -> a

-- | A /k/-d tree structure that stores points of type @p@ with axis
-- values of type @a@. Additionally, each point is associated with a
-- value of type @v@.
data KdMap a p v = KdMap { _pointAsList :: PointAsListFn a p
                         , _distSqr     :: SquaredDistanceFn a p
                         , _rootNode    :: TreeNode a p v
                         , _size        :: Int
                         } deriving Generic
instance (NFData a, NFData p, NFData v) => NFData (KdMap a p v) where rnf = genericRnf

instance (Show a, Show p, Show v) => Show (KdMap a p v) where
  show (KdMap _ _ rootNode _) = "KdMap " ++ show rootNode

instance Functor (KdMap a p) where
  fmap f kdMap = kdMap { _rootNode = mapTreeNode f (_rootNode kdMap) }

foldrTreeNode :: ((p, v) -> b -> b) -> b -> TreeNode a p v -> b
foldrTreeNode _ z Empty = z
foldrTreeNode f z (TreeNode left p _ right) =
  foldrTreeNode f (f p (foldrTreeNode f z right)) left

-- | Performs a foldr over each point-value pair in the 'KdMap'.
foldrWithKey :: ((p, v) -> b -> b) -> b -> KdMap a p v -> b
foldrWithKey f z (KdMap _ _ r _) = foldrTreeNode f z r

instance Foldable (KdMap a p) where
  foldr f = foldrWithKey (f . snd)

traverseTreeNode :: Applicative f => (b -> f c) -> TreeNode a p b -> f (TreeNode a p c)
traverseTreeNode _ Empty = pure Empty
traverseTreeNode f (TreeNode l (p, v) axisValue r) =
  TreeNode <$>
    traverseTreeNode f l <*>
    ((,) p <$> f v) <*> -- would simply be traverse f (p, v), but
                        -- base-4.6.* doesn't have a Traversable
                        -- instance for tuples.
    pure axisValue <*>
    traverseTreeNode f r

instance Traversable (KdMap a p) where
  traverse f (KdMap p d r n) =
    KdMap <$> pure p <*> pure d <*> traverseTreeNode f r <*> pure n

-- | Builds an empty 'KdMap'.
empty :: Real a => PointAsListFn a p -> KdMap a p v
empty p2l = emptyWithDist p2l (defaultSqrDist p2l)

-- | Builds an empty 'KdMap' using a user-specified squared distance
-- function.
emptyWithDist :: Real a => PointAsListFn a p
                        -> SquaredDistanceFn a p
                        -> KdMap a p v
emptyWithDist p2l d2 = KdMap p2l d2 Empty 0

-- | Returns 'True' if the given 'KdMap' is empty.
null :: KdMap a p v -> Bool
null kdm = _size kdm == 0

-- | Builds a 'KdMap' with a single point-value pair and a
-- user-specified squared distance function.
singletonWithDist :: Real a => PointAsListFn a p
                            -> SquaredDistanceFn a p
                            -> (p, v)
                            -> KdMap a p v
singletonWithDist p2l d2 (p, v) =
  let singletonTreeNode = TreeNode Empty (p, v) (head $ p2l p) Empty
  in  KdMap p2l d2 singletonTreeNode 1

-- | Builds a 'KdMap' with a single point-value pair.
singleton :: Real a => PointAsListFn a p -> (p, v) -> KdMap a p v
singleton p2l (p, v) = singletonWithDist p2l (defaultSqrDist p2l) (p, v)

quickselect :: (b -> b -> Ordering) -> Int -> [b] -> b
quickselect cmp = go
  where go _ [] = error "quickselect must be called on a non-empty list."
        go k (x:xs) | k < l = go k ys
                    | k > l = go (k - l - 1) zs
                    | otherwise = x
          where (ys, zs) = L.partition ((== LT) . (`cmp` x)) xs
                l = length ys

-- | Builds a 'KdMap' from a list of pairs of points (of type p) and
-- values (of type v), using a user-specified squared distance
-- function.
--
-- Average time complexity: /O(n * log(n))/ for /n/ data points.
--
-- Worst case time complexity: /O(n^2)/ for /n/ data points.
--
-- Worst case space complexity: /O(n)/ for /n/ data points.
buildWithDist :: Real a => PointAsListFn a p
                        -> SquaredDistanceFn a p
                        -> [(p, v)]
                        -> KdMap a p v
buildWithDist p2l d2 [] = emptyWithDist p2l d2
buildWithDist pointAsList distSqr dataPoints =
  let axisValsPointsPairs = zip (map (cycle . pointAsList . fst) dataPoints) dataPoints
  in  KdMap { _pointAsList = pointAsList
            , _distSqr     = distSqr
            , _rootNode    = buildTreeInternal axisValsPointsPairs
            , _size        = length dataPoints
            }
  where buildTreeInternal [] = Empty
        buildTreeInternal ps =
          let n = length ps
              (medianAxisVal : _, _) =
                quickselect (comparing (head . fst)) (n `div` 2) ps
              f ([], _) _ = error "buildKdMap.f: no empty lists allowed!"
              f (v : vt, p) (lt, maybeMedian, gt)
                | v < medianAxisVal = ((vt, p) : lt, maybeMedian, gt)
                | v > medianAxisVal = (lt, maybeMedian, (vt, p) : gt)
                | otherwise =
                    case maybeMedian of
                      Nothing -> (lt, Just p, gt)
                      Just _ -> ((vt, p) : lt, maybeMedian, gt)
              (leftPoints, maybeMedianPt, rightPoints) = L.foldr f ([], Nothing, []) ps
          in  TreeNode
              { _treeLeft  = buildTreeInternal leftPoints
              , _treePoint = fromJust maybeMedianPt
              , _axisValue = medianAxisVal
              , _treeRight = buildTreeInternal rightPoints
              }

-- | A default implementation of squared distance given two points and
-- a 'PointAsListFn'.
defaultSqrDist :: Num a => PointAsListFn a p -> SquaredDistanceFn a p
defaultSqrDist pointAsList k1 k2 =
  L.sum $ map (^ (2 :: Int)) $ zipWith (-) (pointAsList k1) (pointAsList k2)

-- | Builds a 'KdTree' from a list of pairs of points (of type p) and
-- values (of type v) using a default squared distance function
-- 'defaultSqrDist'.
--
-- Average complexity: /O(n * log(n))/ for /n/ data points.
--
-- Worst case time complexity: /O(n^2)/ for /n/ data points.
--
-- Worst case space complexity: /O(n)/ for /n/ data points.
build :: Real a => PointAsListFn a p -> [(p, v)] -> KdMap a p v
build pointAsList =
  buildWithDist pointAsList $ defaultSqrDist pointAsList

-- | Inserts a point-value pair into a 'KdMap'. This can potentially
-- cause the internal tree structure to become unbalanced. If the tree
-- becomes too unbalanced, point queries will be very inefficient. If
-- you need to perform lots of point insertions on an already existing
-- /k/-d map, check out
-- @Data.KdMap.Dynamic.@'Data.KdMap.Dynamic.KdMap'.
--
-- Average complexity: /O(log(n))/ for /n/ data points.
--
-- Worst case time complexity: /O(n)/ for /n/ data points.
insertUnbalanced :: Real a => KdMap a p v -> p -> v -> KdMap a p v
insertUnbalanced kdm@(KdMap pointAsList _ rootNode n) p' v' =
  kdm { _rootNode = go rootNode (cycle $ pointAsList p'), _size = n + 1 }
  where
    go _ [] = error "insertUnbalanced.go: no empty lists allowed!"
    go Empty (axisValue' : _) = TreeNode Empty (p', v') axisValue' Empty
    go t@(TreeNode left _ nodeAxisValue right) (axisValue' : nextAxisValues)
      | axisValue' <= nodeAxisValue = t { _treeLeft = go left nextAxisValues }
      | otherwise = t { _treeRight = go right nextAxisValues }

-- | Inserts a list of point-value pairs into a 'KdMap'. This can
-- potentially cause the internal tree structure to become unbalanced,
-- which leads to inefficient point queries.
--
-- Average complexity: /O(n * log(n))/ for /n/ data points.
--
-- Worst case time complexity: /O(n^2)/ for /n/ data points.
batchInsertUnbalanced :: Real a => KdMap a p v -> [(p, v)] -> KdMap a p v
batchInsertUnbalanced = foldl' $ \kdm (p, v) -> insertUnbalanced kdm p v

assocsInternal :: TreeNode a p v -> [(p, v)]
assocsInternal t = go t []
  where go Empty = id
        go (TreeNode l p _ r) = go l . (p :) . go r

-- | Returns a list of all the point-value pairs in the 'KdMap'.
--
-- Time complexity: /O(n)/ for /n/ data points.
assocs :: KdMap a p v -> [(p, v)]
assocs (KdMap _ _ t _) = assocsInternal t

-- | Returns all points in the 'KdMap'.
--
-- Time complexity: /O(n)/ for /n/ data points.
keys :: KdMap a p v -> [p]
keys = map fst . assocs

-- | Returns all values in the 'KdMap'.
--
-- Time complexity: /O(n)/ for /n/ data points.
elems :: KdMap a p v -> [v]
elems = map snd . assocs

-- | Given a 'KdMap' and a query point, returns the point-value pair
-- in the 'KdMap' with the point nearest to the query.
--
-- Average time complexity: /O(log(n))/ for /n/ data points.
--
-- Worst case time complexity: /O(n)/ for /n/ data points.
--
-- Throws error if called on an empty 'KdMap'.
nearest :: Real a => KdMap a p v -> p -> (p, v)
nearest (KdMap _ _ Empty _) _ =
  error "Attempted to call nearest on an empty KdMap."
nearest (KdMap pointAsList distSqr t@(TreeNode _ root _ _) _) query =
  -- This is an ugly way to kickstart the function but it's faster
  -- than using a Maybe.
  fst $ go (root, distSqr (fst root) query) (cycle $ pointAsList query) t
  where
    go _ [] _ = error "nearest.go: no empty lists allowed!"
    go bestSoFar _ Empty = bestSoFar
    go bestSoFar
       (queryAxisValue : qvs)
       (TreeNode left (nodeK, nodeV) nodeAxisVal right) =
      let better x1@(_, dist1) x2@(_, dist2) = if dist1 < dist2
                                               then x1
                                               else x2
          currDist       = distSqr query nodeK
          bestAfterNode = better ((nodeK, nodeV), currDist) bestSoFar
          nearestInTree onsideSubtree offsideSubtree =
            let bestAfterOnside = go bestAfterNode qvs onsideSubtree
                checkOffsideSubtree =
                  (queryAxisValue - nodeAxisVal)^(2 :: Int) < snd bestAfterOnside
            in  if checkOffsideSubtree
                then go bestAfterOnside qvs offsideSubtree
                else bestAfterOnside
      in  if queryAxisValue <= nodeAxisVal
          then nearestInTree left right
          else nearestInTree right left

-- | Given a 'KdMap', a query point, and a radius, returns all
-- point-value pairs in the 'KdMap' with points within the given
-- radius of the query point.
--
-- Points are not returned in any particular order.
--
-- Worst case time complexity: /O(n)/ for /n/ data points and a radius
-- that spans all points in the structure.
inRadius :: Real a => KdMap a p v
                   -> a -- ^ radius
                   -> p -- ^ query point
                   -> [(p, v)] -- ^ list of point-value pairs with
                               -- points within given radius of query
inRadius (KdMap pointAsList distSqr t _) radius query =
  go (cycle $ pointAsList query) t []
  where
    go [] _ _ = error "inRadius.go: no empty lists allowed!"
    go _ Empty acc = acc
    go (queryAxisValue : qvs) (TreeNode left (k, v) nodeAxisVal right) acc =
      let onTheLeft = queryAxisValue <= nodeAxisVal
          accAfterOnside = if   onTheLeft
                           then go qvs left acc
                           else go qvs right acc
          accAfterOffside = if   abs (queryAxisValue - nodeAxisVal) < radius
                            then if   onTheLeft
                                 then go qvs right accAfterOnside
                                 else go qvs left accAfterOnside
                            else accAfterOnside
          accAfterCurrent = if distSqr k query <= radius * radius
                            then (k, v) : accAfterOffside
                            else accAfterOffside
      in  accAfterCurrent

-- | Given a 'KdMap', a query point, and a number @k@, returns the @k@
-- point-value pairs with the nearest points to the query.
--
-- Neighbors are returned in order of increasing distance from query
-- point.
--
-- Average time complexity: /log(k) * log(n)/ for /k/ nearest
-- neighbors on a structure with /n/ data points.
--
-- Worst case time complexity: /n * log(k)/ for /k/ nearest
-- neighbors on a structure with /n/ data points.
kNearest :: Real a => KdMap a p v -> Int -> p -> [(p, v)]
kNearest (KdMap pointAsList distSqr t _) numNeighbors query =
  reverse $ map snd $ Q.toList $ go (cycle $ pointAsList query) Q.empty t
  where
    -- go :: [Double] -> Q.MaxPQueue Double (p, d) -> TreeNode p d -> KQueue p d
    go [] _ _ = error "kNearest.go: no empty lists allowed!"
    go _ q Empty = q
    go (queryAxisValue : qvs) q (TreeNode left (k, v) nodeAxisVal right) =
      let insertBounded queue dist x
            | Q.size queue < numNeighbors = Q.insert dist x queue
            | otherwise = if dist < fst (Q.findMax queue)
                          then Q.insert dist x $ Q.deleteMax queue
                          else queue
          q' = insertBounded q (distSqr k query) (k, v)
          kNear queue onsideSubtree offsideSubtree =
            let queue' = go qvs queue onsideSubtree
                checkOffsideTree =
                  Q.size queue' < numNeighbors ||
                  (queryAxisValue - nodeAxisVal)^(2 :: Int) < fst (Q.findMax queue')
            in  if checkOffsideTree
                then go qvs queue' offsideSubtree
                else queue'
      in  if queryAxisValue <= nodeAxisVal
          then kNear q' left right
          else kNear q' right left

-- | Finds all point-value pairs in a 'KdMap' with points within a
-- given range, where the range is specified as a set of lower and
-- upper bounds.
--
-- Points are not returned in any particular order.
--
-- Worst case time complexity: /O(n)/ for n data points and a range
-- that spans all the points.
--
-- TODO: Maybe use known bounds on entire tree structure to be able to
-- automatically count whole portions of tree as being within given
-- range.
inRange :: Real a => KdMap a p v
                  -> p -- ^ lower bounds of range
                  -> p -- ^ upper bounds of range
                  -> [(p, v)] -- ^ point-value pairs within given
                              -- range
inRange (KdMap pointAsList _ t _) lowers uppers =
  go (cycle (pointAsList lowers) `zip` cycle (pointAsList uppers)) t []
  where
    go [] _ _ = error "inRange.go: no empty lists allowed!"
    go _ Empty acc = acc
    go ((lower, upper) : nextBounds) (TreeNode left p nodeAxisVal right) acc =
      let accAfterLeft = if lower <= nodeAxisVal
                         then go nextBounds left acc
                         else acc
          accAfterRight = if upper > nodeAxisVal
                          then go nextBounds right accAfterLeft
                          else accAfterLeft
          valInRange l x u = l <= x && x <= u
          -- maybe "cache" lowers and uppers as lists sooner as hint
          -- to ghc. Also, maybe only need to check previously
          -- unchecked axes?
          currentInRange =
            L.and $ zipWith3 valInRange
              (pointAsList lowers) (pointAsList $ fst p) (pointAsList uppers)
          accAfterCurrent = if currentInRange
                            then p : accAfterRight
                            else accAfterRight
      in  accAfterCurrent

-- | Returns the number of point-value pairs in the 'KdMap'.
--
-- Time complexity: /O(1)/
size :: KdMap a p v -> Int
size (KdMap _ _ _ n) = n

isTreeNodeValid :: Real a => PointAsListFn a p -> Int -> TreeNode a p v -> Bool
isTreeNodeValid _ _ Empty = True
isTreeNodeValid pointAsList axis (TreeNode l (k, _) nodeAxisVal r) =
  let childrenAxisValues = map ((!! axis) . pointAsList . fst) . assocsInternal
      leftSubtreeLess = L.all (<= nodeAxisVal) $ childrenAxisValues l
      rightSubtreeGreater = L.all (> nodeAxisVal) $ childrenAxisValues r
      nextAxis = (axis + 1) `mod` length (pointAsList k)
  in  leftSubtreeLess && rightSubtreeGreater &&
      isTreeNodeValid pointAsList nextAxis l && isTreeNodeValid pointAsList nextAxis r

-- | Returns 'True' if tree structure adheres to k-d tree
-- properties. For internal testing use.
isValid :: Real a => KdMap a p v -> Bool
isValid (KdMap pointAsList _ r _) = isTreeNodeValid pointAsList 0 r