{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language LambdaCase #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# options_ghc -Wno-unused-imports #-}
module Data.RPTree.Gen where

import Control.Monad (replicateM, foldM)

-- containers
import qualified Data.IntMap as IM (IntMap, insert, toList)
-- mtl
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.State (MonadState(..), modify)
-- splitmix-distribitions
import System.Random.SplitMix.Distributions (Gen, GenT, stdUniform, bernoulli, exponential, normal, discrete, categorical)
-- transformers
import Control.Monad.Trans.State (StateT(..), runStateT, evalStateT, State, runState, evalState)
-- vector


import qualified Data.Vector.Generic as VG (Vector(..), unfoldrM, length, replicateM, (!))
import qualified Data.Vector.Unboxed as VU (Vector, Unbox, fromList)


import Data.RPTree.Internal (RPTree(..), RPT(..), SVector(..), fromListSv, DVector(..))


-- | Sample without replacement with a single pass over the data
--
-- implements Algorithm L for reservoir sampling
--
-- Li, Kim-Hung (4 December 1994). "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". ACM Transactions on Mathematical Software. 20 (4): 481–493. doi:10.1145/198429.198435
sampleWOR :: (Monad m, Foldable t) =>
             Int -- ^ sample size
          -> t a
          -> GenT m [a]
sampleWOR :: Int -> t a -> GenT m [a]
sampleWOR Int
k t a
xs = do
  (Int
_, ResS a
res) <- (StateT (ResS a) (GenT m) Int -> ResS a -> GenT m (Int, ResS a))
-> ResS a -> StateT (ResS a) (GenT m) Int -> GenT m (Int, ResS a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT (ResS a) (GenT m) Int -> ResS a -> GenT m (Int, ResS a)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ResS a
forall a. ResS a
z (StateT (ResS a) (GenT m) Int -> GenT m (Int, ResS a))
-> StateT (ResS a) (GenT m) Int -> GenT m (Int, ResS a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> StateT (ResS a) (GenT m) Int)
-> Int -> t a -> StateT (ResS a) (GenT m) Int
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Int -> a -> StateT (ResS a) (GenT m) Int
forall a (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadState (ResS a) (t (GenT m)), MonadTrans t, Monad m) =>
Int -> a -> t (GenT m) Int
insf Int
0 t a
xs
  [a] -> GenT m [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> GenT m [a]) -> [a] -> GenT m [a]
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> a -> b
$ IntMap a -> [(Int, a)]
forall a. IntMap a -> [(Int, a)]
IM.toList (ResS a -> IntMap a
forall a. ResS a -> IntMap a
rsReservoir ResS a
res)
  where
    z :: ResS a
z = IntMap a -> ResS a
forall a. IntMap a -> ResS a
RSPartial IntMap a
forall a. Monoid a => a
mempty
    insf :: Int -> a -> t (GenT m) Int
insf Int
i a
x = do
      ResS a
st <- t (GenT m) (ResS a)
forall s (m :: * -> *). MonadState s m => m s
get
      case ResS a
st of
        RSPartial IntMap a
acc -> do
          Double
w <- GenT m Double -> t (GenT m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (GenT m Double -> t (GenT m) Double)
-> GenT m Double -> t (GenT m) Double
forall a b. (a -> b) -> a -> b
$ Int -> GenT m Double
forall (m :: * -> *). Monad m => Int -> GenT m Double
genW Int
k
          Int
s <- GenT m Int -> t (GenT m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (GenT m Int -> t (GenT m) Int) -> GenT m Int -> t (GenT m) Int
forall a b. (a -> b) -> a -> b
$ Double -> GenT m Int
forall (m :: * -> *). Monad m => Double -> GenT m Int
genS Double
w
          let
            acc' :: IntMap a
acc' = Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
i a
x IntMap a
acc
            ila :: Int
ila = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            st' :: ResS a
st'
              | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
k = IntMap a -> Int -> Double -> ResS a
forall a. IntMap a -> Int -> Double -> ResS a
RSFull IntMap a
acc' Int
ila Double
w
              | Bool
otherwise = IntMap a -> ResS a
forall a. IntMap a -> ResS a
RSPartial IntMap a
acc'
          ResS a -> t (GenT m) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ResS a
st'
          Int -> t (GenT m) Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
forall a. Enum a => a -> a
succ Int
i)
        RSFull IntMap a
acc Int
ila0 Double
w0 -> do
          case Int
i Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
ila0 of
            Ordering
EQ -> do
              Double
w <- GenT m Double -> t (GenT m) Double
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (GenT m Double -> t (GenT m) Double)
-> GenT m Double -> t (GenT m) Double
forall a b. (a -> b) -> a -> b
$ Int -> GenT m Double
forall (m :: * -> *). Monad m => Int -> GenT m Double
genW Int
k
              Int
s <- GenT m Int -> t (GenT m) Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (GenT m Int -> t (GenT m) Int) -> GenT m Int -> t (GenT m) Int
forall a b. (a -> b) -> a -> b
$ Double -> GenT m Int
forall (m :: * -> *). Monad m => Double -> GenT m Int
genS Double
w0
              let
                ila :: Int
ila = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
              IntMap a
acc' <- GenT m (IntMap a) -> t (GenT m) (IntMap a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (GenT m (IntMap a) -> t (GenT m) (IntMap a))
-> GenT m (IntMap a) -> t (GenT m) (IntMap a)
forall a b. (a -> b) -> a -> b
$ Int -> IntMap a -> a -> GenT m (IntMap a)
forall (m :: * -> *) a.
Monad m =>
Int -> IntMap a -> a -> GenT m (IntMap a)
replaceInBuffer Int
k IntMap a
acc a
x
              let
                w' :: Double
w' = Double
w0 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
w
              ResS a -> t (GenT m) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (IntMap a -> Int -> Double -> ResS a
forall a. IntMap a -> Int -> Double -> ResS a
RSFull IntMap a
acc' Int
ila Double
w')
              Int -> t (GenT m) Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
forall a. Enum a => a -> a
succ Int
i)
            Ordering
_ -> Int -> t (GenT m) Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Int
forall a. Enum a => a -> a
succ Int
i)

data ResS a = RSPartial { ResS a -> IntMap a
rsReservoir :: IM.IntMap a }
            | RSFull {
                rsReservoir :: IM.IntMap a -- ^ reservoir
                , ResS a -> Int
rsfLookAh :: !Int -- ^ lookahead index
                , ResS a -> Double
rsfW :: !Double -- ^ W
                } deriving (ResS a -> ResS a -> Bool
(ResS a -> ResS a -> Bool)
-> (ResS a -> ResS a -> Bool) -> Eq (ResS a)
forall a. Eq a => ResS a -> ResS a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ResS a -> ResS a -> Bool
$c/= :: forall a. Eq a => ResS a -> ResS a -> Bool
== :: ResS a -> ResS a -> Bool
$c== :: forall a. Eq a => ResS a -> ResS a -> Bool
Eq, Int -> ResS a -> ShowS
[ResS a] -> ShowS
ResS a -> String
(Int -> ResS a -> ShowS)
-> (ResS a -> String) -> ([ResS a] -> ShowS) -> Show (ResS a)
forall a. Show a => Int -> ResS a -> ShowS
forall a. Show a => [ResS a] -> ShowS
forall a. Show a => ResS a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResS a] -> ShowS
$cshowList :: forall a. Show a => [ResS a] -> ShowS
show :: ResS a -> String
$cshow :: forall a. Show a => ResS a -> String
showsPrec :: Int -> ResS a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> ResS a -> ShowS
Show)

genW :: (Monad m) => Int -> GenT m Double
genW :: Int -> GenT m Double
genW Int
k = do
  Double
u <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  Double -> GenT m Double
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> GenT m Double) -> Double -> GenT m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double
forall a. Floating a => a -> a
log Double
u Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k)

