```{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Geometry.KDTree where

import           Control.Lens hiding (imap, element, Empty, (:<))
import           Data.BinaryTree
import           Unsafe.Coerce(unsafeCoerce)
import           Data.Ext
import qualified Data.Foldable as F
import           Data.Geometry.Box
import           Data.Geometry.Point
import           Data.Geometry.Properties
import           Data.Geometry.Vector
import qualified Data.List.NonEmpty as NonEmpty
import           Data.Maybe (fromJust)
import           Data.Proxy
import           Data.LSeq (LSeq, pattern (:<|))
import qualified Data.LSeq as LSeq
import           Data.Util
import qualified Data.Vector.Fixed as FV
import           GHC.TypeLits
import           Prelude hiding (replicate)

--------------------------------------------------------------------------------

newtype Coord (d :: Nat) = Coord { unCoord ::  Int}

instance KnownNat d => Eq (Coord d) where
(Coord i) == (Coord j) = (i `mod` d) == (j `mod` d)
where
d = fromInteger . natVal \$ (Proxy :: Proxy d)

instance KnownNat d => Show (Coord d) where
show (Coord i) = show \$ 1 + (i `mod` d)
where
d = fromInteger . natVal \$ (Proxy :: Proxy d)

instance KnownNat d => Enum (Coord d) where
toEnum i = Coord \$ 1 + (i `mod` d)
where
d = fromInteger . natVal \$ (Proxy :: Proxy d)
fromEnum = subtract 1 . unCoord

data Split d r = Split !(Coord d) !r !(Box d () r)

deriving instance (Show r, Arity d, KnownNat d) => Show (Split d r)
deriving instance (Eq r, Arity d, KnownNat d)   => Eq (Split d r)

type Split' d r = SP (Coord d) r

newtype KDTree' d p r = KDT { unKDT :: BinLeafTree (Split d r) (Point d r :+ p) }

deriving instance (Show p, Show r, Arity d, KnownNat d) => Show (KDTree' d p r)
deriving instance (Eq p, Eq r, Arity d, KnownNat d)     => Eq   (KDTree' d p r)

data KDTree d p r = Empty
| Tree (KDTree' d p r)

deriving instance (Show p, Show r, Arity d, KnownNat d) => Show (KDTree d p r)
deriving instance (Eq p, Eq r, Arity d, KnownNat d)     => Eq   (KDTree d p r)

toMaybe          :: KDTree d p r -> Maybe (KDTree' d p r)
toMaybe Empty    = Nothing
toMaybe (Tree t) = Just t

-- | Expects the input to be a set, i.e. no duplicates
--
-- running time: \(O(n \log n)\)
buildKDTree :: (Arity d, 1 <= d, Ord r)
=> [Point d r :+ p] -> KDTree d p r
buildKDTree = maybe Empty (Tree . buildKDTree') . NonEmpty.nonEmpty

buildKDTree' :: (Arity d, 1 <= d, Ord r)
=> NonEmpty.NonEmpty (Point d r :+ p) -> KDTree' d p r
buildKDTree' = KDT . addBoxes . build (Coord 1) . toPointSet . LSeq.fromNonEmpty
where     -- compute one tree with bounding boxes, then merge them together
addBoxes t = let bbt = foldUpData (\l _ r -> boundingBoxList' [l,r])
(boundingBox . (^.core)) t
in zipExactWith (\(SP c m) b -> Split c m b) const t bbt

-- | Nub by sorting first
ordNub :: Ord a => NonEmpty.NonEmpty a -> NonEmpty.NonEmpty a
ordNub = fmap NonEmpty.head . NonEmpty.group1 . NonEmpty.sort

toPointSet :: (Arity d, Ord r)
=> LSeq n (Point d r :+ p) -> PointSet (LSeq n) d p r
toPointSet = FV.imap sort . FV.replicate
where
sort i = LSeq.unstableSortBy (compareOn \$ 1 + i)

compareOn       :: (Ord r, Arity d)
=> Int -> Point d r :+ e -> Point d r :+ e -> Ordering
compareOn i p q = let f = (^.core.unsafeCoord i)
in (f p, p^.core) `compare` (f q, q^.core)

build      :: (1 <= d, Arity d, Ord r)
=> Coord d
-> PointSet (LSeq 1) d p r
-> BinLeafTree (Split' d r) (Point d r :+ p)
build i ps = case asSingleton ps of
Left p    -> Leaf p
Right ps' -> let (l,m,r) = splitOn i ps'
j       = succ i
-- the pattern match proves tha tthe seq has >= 2 elements
in Node (build j l) m (build j r)

--------------------------------------------------------------------------------

reportSubTree :: KDTree' d p r -> NonEmpty.NonEmpty (Point d r :+ p)
reportSubTree = NonEmpty.fromList . F.toList . unKDT

-- | Searches in a KDTree
--
-- running time: \(O(n^{(d-1)/d} + k)\)
searchKDTree    :: (Arity d, Ord r)
=> Box d q r -> KDTree d p r -> [Point d r :+ p]
searchKDTree qr = maybe [] (searchKDTree' qr) . toMaybe

searchKDTree'                  :: (Arity d, Ord r)
=> Box d q r -> KDTree' d p r -> [Point d r :+ p]
searchKDTree' qr = search . unKDT
where
search (Leaf p)
| (p^.core) `intersects` qr = [p]
| otherwise                 = []
search t@(Node l (Split _ _ b) r)
| b `containedIn` qr        = F.toList t
| otherwise                 = l' ++ r'
where
l' = if qr `intersects` boxOf l then search l else []
r' = if qr `intersects` boxOf r then search r else []

boxOf :: (Arity d, Ord r) => BinLeafTree (Split d r) (Point d r :+ p) -> Box d () r
boxOf (Leaf p)                 = boundingBox (p^.core)
boxOf (Node _ (Split _ _ b) _) = b

containedIn :: (Arity d, Ord r) => Box d q r -> Box d p r -> Bool
(Box (CWMin p :+ _) (CWMax q :+ _)) `containedIn` b = all (`intersects` b) [p,q]

--------------------------------------------------------------------------------

type PointSet seq d p r = Vector d (seq (Point d r :+ p))

-- | running time: \(O(n)\)
splitOn                 :: (Arity d, KnownNat d, Ord r)
=> Coord d
-> PointSet (LSeq 2) d p r
-> ( PointSet (LSeq 1) d p r
, Split' d r
, PointSet (LSeq 1) d p r)
splitOn c@(Coord i) pts = (l, SP c (m^.core.unsafeCoord i), r)
where
-- i = traceShow (c,j) j

m = let xs = fromJust \$ pts^?element' (i-1)
in xs `LSeq.index` (F.length xs `div` 2)

-- Since the input seq has >= 2 elems, F.length xs / 2 >= 1. It follows
-- that the both sets thus have at least one elemnt.
-- f :: LSeq 2 _ -> (LSeq 1 _, LSeq 1 _)
f = bimap LSeq.promise LSeq.promise
. LSeq.partition (\p -> compareOn i p m == LT)

(l,r) = unzip' . fmap f \$ pts

-- unzip' :: Vector d (a,b) -> (Vector d a, Vector d b)
unzip' = bimap vectorFromListUnsafe vectorFromListUnsafe . unzip . F.toList

asSingleton   :: (1 <= d, Arity d)
=> PointSet (LSeq 1) d p r
-> Either (Point d r :+ p) (PointSet (LSeq 2) d p r)
asSingleton v = case v^.element (C :: C 0) of
(p :<| s) | null s -> Left p -- only one lement
_                  -> Right \$ unsafeCoerce v
```