{-# options_ghc -Wno-unused-imports #-}
{-# options_ghc -Wno-type-defaults #-}
module Data.VPTree.Build (build
                         -- * Internal
                         , buildVT
                         ) where

import Control.Monad.ST (ST, runST)
import qualified Data.Foldable as F (Foldable(..))
import Data.Foldable (foldlM)
import Data.Maybe (fromMaybe)

-- containers
import qualified Data.Set as S (Set, fromList, difference)
-- import qualified Data.Sequence as SQ (Seq)
-- deepseq
-- import Control.DeepSeq (NFData (rnf))
-- mwc-probability
import qualified System.Random.MWC.Probability as P (Gen, Prob, withSystemRandom, asGenIO, GenIO, create, initialize, sample, samples, normal, bernoulli)
-- primitive
import Control.Monad.Primitive (PrimMonad(..), PrimState)
-- sampling
import Numeric.Sampling (sample)
-- vector
import qualified Data.Vector as V (Vector, map, filter, length, toList, replicate, partition, zipWith, head, tail, fromList, thaw, freeze, (!), foldl)
-- import qualified Data.Vector.Generic as VG (Vector(..))
-- import Data.Vector.Generic.Mutable (MVector)
-- vector-algorithms
import qualified Data.Vector.Algorithms.Merge as V (sort, Comparison)

import Data.VPTree.Internal (VT(..), VPTree(..), withST_)

-- * Construction

-- | Build a 'VPTree'
--
-- The supplied distance function @d@ must satisfy the definition of a metric, i.e.
--
-- * identity of indiscernible elements : \( d(x, y) = 0 \leftrightarrow x \equiv y \)
--
-- * symmetry : \(  d(x, y) = d(y, x)  \)
--
-- * triangle inequality : \( d(x, y) + d(y, z) >= d(x, z) \)
--
-- The current implementation makes multiple passes over the whole dataset, which is why the entire indexing dataset must be present in memory (packed as a 'V.Vector').
--
-- Implementation detail : construction of a VP-tree requires a randomized algorithm, but we run that in the ST monad so the result is pure.
build :: (RealFrac p, Floating d, Ord d, Eq a) =>
         (a -> a -> d) -- ^ distance function
      -> p -- ^ proportion of remaining dataset to sample at each level, \(0 < p <= 1 \)
      -> V.Vector a -- ^ dataset used for constructing the index
      -> VPTree d a
build :: (a -> a -> d) -> p -> Vector a -> VPTree d a
build a -> a -> d
distf p
prop Vector a
xss = (forall s. Gen s -> ST s (VPTree d a)) -> VPTree d a
forall a. (forall s. Gen s -> ST s a) -> a
withST_ ((forall s. Gen s -> ST s (VPTree d a)) -> VPTree d a)
-> (forall s. Gen s -> ST s (VPTree d a)) -> VPTree d a
forall a b. (a -> b) -> a -> b
$ \Gen s
gen -> do
  VT d a
vt <- (a -> a -> d)
-> p -> Vector a -> Gen (PrimState (ST s)) -> ST s (VT d a)
forall (m :: * -> *) b d a.
(PrimMonad m, RealFrac b, Floating d, Eq a, Ord d) =>
(a -> a -> d) -> b -> Vector a -> Gen (PrimState m) -> m (VT d a)
buildVT a -> a -> d
distf p
prop Vector a
xss Gen s
Gen (PrimState (ST s))
gen
  VPTree d a -> ST s (VPTree d a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VPTree d a -> ST s (VPTree d a))
-> VPTree d a -> ST s (VPTree d a)
forall a b. (a -> b) -> a -> b
$ VT d a -> (a -> a -> d) -> VPTree d a
forall d a. VT d a -> (a -> a -> d) -> VPTree d a
VPT VT d a
vt a -> a -> d
distf


-- | Build a VP-tree with the given distance function
buildVT :: (PrimMonad m, RealFrac b, Floating d, Eq a, Ord d) =>
           (a -> a -> d) -- ^ distance function
        -> b -- ^ proportion of remaining dataset to sample at each level
        -> V.Vector a -- ^ dataset
        -> P.Gen (PrimState m) -- ^ PRNG
        -> m (VT d a)
buildVT :: (a -> a -> d) -> b -> Vector a -> Gen (PrimState m) -> m (VT d a)
buildVT a -> a -> d
distf b
prop Vector a
xss Gen (PrimState m)
gen = Vector a -> m (VT d a)
go Vector a
xss
  where
    go :: Vector a -> m (VT d a)
go Vector a
xs
      | Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector a
xs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
10 = VT d a -> m (VT d a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VT d a -> m (VT d a)) -> VT d a -> m (VT d a)
forall a b. (a -> b) -> a -> b
$ Vector a -> VT d a
forall d a. Vector a -> VT d a
Tip Vector a
xs
      | Bool
otherwise = do
          (a
vp, Vector a
xs') <- (a -> a -> d)
-> b -> Vector a -> Gen (PrimState m) -> m (a, Vector a)
forall (m :: * -> *) b d a.
(PrimMonad m, RealFrac b, Floating d, Ord d) =>
(a -> a -> d)
-> b -> Vector a -> Gen (PrimState m) -> m (a, Vector a)
selectVP a -> a -> d
distf b
prop Vector a
xs Gen (PrimState m)
gen
          let
            mu :: d
mu = Vector d -> d
forall a. Ord a => Vector a -> a
median (Vector d -> d) -> Vector d -> d
forall a b. (a -> b) -> a -> b
$ (a -> d) -> Vector a -> Vector d
forall a b. (a -> b) -> Vector a -> Vector b
V.map (a -> a -> d
`distf` a
vp) Vector a
xs' -- median distance to the vantage point
            (Vector a
ll, Vector a
rr) = (a -> Bool) -> Vector a -> (Vector a, Vector a)
forall a. (a -> Bool) -> Vector a -> (Vector a, Vector a)
V.partition (\a
x -> a -> a -> d
distf a
x a
vp d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
mu) Vector a
xs'

          VT d a
ltree <- Vector a -> m (VT d a)
go Vector a
ll
          VT d a
rtree <- Vector a -> m (VT d a)
go Vector a
rr
          VT d a -> m (VT d a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VT d a -> m (VT d a)) -> VT d a -> m (VT d a)
forall a b. (a -> b) -> a -> b
$ d -> a -> VT d a -> VT d a -> VT d a
forall d a. d -> a -> VT d a -> VT d a -> VT d a
Bin d
mu a
vp VT d a
ltree VT d a
rtree


-- | Select a vantage point
selectVP :: (PrimMonad m, RealFrac b, Floating d, Ord d) =>
            (a -> a -> d)
         -> b -> V.Vector a -> P.Gen (PrimState m) -> m (a, V.Vector a)
selectVP :: (a -> a -> d)
-> b -> Vector a -> Gen (PrimState m) -> m (a, Vector a)
selectVP a -> a -> d
distf b
prop Vector a
xs Gen (PrimState m)
gen = do
  (a
pstart, [a]
pstail, [a]
pscl) <- Int -> Vector a -> Gen (PrimState m) -> m (a, [a], [a])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> Vector a -> Gen (PrimState m) -> m (a, [a], [a])
vpRandSplitInit Int
n Vector a
xs Gen (PrimState m)
gen
  let pickMu :: (d, a, [a]) -> a -> m (d, a, [a])
pickMu (d
spread_curr, a
p_curr, [a]
acc) a
p = do
        [a]
ds <- Int -> [a] -> Gen (PrimState m) -> m [a]
forall (m :: * -> *) (t :: * -> *) a.
(PrimMonad m, Foldable t) =>
Int -> t a -> Gen (PrimState m) -> m [a]
sampleId Int
n2 [a]
pscl Gen (PrimState m)
gen -- sample n2 < n points from pscl
        let
          spread :: d
spread = (a -> a -> d) -> a -> Vector a -> d
forall a t p.
(Floating a, Ord a) =>
(t -> p -> a) -> p -> Vector t -> a
varianceWrt a -> a -> d
distf a
p ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ds)
        if d
spread d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
spread_curr
          then (d, a, [a]) -> m (d, a, [a])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (d
spread,      a
p,      a
p_curr a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)
          else (d, a, [a]) -> m (d, a, [a])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (d
spread_curr, a
p_curr, a
p      a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc)
  (a
vp, [a]
vrest) <- (d, a, [a]) -> (a, [a])
forall a a b. (a, a, b) -> (a, b)
tail3 ((d, a, [a]) -> (a, [a])) -> m (d, a, [a]) -> m (a, [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((d, a, [a]) -> a -> m (d, a, [a]))
-> (d, a, [a]) -> [a] -> m (d, a, [a])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (d, a, [a]) -> a -> m (d, a, [a])
pickMu (d
0, a
pstart, [a]
forall a. Monoid a => a
mempty) [a]
pstail
  (a, Vector a) -> m (a, Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
vp, [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
vrest)
  where
    n :: Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ b -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (b
prop b -> b -> b
forall a. Num a => a -> a -> a
* Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ndata)
    n2 :: Int
n2 = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ b -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (b
prop b -> b -> b
forall a. Num a => a -> a -> a
* Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    ndata :: Int
ndata = Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector a
xs -- size of dataset at current level
    tail3 :: (a, a, b) -> (a, b)
tail3 (a
_, a
x, b
xs) = (a
x, b
xs)


-- | sample the initialization for picking a vantage point
--
-- samples a random split of the input dataset, and from the first half further samples a head element, which will be used as candidate vantage point
vpRandSplitInit :: PrimMonad m =>
                   Int
                -> V.Vector a -- ^ cannot be less than 3 elements
                -> P.Gen (PrimState m)
                -> m (a, [a], [a]) -- (head of C, tail of C, complement of C)
vpRandSplitInit :: Int -> Vector a -> Gen (PrimState m) -> m (a, [a], [a])
vpRandSplitInit Int
n Vector a
sset Gen (PrimState m)
gen = do
  ([a]
ps, [a]
psc) <- Int -> Vector a -> Gen (PrimState m) -> m ([a], [a])
forall (m :: * -> *) (t :: * -> *) a.
(PrimMonad m, Foldable t) =>
Int -> t a -> Gen (PrimState m) -> m ([a], [a])
uniformSplit Int
n Vector a
sset Gen (PrimState m)
gen
  ([a]
pstartv, [a]
pstail) <- Double -> Int -> [a] -> Gen (PrimState m) -> m ([a], [a])
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, PrimMonad m) =>
Double -> Int -> t a -> Gen (PrimState m) -> m ([a], [a])
randomSplit Double
0.5 Int
1 [a]
ps Gen (PrimState m)
gen -- Pick a random starting point from ps
  let
    -- this is load-bearing, do not change
    pstart :: a
pstart = if [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
pstartv then [a]
pstail [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
1 else [a] -> a
forall a. [a] -> a
head [a]
pstartv
  (a, [a], [a]) -> m (a, [a], [a])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
pstart, [a]
pstail, [a]
psc)

-- | Split a dataset in two, returning a ~ uniform sample
--
-- the Bernoulli parameter depends on the size of the desired sample and that of the dataset
uniformSplit :: (PrimMonad m, Foldable t) =>
                Int -> t a -> P.Gen (PrimState m) -> m ([a], [a])
uniformSplit :: Int -> t a -> Gen (PrimState m) -> m ([a], [a])
uniformSplit Int
n t a
vv = Double -> Int -> t a -> Gen (PrimState m) -> m ([a], [a])
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, PrimMonad m) =>
Double -> Int -> t a -> Gen (PrimState m) -> m ([a], [a])
randomSplit Double
p Int
n t a
vv
  where
    p :: Double
p = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
vv))

-- | Sample a random split of the dataset in a single pass, by repeatedly tossing a coin
--
-- Invariant : the concatenation of the two resulting vectors is a permutation of the input vector
--
-- NB : the second vector in the result tuple will be empty if the requested sample size is larger than the input vector
randomSplit :: (Foldable t, PrimMonad m) =>
                Double -- ^ Bernoulli parameter
             -> Int  -- ^ Size of sample
             -> t a -- ^ dataset
             -> P.Gen (PrimState m) -- ^ PRNG
             -> m ([a], [a])
randomSplit :: Double -> Int -> t a -> Gen (PrimState m) -> m ([a], [a])
randomSplit Double
p Int
n t a
vv = Prob m ([a], [a]) -> Gen (PrimState m) -> m ([a], [a])
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
P.sample (Prob m ([a], [a]) -> Gen (PrimState m) -> m ([a], [a]))
-> Prob m ([a], [a]) -> Gen (PrimState m) -> m ([a], [a])
forall a b. (a -> b) -> a -> b
$ (([a], [a]) -> a -> Prob m ([a], [a]))
-> ([a], [a]) -> t a -> Prob m ([a], [a])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM ([a], [a]) -> a -> Prob m ([a], [a])
forall (m :: * -> *) a.
PrimMonad m =>
([a], [a]) -> a -> Prob m ([a], [a])
insf ([], []) t a
vv
  where
    insf :: ([a], [a]) -> a -> Prob m ([a], [a])
insf ([a]
al, [a]
ar) a
x = do
      Bool
coin <- Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
P.bernoulli Double
p
      if [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
al Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Bool -> Bool -> Bool
|| Bool
coin
        then ([a], [a]) -> Prob m ([a], [a])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a]
al, a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
ar)
        else ([a], [a]) -> Prob m ([a], [a])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
al, [a]
ar)




-- | Sample _without_ replacement. Returns the input list if the required sample size is too large
sampleId :: (PrimMonad m, Foldable t) =>
            Int -- ^ Size of sample
         -> t a
         -> P.Gen (PrimState m)
         -> m [a]
sampleId :: Int -> t a -> Gen (PrimState m) -> m [a]
sampleId Int
n t a
xs Gen (PrimState m)
g = [a] -> Maybe [a] -> [a]
forall a. a -> Maybe a -> a
fromMaybe (t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList t a
xs) (Maybe [a] -> [a]) -> m (Maybe [a]) -> m [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> t a -> Gen (PrimState m) -> m (Maybe [a])
forall (m :: * -> *) (f :: * -> *) a.
(PrimMonad m, Foldable f) =>
Int -> f a -> Gen (PrimState m) -> m (Maybe [a])
sample Int
n t a
xs Gen (PrimState m)
g
{-# INLINE sampleId #-}

-- | Variance of the distance btw the dataset and a given query point
--
-- NB input vector must have at least 1 element
varianceWrt :: (Floating a, Ord a) =>
               (t -> p -> a) -- ^ distance function
            -> p -- ^ query point
            -> V.Vector t
            -> a
varianceWrt :: (t -> p -> a) -> p -> Vector t -> a
varianceWrt t -> p -> a
distf p
p Vector t
ds = Vector a -> Vector a -> a
forall a. Floating a => Vector a -> Vector a -> a
variance Vector a
dists (Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate Int
n2 a
mu) where
  dists :: Vector a
dists = (t -> a) -> Vector t -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (t -> p -> a
`distf` p
p) Vector t
ds
  mu :: a
mu = Vector a -> a
forall a. Ord a => Vector a -> a
median Vector a
dists
  n2 :: Int
n2 = Vector t -> Int
forall a. Vector a -> Int
V.length Vector t
ds
{-# INLINE varianceWrt #-}

-- | NB input vector must have at least 1 element
median :: Ord a => V.Vector a -> a
median :: Vector a -> a
median Vector a
xs
  | Vector a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Vector a
xs = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"median : input array must have at least 1 element"
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Vector a -> a
forall a. Vector a -> a
V.head Vector a
xs
  | Bool
otherwise = Vector a -> Vector a
forall a. Ord a => Vector a -> Vector a
sortV Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2)
  where n :: Int
n = Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector a
xs
{-# INLINE median #-}

variance :: (Floating a) => V.Vector a -> V.Vector a -> a
variance :: Vector a -> Vector a -> a
variance Vector a
xs Vector a
mus = Vector a -> a
forall a. Fractional a => Vector a -> a
mean (Vector a -> a) -> Vector a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. Floating a => a -> a -> a
sqdiff Vector a
xs Vector a
mus
  where
    sqdiff :: a -> a -> a
sqdiff a
x a
y = (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y) a -> a -> a
forall a. Floating a => a -> a -> a
** a
2
{-# INLINE variance #-}

mean :: (Fractional a) => V.Vector a -> a
mean :: Vector a -> a
mean Vector a
xs = Vector a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Vector a
xs a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Vector a
xs)
{-# INLINE mean #-}

sortV :: Ord a => V.Vector a -> V.Vector a
sortV :: Vector a -> Vector a
sortV Vector a
v = (forall s. ST s (Vector a)) -> Vector a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a)) -> Vector a)
-> (forall s. ST s (Vector a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vm <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
V.thaw Vector a
v
  MVector (PrimState (ST s)) a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e, Ord e) =>
v (PrimState m) e -> m ()
V.sort MVector s a
MVector (PrimState (ST s)) a
vm
  MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s a
MVector (PrimState (ST s)) a
vm
{-# INLINE sortV #-}





-- -- OLD


-- selectVP :: (PrimMonad m, RealFrac b, Ord d, Floating d) =>
--             (a -> a -> d) -- ^ distance function
--          -> b -- ^ proportion of dataset to sample
--          -> V.Vector a -- ^ dataset
--          -> P.Gen (PrimState m)
--          -> m a
-- selectVP distf prop sset gen = do
--   (pstart, pstail, pscl) <- vpRandSplitInit n sset gen
--   let pickMu (spread_curr, p_curr) p = do
--         ds <- sampleId n2 pscl gen -- sample n2 < n points from pscl
--         let
--           spread = varianceWrt distf p (V.fromList ds)
--         if spread > spread_curr
--           then pure (spread, p)
--           else pure (spread_curr, p_curr)
--   snd <$> foldlM pickMu (0, pstart) pstail
--   where
--     n = floor (prop * fromIntegral ndata)
--     n2 = floor (prop * fromIntegral n)
--     ndata = length sset -- size of dataset at current level

-- randomSplit :: (PrimMonad f) =>
--                Int -- ^ Size of sample
--             -> V.Vector a -- ^ dataset
--             -> P.Gen (PrimState f) -- ^ PRNG
--             -> f (V.Vector a, V.Vector a)
-- randomSplit n vv gen = split <$> sampleId n ixs gen
--   where
--     split xs = (vxs, vxsc)
--       where
--         ixss = S.fromList xs
--         ixsc = S.fromList ixs `S.difference` ixss
--         vxs  = pickItems ixss
--         vxsc = pickItems ixsc
--     m = V.length vv
--     ixs = [0 .. m - 1]
--     pickItems = V.fromList . foldl (\acc i -> vv V.! i : acc) []