{-# LANGUAGE FlexibleContexts #-}
module Statistics.Test.Internal (
    rank
  , rankUnsorted  
  , splitByTags  
  ) where

import Data.Ord
import           Data.Vector.Generic           ((!))
import qualified Data.Vector.Generic         as G
import qualified Data.Vector.Generic.Mutable as M
import Statistics.Function


-- Private data type for unfolding
data Rank v a = Rank {
      forall (v :: * -> *) a. Rank v a -> Int
rankCnt :: {-# UNPACK #-} !Int        -- Number of ranks to return
    , forall (v :: * -> *) a. Rank v a -> Double
rankVal :: {-# UNPACK #-} !Double     -- Rank to return
    , forall (v :: * -> *) a. Rank v a -> Double
rankNum :: {-# UNPACK #-} !Double     -- Current rank
    , forall (v :: * -> *) a. Rank v a -> v a
rankVec :: v a                        -- Remaining vector
    }

-- | Calculate rank of every element of sample. In case of ties ranks
--   are averaged. Sample should be already sorted in ascending order.
--
--   Rank is index of element in the sample, numeration starts from 1.
--   In case of ties average of ranks of equal elements is assigned
--   to each
--
-- >>> rank (==) (fromList [10,20,30::Int])
-- > fromList [1.0,2.0,3.0]
--
-- >>> rank (==) (fromList [10,10,10,30::Int])
-- > fromList [2.0,2.0,2.0,4.0]
rank :: (G.Vector v a, G.Vector v Double)
     => (a -> a -> Bool)        -- ^ Equivalence relation
     -> v a                     -- ^ Vector to rank
     -> v Double
rank :: forall (v :: * -> *) a.
(Vector v a, Vector v Double) =>
(a -> a -> Bool) -> v a -> v Double
rank a -> a -> Bool
eq v a
vec = forall (v :: * -> *) a b.
Vector v a =>
(b -> Maybe (a, b)) -> b -> v a
G.unfoldr forall {v :: * -> *}.
Vector v a =>
Rank v a -> Maybe (Double, Rank v a)
go (forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank Int
0 (-Double
1) Double
1 v a
vec)
  where
    go :: Rank v a -> Maybe (Double, Rank v a)
go (Rank Int
0 Double
_ Double
r v a
v)
      | forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v a
v  = forall a. Maybe a
Nothing
      | Bool
otherwise =
          case forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
h of
            Int
1 -> forall a. a -> Maybe a
Just (Double
r, forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank Int
0 Double
0 (Double
rforall a. Num a => a -> a -> a
+Double
1) v a
rest)
            Int
n -> Rank v a -> Maybe (Double, Rank v a)
go Rank { rankCnt :: Int
rankCnt = Int
n
                         , rankVal :: Double
rankVal = Double
0.5 forall a. Num a => a -> a -> a
* (Double
rforall a. Num a => a -> a -> a
*Double
2 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nforall a. Num a => a -> a -> a
-Int
1))
                         , rankNum :: Double
rankNum = Double
r forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
                         , rankVec :: v a
rankVec = v a
rest
                         }
          where
            (v a
h,v a
rest) = forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> (v a, v a)
G.span (a -> a -> Bool
eq forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> a
G.head v a
v) v a
v
    go (Rank Int
n Double
val Double
r v a
v) = forall a. a -> Maybe a
Just (Double
val, forall (v :: * -> *) a. Int -> Double -> Double -> v a -> Rank v a
Rank (Int
nforall a. Num a => a -> a -> a
-Int
1) Double
val Double
r v a
v)
{-# INLINE rank #-}

-- | Compute rank of every element of vector. Unlike rank it doesn't
--   require sample to be sorted.
rankUnsorted :: ( Ord a
                , G.Vector v a
                , G.Vector v Int
                , G.Vector v Double
                , G.Vector v (Int, a)
                )
             => v a
             -> v Double
rankUnsorted :: forall a (v :: * -> *).
(Ord a, Vector v a, Vector v Int, Vector v Double,
 Vector v (Int, a)) =>
v a -> v Double
rankUnsorted v a
xs = forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create forall a b. (a -> b) -> a -> b
$ do
    -- Put ranks into their original positions
    -- NOTE: backpermute will do wrong thing
    Mutable v s Double
vec <- forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
M.new Int
n
    forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m ()) -> m ()
for Int
0 Int
n forall a b. (a -> b) -> a -> b
$ \Int
i ->
      forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite Mutable v s Double
vec (v Int
index forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i) (v Double
ranks forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i)
    forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s Double
vec
  where
    n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
    -- Calculate ranks for sorted array
    ranks :: v Double
ranks = forall (v :: * -> *) a.
(Vector v a, Vector v Double) =>
(a -> a -> Bool) -> v a -> v Double
rank forall a. Eq a => a -> a -> Bool
(==) v a
sorted
    -- Sort vector and retain original indices of elements
    (v Int
index, v a
sorted)
      = forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v (a, b) -> (v a, v b)
G.unzip
      forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) e. Vector v e => Comparison e -> v e -> v e
sortBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing forall a b. (a, b) -> b
snd)
      forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) e.
(Vector v e, Vector v Int, Vector v (Int, e)) =>
v e -> v (Int, e)
indexed v a
xs
{-# INLINE rankUnsorted #-}


-- | Split tagged vector
splitByTags :: (G.Vector v a, G.Vector v (Bool,a)) => v (Bool,a) -> (v a, v a)
splitByTags :: forall (v :: * -> *) a.
(Vector v a, Vector v (Bool, a)) =>
v (Bool, a) -> (v a, v a)
splitByTags v (Bool, a)
vs = (forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a b. (a, b) -> b
snd v (Bool, a)
a, forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map forall a b. (a, b) -> b
snd v (Bool, a)
b)
  where
    (v (Bool, a)
a,v (Bool, a)
b) = forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> (v a, v a)
G.unstablePartition forall a b. (a, b) -> a
fst v (Bool, a)
vs
{-# INLINE splitByTags #-}