genS :: (Monad m) => Double -> GenT m Int
genS :: Double -> GenT m Int
genS Double
w = do
  Double
u <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  Int -> GenT m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> GenT m Int) -> Int -> GenT m Int
forall a b. (a -> b) -> a -> b
$ Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
log Double
u Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
w))

-- | Replaces a value at a random position within the buffer
replaceInBuffer :: (Monad m) =>
                   Int
                -> IM.IntMap a
                -> a
                -> GenT m (IM.IntMap a)
replaceInBuffer :: Int -> IntMap a -> a -> GenT m (IntMap a)
replaceInBuffer Int
k IntMap a
imm a
y = do
  Double
u <- GenT m Double
forall (m :: * -> *). Monad m => GenT m Double
stdUniform
  let ix :: Int
ix = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
u)
  IntMap a -> GenT m (IntMap a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntMap a -> GenT m (IntMap a)) -> IntMap a -> GenT m (IntMap a)
forall a b. (a -> b) -> a -> b
$ Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
ix a
y IntMap a
imm







-- mixtures

mixtureN :: Monad m => [(Double, GenT m b)] -> GenT m b
mixtureN :: [(Double, GenT m b)] -> GenT m b
mixtureN [(Double, GenT m b)]
pgs = GenT m b
go
  where
    ([Double]
ps, [GenT m b]
gs) = [(Double, GenT m b)] -> ([Double], [GenT m b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Double, GenT m b)]
pgs
    go :: GenT m b
