{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language DeriveGeneric #-}
{-# language LambdaCase #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiWayIf #-}
{-# options_ghc -Wno-unused-imports #-}
{-# options_ghc -Wno-unused-top-binds #-}
module Data.RPTree (
tree, forest
, knn
, serialiseRPForest
, deserialiseRPForest
, recallWith
, levels, points, leaves, candidates
, RPTree, RPForest
, SVector, fromListSv, fromVectorSv
, DVector, fromListDv, fromVectorDv
, Inner(..), Scale(..)
, innerSS, innerSD, innerDD
, metricSSL2, metricSDL2
, scaleS, scaleD
, dataSource
, sparse, dense
, normal2
, draw
, writeCsv
, randSeed, BenchConfig(..), normalSparse2
, liftC
) where
import Control.Monad (replicateM)
import Control.Monad.IO.Class (MonadIO(..))
import Data.Foldable (Foldable(..), maximumBy, minimumBy)
import Data.Functor.Identity (Identity(..))
import Data.List (partition, sortBy)
import Data.Monoid (Sum(..))
import Data.Ord (comparing)
import GHC.Generics (Generic)
import GHC.Word (Word64)
import Data.Sequence (Seq, (|>))
import qualified Data.Map as M (Map, fromList, toList, foldrWithKey, insert, insertWith)
import qualified Data.Set as S (Set, fromList, intersection, insert)
import Control.DeepSeq (NFData(..))
import Control.Monad.State (MonadState(..), modify)
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)
import Control.Monad.Trans.Class (MonadTrans(..))
import qualified Data.Vector as V (Vector, replicateM, fromList)
import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!), map, freeze, thaw, take, drop, unzip)
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)
import qualified Data.Vector.Storable as VS (Vector)
import qualified Data.Vector.Algorithms.Merge as V (sortBy)
import Data.RPTree.Conduit (tree, forest, dataSource, liftC)
import Data.RPTree.Gen (sparse, dense, normal2, normalSparse2)
import Data.RPTree.Internal (RPTree(..), RPForest, RPT(..), levels, points, leaves, RT(..), Inner(..), Scale(..), scaleS, scaleD, (/.), innerDD, innerSD, innerSS, metricSSL2, metricSDL2, SVector(..), fromListSv, fromVectorSv, DVector(..), fromListDv, fromVectorDv, partitionAtMedian, Margin, getMargin, sortByVG, serialiseRPForest, deserialiseRPForest)
import Data.RPTree.Internal.Testing (BenchConfig(..), randSeed)
import Data.RPTree.Draw (draw, writeCsv)
knn :: (Ord p, Inner SVector v, VU.Unbox d, Real d) =>
(v2 -> v d -> p)
-> Int
-> RPForest d (V.Vector v2)
-> v d
-> V.Vector (p, v2)
knn :: (v2 -> v d -> p)
-> Int -> RPForest d (Vector v2) -> v d -> Vector (p, v2)
knn v2 -> v d -> p
distf Int
k RPForest d (Vector v2)
tts v d
q = ((p, v2) -> p) -> Vector (p, v2) -> Vector (p, v2)
forall (v :: * -> *) a b.
(Vector v a, Ord b) =>
(a -> b) -> v a -> v a
sortByVG (p, v2) -> p
forall a b. (a, b) -> a
fst Vector (p, v2)
cs
where
cs :: Vector (p, v2)
cs = (v2 -> (p, v2)) -> Vector v2 -> Vector (p, v2)
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
VG.map (\v2
x -> (v2
x v2 -> v d -> p
`distf` v d
q, v2
x)) (Vector v2 -> Vector (p, v2)) -> Vector v2 -> Vector (p, v2)
forall a b. (a -> b) -> a -> b
$ Int -> Vector v2 -> Vector v2
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
VG.take Int
k (Vector v2 -> Vector v2) -> Vector v2 -> Vector v2
forall a b. (a -> b) -> a -> b
$ IntMap (Vector v2) -> Vector v2
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold (IntMap (Vector v2) -> Vector v2)
-> IntMap (Vector v2) -> Vector v2
forall a b. (a -> b) -> a -> b
$ (RPTree d (Vector v2) -> v d -> Vector v2
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
`candidates` v d
q) (RPTree d (Vector v2) -> Vector v2)
-> RPForest d (Vector v2) -> IntMap (Vector v2)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RPForest d (Vector v2)
tts
recallWith :: (Inner SVector v, VU.Unbox a, Fractional a, Ord a, Ord d, Ord p) =>
(p -> v a -> d)
-> RPForest a (V.Vector p)
-> Int
-> v a
-> a
recallWith :: (p -> v a -> d) -> RPForest a (Vector p) -> Int -> v a -> a
recallWith p -> v a -> d
distf RPForest a (Vector p)
tt Int
k v a
q = IntMap a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum IntMap a
rs a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
where
rs :: IntMap a
rs = (RPTree a (Vector p) -> a) -> RPForest a (Vector p) -> IntMap a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\RPTree a (Vector p)
t -> (p -> v a -> d) -> RPTree a (Vector p) -> Int -> v a -> a
forall a1 (v :: * -> *) d a3 a2.
(Fractional a1, Inner SVector v, Ord d, Unbox d, Num d, Ord a3,
Ord a2) =>
(a2 -> v d -> a3) -> RPTree d (Vector a2) -> Int -> v d -> a1
recallWith1 p -> v a -> d
distf RPTree a (Vector p)
t Int
k v a
q) RPForest a (Vector p)
tt
n :: Int
n = RPForest a (Vector p) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length RPForest a (Vector p)
tt
recallWith1 :: (Fractional a1, Inner SVector v, Ord d, VU.Unbox d,
Num d, Ord a3, Ord a2) =>
(a2 -> v d -> a3)
-> RPTree d (V.Vector a2)
-> Int
-> v d
-> a1
recallWith1 :: (a2 -> v d -> a3) -> RPTree d (Vector a2) -> Int -> v d -> a1
recallWith1 a2 -> v d -> a3
distf RPTree d (Vector a2)
tt Int
k v d
q = Int -> a1
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Set a2 -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Set a2
aintk) a1 -> a1 -> a1
forall a. Fractional a => a -> a -> a
/ Int -> a1
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
where
xs :: Vector a2
xs = RPTree d (Vector a2) -> Vector a2
forall m d. Monoid m => RPTree d m -> m
points RPTree d (Vector a2)
tt
dists :: [(a2, a3)]
dists = ((a2, a3) -> (a2, a3) -> Ordering) -> [(a2, a3)] -> [(a2, a3)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((a2, a3) -> a3) -> (a2, a3) -> (a2, a3) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (a2, a3) -> a3
forall a b. (a, b) -> b
snd) ([(a2, a3)] -> [(a2, a3)]) -> [(a2, a3)] -> [(a2, a3)]
forall a b. (a -> b) -> a -> b
$ Vector (a2, a3) -> [(a2, a3)]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Vector (a2, a3) -> [(a2, a3)]) -> Vector (a2, a3) -> [(a2, a3)]
forall a b. (a -> b) -> a -> b
$ (a2 -> (a2, a3)) -> Vector a2 -> Vector (a2, a3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a2
x -> (a2
x, a2
x a2 -> v d -> a3
`distf` v d
q)) Vector a2
xs
kk :: Set a2
kk = [a2] -> Set a2
forall a. Ord a => [a] -> Set a
S.fromList ([a2] -> Set a2) -> [a2] -> Set a2
forall a b. (a -> b) -> a -> b
$ ((a2, a3) -> a2) -> [(a2, a3)] -> [a2]
forall a b. (a -> b) -> [a] -> [b]
map (a2, a3) -> a2
forall a b. (a, b) -> a
fst ([(a2, a3)] -> [a2]) -> [(a2, a3)] -> [a2]
forall a b. (a -> b) -> a -> b
$ Int -> [(a2, a3)] -> [(a2, a3)]
forall a. Int -> [a] -> [a]
take Int
k [(a2, a3)]
dists
aa :: Set a2
aa = Vector a2 -> Set a2
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> Set a
set (Vector a2 -> Set a2) -> Vector a2 -> Set a2
forall a b. (a -> b) -> a -> b
$ RPTree d (Vector a2) -> v d -> Vector a2
forall (v :: * -> *) d xs.
(Inner SVector v, Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs -> v d -> xs
candidates RPTree d (Vector a2)
tt v d
q
aintk :: Set a2
aintk = Set a2
aa Set a2 -> Set a2 -> Set a2
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set a2
kk
set :: (Foldable t, Ord a) => t a -> S.Set a
set :: t a -> Set a
set = (Set a -> a -> Set a) -> Set a -> t a -> Set a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((a -> Set a -> Set a) -> Set a -> a -> Set a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert) Set a
forall a. Monoid a => a
mempty
{-# SCC candidates #-}
candidates :: (Inner SVector v, VU.Unbox d, Ord d, Num d, Semigroup xs) =>
RPTree d xs
-> v d
-> xs
candidates :: RPTree d xs -> v d -> xs
candidates (RPTree Vector (SVector d)
rvs RPT d xs
tt) v d
x = Int -> RPT d xs -> xs
forall a. Semigroup a => Int -> RPT d a -> a
go Int
0 RPT d xs
tt
where
go :: Int -> RPT d a -> a
go Int
_ (Tip a
xs) = a
xs
go Int
ixLev (Bin d
thr Margin d
margin RPT d a
ltree RPT d a
rtree) = do
let
(d
mglo, d
mghi) = Margin d -> (d, d)
forall a. Margin a -> (a, a)
getMargin Margin d
margin
r :: SVector d
r = Vector (SVector d)
rvs Vector (SVector d) -> Int -> SVector d
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.! Int
ixLev
proj :: d
proj = SVector d
r SVector d -> v d -> d
forall (u :: * -> *) (v :: * -> *) a.
(Inner u v, Unbox a, Num a) =>
u a -> v a -> a
`inner` v d
x
i' :: Int
i' = Int -> Int
forall a. Enum a => a -> a
succ Int
ixLev
dl :: d
dl = d -> d
forall a. Num a => a -> a
abs (d
mglo d -> d -> d
forall a. Num a => a -> a -> a
- d
proj)
dr :: d
dr = d -> d
forall a. Num a => a -> a
abs (d
mghi d -> d -> d
forall a. Num a => a -> a -> a
- d
proj)
if | d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr Bool -> Bool -> Bool
&&
d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
| d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
thr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree
| d
proj d -> d -> Bool
forall a. Ord a => a -> a -> Bool
> d
thr Bool -> Bool -> Bool
&&
d
dl d -> d -> Bool
forall a. Ord a => a -> a -> Bool
< d
dr -> Int -> RPT d a -> a
go Int
i' RPT d a
ltree a -> a -> a
forall a. Semigroup a => a -> a -> a
<> Int -> RPT d a -> a
go Int
i' RPT d a
rtree
| Bool
otherwise -> Int -> RPT d a -> a
go Int
i' RPT d a
rtree