{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language DeriveGeneric #-}
{-# language LambdaCase #-}
{-# language GeneralizedNewtypeDeriving #-}
-- {-# language MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
{-# options_ghc -Wno-unused-imports #-}
{-# options_ghc -Wno-unused-top-binds #-}

{-|
Random projection trees for approximate nearest neighbor search in high-dimensional vector spaces
-}
module Data.RPTree (
  -- * Construction
  tree, forest
  -- * Query
  , knn
  -- * I/O
  , serialiseRPForest
  , deserialiseRPForest
  -- * Validation
  , recallWith
  -- * Access
  , levels, points, leaves, candidates
  -- * Types
  -- ** RPTree
  , RPTree, RPForest
  -- *
  , SVector, fromListSv, fromVectorSv
  , DVector, fromListDv, fromVectorDv
  -- * Vector space types
  , Inner(..), Scale(..)
    -- ** helpers for implementing Inner instances
    -- *** inner product
  , innerSS, innerSD, innerDD
    -- *** L2 distance
  , metricSSL2, metricSDL2
  -- *** Scale
  , scaleS, scaleD
  -- * Conduit
  , dataSource
  -- * Random generation
  -- ** vector
  , sparse, dense
  , normal2

  -- * Rendering
  , draw
  -- * CSV
  , writeCsv
  -- * Testing
  , randSeed, BenchConfig(..), normalSparse2
  , liftC
  ) where

import Control.Monad (replicateM)

import Control.Monad.IO.Class (MonadIO(..))
import Data.Foldable (Foldable(..), maximumBy, minimumBy)
import Data.Functor.Identity (Identity(..))
import Data.List (partition, sortBy)
import Data.Monoid (Sum(..))
import Data.Ord (comparing)
import GHC.Generics (Generic)
import GHC.Word (Word64)

-- containers
import Data.Sequence (Seq, (|>))
import qualified Data.Map as M (Map, fromList, toList, foldrWithKey, insert, insertWith)
import qualified Data.Set as S (Set, fromList, intersection, insert)
-- deepseq
import Control.DeepSeq (NFData(..))
-- mtl
import Control.Monad.State (MonadState(..), modify)
-- -- psqueues
-- import qualified Data.IntPSQ as PQ (IntPSQ, insert, fromList, findMin, minView)
-- transformers
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)
import Control.Monad.Trans.Class (MonadTrans(..))
-- vector
import qualified Data.Vector as V (Vector, replicateM, fromList)
import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!), map, freeze, thaw, take, drop, unzip)
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)
import qualified Data.Vector.Storable as VS (Vector)
-- vector-algorithms
import qualified Data.Vector.Algorithms.Merge as V (sortBy)

import Data.RPTree.Conduit (tree, forest, dataSource, liftC)
import Data.RPTree.Gen (sparse, dense, normal2, normalSparse2)
import Data.RPTree.Internal (RPTree(..), RPForest, RPT(..), levels, points, leaves, RT(..), Inner(..), Scale(..), scaleS, scaleD, (/.), innerDD, innerSD, innerSS, metricSSL2, metricSDL2, SVector(..), fromListSv, fromVectorSv, DVector(..), fromListDv, fromVectorDv, partitionAtMedian, Margin, getMargin, sortByVG, serialiseRPForest, deserialiseRPForest)
import Data.RPTree.Internal.Testing (BenchConfig(..), randSeed)
import Data.RPTree.Draw (draw, writeCsv)


-- | k nearest neighbors
knn :: (Ord p, Inner SVector v, VU.Unbox d, Real d) =>
       (v2 -> v d -> p) -- ^ distance function
    -> Int -- ^ k neighbors
    -> RPForest d (V.Vector v2) -- ^ random projection forest
    -> v d -- ^ query point
    -> V.Vector (p, v2) -- ^ ordered in increasing distance order
knn :: (v2 -> v d -> p)
-> Int -> RPForest d (Vector v2) -> v d -> Vector (p, v2)
knn v2 -> v d -> p
distf Int
k RPForest d (Vector v2)
tts v d
q = ((p, v2) -> p) -> Vector (p, v2) -> Vector (p, v2)
forall (v :: * -> *) a b.
(Vector v a, Ord b) =>
(a -> b) -> v a -> v a
sortByVG (p, v2) -> p
forall a b. (a, b) -> a
fst Vector (p, v2)
cs
  where
    cs :: Vector (p, v2)
cs = (v2 -> (p, v2)) -> Vector v2 -> Vector (p, v2)
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (\v2
x -> (v2
x v2 -> v d -> p
`distf` v d
q, v2
x)) (Vector v2 -> Vector (p, v2)) -> Vector v2 -> Vector (p, v2)
forall a b. (a -> b) -> a -> b
$ Int -> Vector v2 -> Vector v2
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
VG.take Int
k (Vector v2 -> Vector v2) -> Vector v2 -> Vector v2
forall a b. (a -> b) -> a -> b
$ IntMap (Vector v2) -> Vector v2
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (IntMap (Vector v2) -> Vector v2)
-> IntMap (Vector v2) -> Vector v2
forall a b. (a -> b) -> a -> b
$ (RPTree d (Vector v2) -> v d -> Vector v2
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
`candidates` v d
q) (RPTree d (Vector v2) -> Vector v2)
-> RPForest d (Vector v2) -> IntMap (Vector v2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RPForest d (Vector v2)
tts


-- | average recall-at-k, computed over a set of trees
recallWith :: (Inner SVector v, VU.Unbox a, Fractional a, Ord a, Ord d, Ord p) =>
              (p -> v a -> d)
           -> RPForest a (V.Vector p)
           -> Int -- ^ k : number of nearest neighbors to consider
           -> v a -- ^ query point
           -> a
recallWith :: (p -> v a -> d) -> RPForest a (Vector p) -> Int -> v a -> a
recallWith p -> v a -> d
distf RPForest a (Vector p)
tt Int
k v a
q = IntMap a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum IntMap a
rs a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
  where
    rs :: IntMap a
rs = (RPTree a (Vector p) -> a) -> RPForest a (Vector p) -> IntMap a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RPTree a (Vector p)
t -> (p -> v a -> d) -> RPTree a (Vector p) -> Int -> v a -> a
forall a1 (v :: * -> *) d a3 a2.
(Fractional a1, Inner SVector v, Ord d, Unbox d, Num d, Ord a3,
 Ord a2) =>
(a2 -> v d -> a3) -> RPTree d (Vector a2) -> Int -> v d -> a1
recallWith1 p -> v a -> d
distf RPTree a (Vector p)
t Int
k v a
q) RPForest a (Vector p)
tt
    n :: Int
n = RPForest a (Vector p) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RPForest a (Vector p)
tt

-- -- | Recall computed wrt the Euclidean distance
-- recallEuclid :: (Inner SVector v, Inner u v, VU.Unbox a, Ord a, Ord (u a), Floating a) =>
--                 RPTree a (V.Vector (u a))
--              -> Int -- ^ k : number of nearest neighbors to consider
--              -> v a  -- ^ query point
--              -> Double
-- recallEuclid = recallWith metricL2

recallWith1 :: (Fractional a1, Inner SVector v, Ord d, VU.Unbox d,
                Num d, Ord a3, Ord a2) =>
              (a2 -> v d -> a3) -- ^ distance function
           -> RPTree d (V.Vector a2)
           -> Int -- ^ k : number of nearest neighbors to consider
           -> v d -- ^ query point
           -> a1
recallWith1 :: (a2 -> v d -> a3) -> RPTree d (Vector a2) -> Int -> v d -> a1
recallWith1 a2 -> v d -> a3
distf RPTree d (Vector a2)
tt Int
k v d
q = Int -> a1
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Set a2 -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Set a2
aintk) a1 -> a1 -> a1
forall a. Fractional a => a -> a -> a
/ Int -> a1
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
  where
    xs :: Vector a2
xs = RPTree d (Vector a2) -> Vector a2
forall m d. Monoid m => RPTree d m -> m
points RPTree d (Vector a2)
tt
    dists :: [(a2, a3)]
dists = ((a2, a3) -> (a2, a3) -> Ordering) -> [(a2, a3)] -> [(a2, a3)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((a2, a3) -> a3) -> (a2, a3) -> (a2, a3) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (a2, a3) -> a3
forall a b. (a, b) -> b
snd) ([(a2, a3)] -> [(a2, a3)]) -> [(a2, a3)] -> [(a2, a3)]
forall a b. (a -> b) -> a -> b
$ Vector (a2, a3) -> [(a2, a3)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Vector (a2, a3) -> [(a2, a3)]) -> Vector (a2, a3) -> [(a2, a3)]
forall a b. (a -> b) -> a -> b
$ (a2 -> (a2, a3)) -> Vector a2 -> Vector (a2, a3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a2
x -> (a2
x, a2
x a2 -> v d -> a3
`distf` v d
q)) Vector a2
xs
    kk :: Set a2
kk = [a2] -> Set a2
forall a. Ord a => [a] -> Set a
S.fromList ([a2] -> Set a2) -> [a2] -> Set a2
forall a b. (a -> b) -> a -> b
$ ((a2, a3) -> a2) -> [(a2, a3)] -> [a2]
forall a b. (a -> b) -> [a] -> [b]
map (a2, a3) -> a2
forall a b. (a, b) -> a
fst ([(a2, a3)] -> [a2]) -> [(a2, a3)] -> [a2]
forall a b. (a -> b) -> a -> b
$ Int -> [(a2, a3)] -> [(a2, a3)]
forall a. Int -> [a] -> [a]
take Int
k [(a2, a3)]
dists -- first k nn's
    aa :: Set a2
aa = Vector a2 -> Set a2
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> Set a
set (Vector a2 -> Set a2) -> Vector a2 -> Set a2
forall a b. (a -> b) -> a -> b
$ RPTree d (Vector a2) -> v d -> Vector a2
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
candidates RPTree d (Vector a2)
tt v d
q
    aintk :: Set a2
aintk = Set a2
aa Set a2 -> Set a2 -> Set a2
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set a2
kk

set :: (Foldable t, Ord a) => t a -> S.Set a
set :: t a -> Set a
set = (Set a -> a -> Set a) -> Set a -> t a -> Set a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((a -> Set a -> Set a) -> Set a -> a -> Set a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert) Set a
forall a. Monoid a => a
mempty



{-# SCC candidates #-}
-- | Retrieve points nearest to the query
--
-- in case of a narrow margin, collect both branches of the tree
candidates :: (Inner SVector v, VU.Unbox d, Ord d, Num d, Semigroup xs) =>
              RPTree d xs
           -> v d -- ^ query point
           -> xs
candidates :: RPTree d xs -> v d -> xs
candidates (RPTree Vector (SVector d)
rvs RPT d xs
tt) v d
x = Int -> RPT d xs -> xs
forall a. Semigroup a => Int -> RPT d a -> a
go Int
0 RPT d xs
tt
  where
    go :: Int -> RPT d a -> a
go Int
_     (Tip a
xs)                     = a
xs
    go Int
ixLev (Bin d
thr Margin d
margin RPT d a
ltree RPT d a
rtree) = do
      let
        (d
mglo, d
mghi) = Margin d -> (d, d)
forall a. Margin a -> (a, a)
getMargin Margin d
margin
        r :: SVector d
r = Vector (SVector d)
rvs Vector (SVector d) -> Int -> SVector d
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.! Int
ixLev
        proj :: d
proj = SVector d
r SVector d -> v d -> d
forall (u :: * -> *) (v :: * -> *) a.
(Inner u v, Unbox a, Num a) =>
u a -> v a -> a
`inner` v d
x
        i' :: Int
i' = Int -> Int
forall a. Enum a => a -> a
succ Int
ixLev
        dl :: d
dl = d -> d
forall a. Num a => a -> a
abs (d
mglo d -> d -> d
forall a. Num a => a -> a -> a
- d
proj) -- left margin
        dr :: d
dr = d -> d
forall a. Num a => a -> a
abs (d
mghi d -> d -> d
forall a. Num a => a -> a -> a
- d
proj) -- right margin
      if | d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr Bool -> Bool -> Bool
&&
           d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
         | d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr  -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree
         | d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
thr Bool -> Bool -> Bool
&&
           d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
         | Bool
otherwise   -> Int -> RPT d a -> a
go Int
i' RPT d a
rtree






-- pqSeq :: Ord a => PQ.IntPSQ a b -> Seq (a, b)
-- pqSeq pqq = go pqq mempty
--   where
--     go pq acc = case PQ.minView pq of
--       Nothing -> acc
--       Just (_, p, v, rest) -> go rest $ acc |> (p, v)


-- newtype Counts a = Counts {
--   unCounts :: M.Map a (Sum Int) } deriving (Eq, Show, Semigroup, Monoid)
-- keepCounts :: Int -- ^ keep entry iff counts are larger than this value
--            -> Counts a
--            -> [(a, Int)]
-- keepCounts thr cs = M.foldrWithKey insf mempty c
--   where
--     insf k v acc
--       | v >= thr = (k, v) : acc
--       | otherwise = acc
--     c = getSum `fmap` unCounts cs
-- counts :: (Foldable t, Ord a) => t a -> Counts a
-- counts = foldl count mempty
-- count :: Ord a => Counts a -> a -> Counts a
-- count (Counts mm) x = Counts $ M.insertWith mappend x (Sum 1) mm


-- forest :: Inner SVector v =>
--           Int -- ^ # of trees
--        -> Int -- ^ maximum tree height
--        -> Double -- ^ nonzero density of sparse projection vectors
--        -> Int -- ^ dimension of projection vectors
--        -> V.Vector (v Double) -- ^ dataset
--        -> Gen [RPTree Double (V.Vector (v Double))]
-- forest nt maxDepth pnz dim xss =
--   replicateM nt (tree maxDepth pnz dim xss)

-- -- | Build a random projection tree
-- --
-- -- Optimization: instead of sampling one projection vector per branch, we sample one per tree level (as suggested in https://www.cs.helsinki.fi/u/ttonteri/pub/bigdata2016.pdf )
-- tree :: (Inner SVector v) =>
--          Int -- ^ maximum tree height
--       -> Double -- ^ nonzero density of sparse projection vectors
--       -> Int -- ^ dimension of projection vectors
--       -> V.Vector (v Double) -- ^ dataset
--       -> Gen (RPTree Double (V.Vector (v Double)))
-- tree maxDepth pnz dim xss = do
--   -- sample all projection vectors
--   rvs <- V.replicateM maxDepth (sparse pnz dim stdNormal)
--   let
--     loop ixLev xs = do
--       if ixLev >= maxDepth || length xs <= 100
--         then
--           pure $ Tip xs
--         else
--         do
--           let
--             r = rvs VG.! ixLev
--             (thr, margin, ll, rr) = partitionAtMedian r xs
--           treel <- loop (ixLev + 1) ll
--           treer <- loop (ixLev + 1) rr
--           pure $ Bin thr margin treel treer
--   rpt <- loop 0 xss
--   pure $ RPTree rvs rpt





-- -- | Partition at median inner product
-- treeRT :: (Monad m, Inner SVector v) =>
--            Int
--         -> Int
--         -> Double
--         -> Int
--         -> V.Vector (v Double)
--         -> GenT m (RT SVector Double (V.Vector (v Double)))
-- treeRT maxDepth minLeaf pnz dim xss = loop 0 xss
--   where
--     loop ixLev xs = do
--       if ixLev >= maxDepth || length xs <= minLeaf
--         then
--           pure $ RTip xs
--         else
--         do
--           r <- sparse pnz dim stdNormal
--           let
--             (_, mrg, ll, rr) = partitionAtMedian r xs
--           treel <- loop (ixLev + 1) ll
--           treer <- loop (ixLev + 1) rr
--           pure $ RBin r mrg treel treer







-- -- | Like 'tree' but here we partition at the median of the inner product values instead
-- tree' :: (Inner SVector v) =>
--          Int
--       -> Double
--       -> Int
--       -> V.Vector (v Double)
--       -> Gen (RPTree Double (V.Vector (v Double)))
-- tree' maxDepth pnz dim xss = do
--   -- sample all projection vectors
--   rvs <- V.replicateM maxDepth (sparse pnz dim stdNormal)
--   let
--     loop ixLev xs =
--       if ixLev >= maxDepth || length xs <= 100
--         then Tip xs
--         else
--           let
--             r = rvs VG.! ixLev
--             (thr, margin, ll, rr) = partitionAtMedian r xs
--             tl = loop (ixLev + 1) ll
--             tr = loop (ixLev + 1) rr
--           in Bin thr margin tl tr
--   let rpt = loop 0 xss
--   pure $ RPTree rvs rpt


-- -- | Partition uniformly at random between inner product extreme values
-- treeRT :: (Monad m, Inner SVector v) =>
--           Int -- ^ max tree depth
--        -> Int -- ^ min leaf size
--        -> Double -- ^ nonzero density
--        -> Int -- ^ embedding dimension
--        -> V.Vector (v Double) -- ^ data
--        -> GenT m (RT SVector Double (V.Vector (v Double)))
-- treeRT maxDepth minLeaf pnz dim xss = loop 0 xss
--   where
--     loop ixLev xs = do
--       if ixLev >= maxDepth || length xs <= minLeaf
--         then
--           pure $ RTip xs
--         else
--         do
--           -- sample projection vector
--           r <- sparse pnz dim stdNormal
--           let
--             -- project the dataset
--             projs = map (\x -> (x, r `inner` x)) xs
--             hi = snd $ maximumBy (comparing snd) projs
--             lo = snd $ minimumBy (comparing snd) projs
--           -- sample a threshold
--           thr <- uniformR lo hi
--           let
--             (ll, rr) = partition (\xp -> snd xp < thr) projs
--           treel <- loop (ixLev + 1) (map fst ll)
--           treer <- loop (ixLev + 1) (map fst rr)
--           pure $ RBin r treel treer


-- -- | Partition wrt a plane _|_ to the segment connecting two points sampled at random
-- --
-- -- (like annoy@@)
-- treeRT2 :: (Monad m, Ord d, Fractional d, Inner v v, VU.Unbox d, Num d) =>
--            Int
--         -> Int
--         -> [v d]
--         -> GenT m (RT v d [v d])
-- treeRT2 maxd minl xss = loop 0 xss
--   where
--     loop ixLev xs = do
--       if ixLev >= maxd || length xs <= minl
--         then
--           pure $ RTip xs
--         else
--         do
--           x12 <- sampleWOR 2 xs
--           let
--             (x1:x2:_) = x12
--             r = x1 ^-^ x2
--             (ll, rr) = partition (\x -> (r `inner` (x ^-^ x1) < 0)) xs
--           treel <- loop (ixLev + 1) ll
--           treer <- loop (ixLev + 1) rr
--           pure $ RBin r treel treer










-- ulid :: MonadIO m => a -> m (ULID a)
-- ulid x = ULID <$> pure x <*> liftIO UU.getULID
-- data ULID a = ULID { uData :: a , uULID :: UU.ULID } deriving (Eq, Show)
-- instance (Eq a) => Ord (ULID a) where
--   ULID _ u1 <= ULID _ u2 = u1 <= u2