go = do
      Maybe Int
miix <- [Double] -> GenT m (Maybe Int)
forall (m :: * -> *) (t :: * -> *).
(Monad m, Foldable t) =>
t Double -> GenT m (Maybe Int)
categorical [Double]
ps
      case Maybe Int
miix of
        Maybe Int
Nothing -> [GenT m b]
gs [GenT m b] -> Int -> GenT m b
forall a. [a] -> Int -> a
!! Int
0
        Just Int
i  -> do
          let p :: GenT m b
p = [GenT m b]
gs [GenT m b] -> Int -> GenT m b
forall a. [a] -> Int -> a
!! Int
i
          GenT m b
p


normalSparse2 :: Monad m => Double -> Int -> GenT m (SVector Double)
normalSparse2 :: Double -> Int -> GenT m (SVector Double)
normalSparse2 Double
pnz Int
d = do
  Bool
b <- Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
0.5
  if Bool
b
    then Double -> Int -> GenT m Double -> GenT m (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
d (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
0 Double
0.5)
    else Double -> Int -> GenT m Double -> GenT m (SVector Double)
forall (m :: * -> *) a.
(Monad m, Unbox a) =>
Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
pnz Int
d (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
2 Double
0.5)

normalDense2 :: Monad m => Int -> GenT m (DVector Double)
normalDense2 :: Int -> GenT m (DVector Double)
normalDense2 Int
d = do
  Bool
b <- Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
0.5
  if Bool
b
    then Int -> GenT m Double -> GenT m (DVector Double)
forall (m :: * -> *) a.
(Monad m, Vector Vector a) =>
Int -> GenT m a -> GenT m (DVector a)
dense Int
d (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
0 Double
0.5)
    else Int -> GenT m Double -> GenT m (DVector Double)
forall (m :: * -> *) a.
(Monad m, Vector Vector a) =>
Int -> GenT m a -> GenT m (DVector a)
dense Int
d (Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
2 Double
0.5)

normal2 :: (Monad m) => GenT m (DVector Double)
normal2 :: GenT m (DVector Double)
normal2 = do
  Bool
b <- Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
0.5
  if Bool
b
    then Int -> GenT m Double -> GenT m (DVector Double)
forall (m :: * -> *) a.
(Monad m, Vector Vector a) =>
Int -> GenT m a -> GenT m (DVector a)
dense Int
2 (GenT m Double -> GenT m (DVector Double))
-> GenT m Double -> GenT m (DVector Double)
forall a b. (a -> b) -> a -> b
$ Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
0 Double
0.5
    else Int -> GenT m Double -> GenT m (DVector Double)
forall (m :: * -> *) a.
(Monad m, Vector Vector a) =>
Int -> GenT m a -> GenT m (DVector a)
dense Int
2 (GenT m Double -> GenT m (DVector Double))
-> GenT m Double -> GenT m (DVector Double)
forall a b. (a -> b) -> a -> b
$ Double -> Double -> GenT m Double
forall (m :: * -> *). Monad m => Double -> Double -> GenT m Double
normal Double
2 Double
0.5


-- | Generate a sparse random vector with a given nonzero density and components sampled from the supplied random generator
sparse :: (Monad m, VU.Unbox a) =>
          Double -- ^ nonzero density
       -> Int -- ^ vector dimension
       -> GenT m a -- ^ random generator of vector components
       -> GenT m (SVector a)
sparse :: Double -> Int -> GenT m a -> GenT m (SVector a)
sparse Double
p Int
sz GenT m a
rand = Int -> Vector (Int, a) -> SVector a
forall a. Int -> Vector (Int, a) -> SVector a
SV Int
sz (Vector (Int, a) -> SVector a)
-> GenT m (Vector (Int, a)) -> GenT m (SVector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Double -> Int -> GenT m a -> GenT m (Vector (Int, a))
forall (m :: * -> *) (v :: * -> *) a.
(Monad m, Vector v (Int, a)) =>
Double -> Int -> GenT m a -> GenT m (v (Int, a))
sparseVG Double
p Int
sz GenT m a
rand

-- | Generate a dense random vector with components sampled from the supplied random generator
dense :: (Monad m, VG.Vector VU.Vector a) =>
         Int -- ^ vector dimension
      -> GenT m a -- ^ random generator of vector components
      -> GenT m (DVector a)
dense :: Int -> GenT m a -> GenT m (DVector a)
dense Int
sz GenT m a
rand = Vector a -> DVector a
forall a. Vector a -> DVector a
DV (Vector a -> DVector a) -> GenT m (Vector a) -> GenT m (DVector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> GenT m a -> GenT m (Vector a)
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, Monad m) =>
Int -> m a -> m (v a)
denseVG Int
sz GenT m a
rand



-- | Sample a dense random vector
denseVG :: (VG.Vector v a, Monad m) =>
           Int -- ^ vector dimension
        -> m a
        -> m (v a)
denseVG :: Int -> m a -> m (v a)
denseVG Int
sz m a
rand = (Int -> m (Maybe (a, Int))) -> Int -> m (v a)
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
(b -> m (Maybe (a, b))) -> b -> m (v a)
VG.unfoldrM Int -> m (Maybe (a, Int))
mkf Int
0
  where
    mkf :: Int -> m (Maybe (a, Int))
mkf Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz = Maybe (a, Int) -> m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (a, Int)
forall a. Maybe a
Nothing
      | Bool
otherwise = do
          a
x <- m a
rand
          Maybe (a, Int) -> m (Maybe (a, Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, Int) -> m (Maybe (a, Int)))
-> Maybe (a, Int) -> m (Maybe (a, Int))
forall a b. (a -> b) -> a -> b
$ (a, Int) -> Maybe (a, Int)
forall a. a -> Maybe a
Just (a
x, Int -> Int
forall a. Enum a => a -> a
succ Int
i)

-- | Sample a sparse random vector
sparseVG :: (Monad m, VG.Vector v (Int, a)) =>
            Double -- ^ nonzero density
         -> Int  -- ^ vector dimension
         -> GenT m a
         -> GenT m (v (Int, a))
sparseVG :: Double -> Int -> GenT m a -> GenT m (v (Int, a))
sparseVG Double
p Int
sz GenT m a
rand = (Int -> GenT m (Maybe ((Int, a), Int)))
-> Int -> GenT m (v (Int, a))
forall (m :: * -> *) (v :: * -> *) a b.
(Monad m, Vector v a) =>
(b -> m (Maybe (a, b))) -> b -> m (v a)
VG.unfoldrM Int -> GenT m (Maybe ((Int, a), Int))
mkf Int
0
  where
    mkf :: Int -> GenT m (Maybe ((Int, a), Int))
mkf Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz = Maybe ((Int, a), Int) -> GenT m (Maybe ((Int, a), Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ((Int, a), Int)
forall a. Maybe a
Nothing
      | Bool
otherwise = do
          Bool
flag <- Double -> GenT m Bool
forall (m :: * -> *). Monad m => Double -> GenT m Bool
bernoulli Double
p
          if Bool
flag
            then
            do
              a
x <- GenT m a
rand
              Maybe ((Int, a), Int) -> GenT m (Maybe ((Int, a), Int))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ((Int, a), Int) -> GenT m (Maybe ((Int, a), Int)))
-> Maybe ((Int, a), Int) -> GenT m (Maybe ((Int, a), Int))
forall a b. (a -> b) -> a -> b
$ ((Int, a), Int) -> Maybe ((Int, a), Int)
forall a. a -> Maybe a
Just ((Int
i, a
x), Int -> Int
forall a. Enum a => a -> a
succ Int
i)
            else
              Int -> GenT m (Maybe ((Int, a), Int))
mkf (Int -> Int
forall a. Enum a => a -> a
succ Int
i)