```-- |
-- Copyright  : (c) 2012 Daniël de Kok
--
-- Maintainer : Daniël de Kok <me@danieldk.eu>
-- Stability  : experimental
--
-- This module provides functionality to perform approximate randomization
-- tests (Noreen, 1989).

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE FlexibleContexts #-}

module Statistics.Test.ApproxRand (
-- * Description
-- \$description

-- * Examples
-- \$examples

-- * Data types
TestOptions(..),
TestResult(..),
Significance(..),
RandWithError,

-- * Approximate randomization tests
approxRandTest,
approxRandStats,

approxRandPairTest,
approxRandPairStats,

-- * Test statistics
TestStatistic,
differenceMean,
meanDifference,
varianceRatio
) where

import           Prelude hiding ((++))
import           Data.Vector.Generic ((++))
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as GM
import           Data.Word (Word)
import           Statistics.Sample (variance)
import           Statistics.Test.Types (TestType(..))
import           Statistics.Types
import           System.Random.Mersenne.Pure64 (PureMT, randomInt, randomWord)

-- \$description
--
-- Approximate randomization tests rely on a simple premise: given a test
-- statistic, if the null-hypothesis (the samples do not differ) is true,
-- we can randomly swap values between samples without an (extreme) impact
-- on the test statistic. Otherwise, the null-hypothesis must be rejected.
--
-- The test works by generating a given number of sample shuffles and computing
-- the test statistic for each shuffle. If /r/ is the number of shuffled
-- samples where the test statistic is at least as high as the test statistic
-- applied on the original samples; and /N/ the number of shuffles, then
-- the null-hypothesis is rejected iff /(r + 1):(N + 1) < p-value/ (for
-- one-sided tests).
--
-- Two kinds of test are supported:
--
-- * /Paired sample/ ('approxRandPairTest'): values from samples are shuffled
--   pair-wise. This requires the samples to have an equal length.
--
-- * /Unpaired sample/ ('approxRandTest'): values from samples are shuffled
--   among both samples. Consequently the i-th element of one sample does not
--   bear a relationship with the i-th element of the other sample. The
--   shuffled samples retain the sizes of the original samples.
--
-- Both tests can be performed as a one-tailed or two-tailed test.

-- \$examples
-- Both unpaired and paired sample tests use the 'Rand' monad to obtain
-- random numbers. We can obtain a pseudo-random number generator that
-- is seeded using the system clock using the
-- 'System.Random.Mersenne.Pure64.newPureMT' function (please refer to
-- the documentation of 'System.Random.Mersenne.Pure64' for more
-- information):
--
-- > prng <- newPureMT
--
-- Suppose that we have the samples 's1' and 's2'. We could now perform
-- a Two-Tailed randomization test with 10,000 shuffles and the mean
-- difference as the test statistic, by running 'approxRandTest' in the 'Rand'
-- monad (at the /p = 0.01/ level):
--
-- > evalRandom (approxRandTest TwoTailed meanDifference 10000 0.01 s1 s2) prng
--
-- It is also possible to obtain the test statistics of the shuffled samples
-- directly (e.g. to inspect the distribution of test statistics) using the
-- 'approxRandStats'/'approxRandPiarStats' functions:
--
-- > evalRandom (approxRandStats meanDifference 10000 0.01 s1 s2) prng

-- | Computations with random numbers that can fail.
type RandWithError a = ErrorT String Rand a

-- | Options for randomization tests
--
data TestOptions = TestOptions {
toTestType      :: TestType,      -- ^ Type of test ('OneTailed' or 'TwoTailed')
toTestStatistic :: TestStatistic, -- ^ Test statistic
toIterations    :: Int,           -- ^ Number of shuffled samples to create
toPValue        :: Double         -- ^ he p-value at which to test (e.g. 0.05)
}

-- |
-- The result of hypothesis testing.
data TestResult = TestResult {
trSignificance     :: Significance, -- ^ Significance
trStat             :: Double,       -- ^ Test statistic for the samples
trRandomizedStats  :: Sample        -- ^ Test statistics for the
--   randomized samples
} deriving (Eq, Ord, Show)

-- |
-- Significance.
data Significance =
Significant    Double -- ^ The null hypothesis should be rejected
| NotSignificant Double -- ^ Data is compatible with the null hypothesis
deriving (Eq, Ord, Show)

-- |
-- Apply a pair-wise approximate randomization test.
--
-- In pair-wise approximate randomization tests the data points at a given
-- index are swapped between samples with a probability of 0.5. Since
-- swapping is pairwise, the samples should have the same length.
approxRandPairTest ::
TestOptions              -- ^ Options for the test
-> Sample                   -- ^ First sample
-> Sample                   -- ^ Second sample
-> RandWithError TestResult -- ^ The test result
approxRandPairTest (TestOptions testType stat n pTest) s1 s2 = do
stats <- approxRandPairStats stat n s1 s2
let tOrig = stat s1 s2
let sig = significance testType pTest n \$ countExtremes tOrig \$ stats
return \$ TestResult sig tOrig stats

-- |
-- Apply an approximate randomization test.
--
-- In approximate randomization tests, the values of two samples are
-- shuffled among those samples. A test statistic is calculated for
-- the original samples and the shuffled samples, to detect whether the
-- difference of the samples is extreme or not.
approxRandTest ::
TestOptions     -- ^ Options for the test
-> Sample          -- ^ First sample
-> Sample          -- ^ Second sample
-> Rand TestResult -- ^ The test result
approxRandTest (TestOptions testType stat n pTest) s1 s2 = do
stats <- approxRandStats stat n s1 s2
let tOrig = stat s1 s2
let sig = significance testType pTest n \$ countExtremes tOrig stats
return \$ TestResult sig tOrig stats

-- | Determine the significance.
significance ::
TestType     -- ^ Type of test ('OneTailed' or 'TwoTailed')
-> Double       -- ^ The p-value at which to test (e.g. 0.05)
-> Int          -- ^ Number of sample shuffles
-> (Int, Int)   -- ^ Extreme statistic counts
-> Significance -- ^ The test result
significance TwoTailed pTest n =
significant (pTest / 2) . pValue n . uncurry min
significance OneTailed pTest n =
significant pTest . pValue n . snd

-- | Wrap a p-value in a 'TestResult'.
significant ::
Double       -- ^ The p-value at which to test
-> Double       -- ^ The p-value
-> Significance -- ^ The test result
significant pTail p =
if p < pTail then Significant p else NotSignificant p

-- | Calculate a p-value
pValue ::
Int    -- ^ Number of extreme outcomes
-> Int    -- ^ Number of shuffles
-> Double -- ^ The p-value
pValue n r = (fromIntegral r + 1) / (fromIntegral n + 1)

-- |
-- Count extreme test statistic values. If the test statistic value of the
-- original sample is in the right tail, we want to count values equal to
-- or larger than that value. If the value is in the left tail, we want to
-- count value smaller than or equal to that value. Since we do not know
-- the tail (yet), we count both.
--
-- Note: we can determine the tail by (1) averaging the test statistics of
-- the randomized samples, or (2) taking the smaller of the two counts.
countExtremes ::
Double     -- ^ Test statistic value of the original samples
-> Sample     -- ^ Test statistic values of the randomized samples.
-> (Int, Int) -- ^ Count of left- and right-tail extremes.
countExtremes tOrig =
VG.foldl' count (0, 0)
where
count (left, right) tPerm =
let !newLeft = if tPerm <= tOrig then succ left else left in
let !newRight = if tPerm >= tOrig then succ right else right in
(newLeft, newRight)

-- |
-- Generate a given number of pairwise shuffled samples, and calculate
-- the test statistic for each shuffle.
--
-- Since the data points at a given index are swapped (with a probability of
-- 0.5), the samples should have the same length.
approxRandPairStats ::
TestStatistic          -- ^ Test statistic
-> Int                    -- ^ Number of shuffled samples to create
-> Sample                 -- ^ First sample
-> Sample                 -- ^ Second sample
-> RandWithError Sample   -- ^ The statistics of the shuffles
approxRandPairStats stat n s1 s2 = do
when (VG.length s1 /= VG.length s2) \$
throwError "Cannot calculate pairwise statistic: samples have different sizes"
lift \$ liftM VG.fromList \$ replicateM n \$
uncurry stat `liftM` shuffleVectorsPairwise s1 s2

-- |
-- Generate a given number of shuffled samples, and calculate the test
-- statistic for each shuffle.
--
-- This function does not require the samples to have an equal length.
approxRandStats ::
TestStatistic -- ^ Test statistic
-> Int           -- ^ Number of shuffled samples to create
-> Sample        -- ^ First sample
-> Sample        -- ^ Second sample
-> Rand Sample   -- ^ The statistics of the shuffles
approxRandStats stat n s1 s2 =
liftM VG.fromList \$ replicateM n \$ uncurry stat `liftM` shuffleVectors s1 s2

-- | Pair-wise shuffle of two vectors.
shuffleVectorsPairwise :: (VG.Vector v a, VG.Vector v Bool) =>
v a -> v a -> Rand (v a, v a)
shuffleVectorsPairwise vec1 vec2 = do
randomVec <- randomVector (VG.length vec1)
let pv1 = VG.zipWith3 permute vec1 vec2 randomVec
let pv2 = VG.zipWith3 permute vec2 vec1 randomVec
return (pv1, pv2)
where
permute val1 val2 coin =
if coin then val1 else val2

randomVector :: (VG.Vector v Bool) => Int -> Rand (v Bool)
randomVector len =
VG.replicateM len getBool

-- Shuffle values amongst two vectors, keeping the original vector lengths.
shuffleVectors :: VG.Vector v a => v a -> v a -> Rand (v a, v a)
shuffleVectors v1 v2 = do
shuffledVectors <- shuffleVector \$ v1 ++ v2
return (VG.slice 0 (VG.length v1) shuffledVectors,
VG.slice (VG.length v1) (VG.length v2) shuffledVectors)

-- Fisher-Yates shuffle in the Rand monad
shuffleVector :: VG.Vector v a => v a -> Rand (v a)
shuffleVector v =
Rand \$ \s -> case shuffleVector' s v of (sv, s') -> R sv s'

-- Fisher-Yates shuffle
shuffleVector' :: VG.Vector v a => PureMT -> v a -> (v a, PureMT)
shuffleVector' gen v = runST \$ do
let maxIdx = VG.length v - 1
vm   <- VG.thaw v
gen' <- swaps vm 0 maxIdx gen
vmf  <- VG.unsafeFreeze vm
return (vmf, gen')
where
swaps vm idx maxIdx gen'
| idx < maxIdx = do
let (newIdx, gen'') = randomIntR gen' (idx, maxIdx)
GM.unsafeSwap vm idx newIdx
swaps vm (idx + 1) maxIdx gen''
| otherwise = return gen'

-- |
-- A test stastic calculates the difference between two samples. See
-- 'meanDifference' and 'varianceRatio' for examples.
type TestStatistic = Sample -> Sample -> Double

-- |
-- Calculates the difference mean of two samples (/mean(s1 - s2)/). When the
-- two samples do not have an equal length, the trailing elements of the
-- longer vector are ignored.
differenceMean :: TestStatistic
differenceMean v1 v2 =
VG.sum (subVector v1 v2) / fromIntegral (VG.length v1)

-- | Calculates the mean difference of two samples (/mean(s1) - mean(s2)/).
meanDifference :: TestStatistic
meanDifference s1 s2 =
mean s1 - mean s2

-- | Calculate the mean of a sample.
mean :: Sample -> Double
mean = do
t <- VG.sum
l <- VG.length
return \$ t / fromIntegral l

-- | Calculate the ratio of sample variances (/var(s1) : var(s2)/).
varianceRatio :: TestStatistic
varianceRatio v1 v2 =
variance v1 / variance v2

-- | Subtract two vectors.
subVector :: (VG.Vector v n, Num n) => v n -> v n -> v n
subVector = VG.zipWith (-)

subIIW :: Int -> Int -> Word
subIIW a b = fromIntegral a - fromIntegral b
{-# INLINE subIIW #-}

addIWI :: Int -> Word -> Int
addIWI a b = a + fromIntegral b

-- | Generate Int numbers within a range
randomIntR :: PureMT -> (Int, Int) -> (Int, PureMT)
randomIntR gen (a, b)
| n == 0    = randomInt gen
| otherwise = loop gen
where
(a', b') = if a < b then (a, b) else (b, a)
-- Number of different Ints that should be generated
n = 1 + subIIW b' a'
-- The total range of Word can hold x complete n ranges
x = maxBound `div` n
-- Pick from a range the is dividable by n without remainders
s = x * n
loop gen'
| r >= s    = loop gen'' -- r is outside the range, discard it...
| otherwise = (addIWI a' (r `div` x), gen'')
where
(!r, !gen'') = randomWord gen'
{-# INLINE randomIntR #-}
```