{-# LANGUAGE FlexibleContexts #-}

{- |
Module      :  Statistics.Pava.Common
Description :  Auxiliary functions
Copyright   :  (c) Dominik Schrempf, 2020
License     :  GPL-3.0-or-later

Maintainer  :  dominik.schrempf@gmail.com
Stability   :  unstable
Portability :  portable

Creation date: Mon Jun  8 11:03:12 2020.

-}

module Statistics.Pava.Common
  ( slope
  , strictlyOrdered
  , smooth
  , unsafeSmooth
  , reverse3
  )
where

import qualified Data.Vector.Generic           as V
import qualified Data.Vector.Generic.Mutable   as M
import           Data.Vector.Generic            ( Vector )

-- | Calculate the slope between to points.
slope :: (Real a, Real b) => a -> a -> b -> b -> Double
slope :: a -> a -> b -> b -> Double
slope a
x0 a
x1 b
y0 b
y1 = b -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (b
y1 b -> b -> b
forall a. Num a => a -> a -> a
- b
y0) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (a
x1 a -> a -> a
forall a. Num a => a -> a -> a
- a
x0)
{-# SPECIALIZE slope :: Int -> Int -> Double -> Double -> Double #-}
{-# SPECIALIZE slope :: Double -> Double -> Double -> Double -> Double #-}
{-# INLINE slope #-}

-- -- Differences between values in vector.
-- diff :: (Num a, Vector v a) => v a -> v a
-- diff v = V.zipWith (-) (V.tail v) v
-- {-# SPECIALIZE diff :: (Vector v Double) => v Double -> v Double #-}

-- | Check if vector is ordered strictly (<).
strictlyOrdered :: (Ord a, Vector v a, Vector v Bool) => v a -> Bool
strictlyOrdered :: v a -> Bool
strictlyOrdered v a
xs | v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v a
xs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Bool
True
                   | Bool
otherwise        = v Bool -> Bool
forall (v :: * -> *). Vector v Bool => v Bool -> Bool
V.and (v Bool -> Bool) -> v Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> a -> Bool) -> v a -> v a -> v Bool
forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
V.zipWith a -> a -> Bool
forall a. Ord a => a -> a -> Bool
(<) v a
xs (v a -> v a
forall (v :: * -> *) a. Vector v a => v a -> v a
V.tail v a
xs)

-- | Fill in missing values of an indexed vector.
--
-- @
--  smooth [-2, 2, 4, 5] [0.0, 4.0, 10.0, 88.0] = [0.0, 1.0, 2.0, 3.0, 4.0, 7.0, 10.0, 88.0]
-- @
smooth
  :: (Vector v Bool, Vector v Double, Vector v Int)
  => v Int
  -> v Double
  -> v Double
smooth :: v Int -> v Double -> v Double
smooth v Int
xs v Double
ys
  | v Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v Int
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= v Double -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v Double
ys = [Char] -> v Double
forall a. HasCallStack => [Char] -> a
error
    [Char]
"smooth: Index and value vector have different length."
  | Bool -> Bool
not (v Int -> Bool
forall a (v :: * -> *).
(Ord a, Vector v a, Vector v Bool) =>
v a -> Bool
strictlyOrdered v Int
xs) = [Char] -> v Double
forall a. HasCallStack => [Char] -> a
error
    [Char]
"smooth: Index vector is not strictly ordered."
  | Bool
otherwise = v Int -> v Double -> v Double
forall (v :: * -> *).
(Vector v Bool, Vector v Double, Vector v Int) =>
v Int -> v Double -> v Double
unsafeSmooth v Int
xs v Double
ys

-- | See 'smooth'.
--
-- Assume that:
-- - the lengths of the provided vectors are equal;
-- - the predictors are ordered.
unsafeSmooth
  :: (Vector v Bool, Vector v Double, Vector v Int)
  => v Int
  -> v Double
  -> v Double
unsafeSmooth :: v Int -> v Double -> v Double
unsafeSmooth v Int
xs v Double
ys
  | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = v Double
forall (v :: * -> *) a. Vector v a => v a
V.empty
  | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Int -> v Double -> v Double
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
V.take Int
1 v Double
ys
  | Bool
otherwise = (forall s. ST s (Mutable v s Double)) -> v Double
forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
V.create
    (do
      Mutable v s Double
zs <- Int -> ST s (Mutable v (PrimState (ST s)) Double)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
M.new Int
m
      Mutable v (PrimState (ST s)) Double
-> Int -> Int -> (Int, Int, Double, Double) -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Double) =>
v (PrimState m) Double
-> Int -> Int -> (Int, Int, Double, Double) -> m ()
go Mutable v s Double
Mutable v (PrimState (ST s)) Double
zs Int
0 Int
1 (Int -> (Int, Int, Double, Double)
bounds Int
1)
      Mutable v s Double -> ST s (Mutable v s Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s Double
zs
    )
 where
  l :: Int
l = v Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v Int
xs
  a :: Int
a = v Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> a
V.head v Int
xs
  b :: Int
b = v Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> a
V.last v Int
xs
  m :: Int
m = Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  -- 0 <= i < m; index traversing resulting vector
  -- 0 <= j < l; index traversing given vectors
  bounds :: Int -> (Int, Int, Double, Double)
bounds Int
i = (v Int
xs v Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
V.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1), v Int
xs v Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
V.! Int
i, v Double
ys v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
V.! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1), v Double
ys v Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
V.! Int
i)
  go :: v (PrimState m) Double
-> Int -> Int -> (Int, Int, Double, Double) -> m ()
go v (PrimState m) Double
zs Int
i Int
j (Int
il, Int
ir, Double
yl, Double
yr)
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
m = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    | Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
ir = do
      v (PrimState m) Double -> Int -> Double -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v (PrimState m) Double
zs Int
i Double
yr
      v (PrimState m) Double
-> Int -> Int -> (Int, Int, Double, Double) -> m ()
go v (PrimState m) Double
zs (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> (Int, Int, Double, Double)
bounds (Int -> (Int, Int, Double, Double))
-> Int -> (Int, Int, Double, Double)
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    | Bool
otherwise = do
      v (PrimState m) Double -> Int -> Double -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write v (PrimState m) Double
zs Int
i (Double
yl Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
dy)
      v (PrimState m) Double
-> Int -> Int -> (Int, Int, Double, Double) -> m ()
go v (PrimState m) Double
zs (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
j (Int
il, Int
ir, Double
yl, Double
yr)
   where
    dx :: Int
dx = Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
il
    dy :: Double
dy = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Int -> Double -> Double -> Double
forall a b. (Real a, Real b) => a -> a -> b -> b -> Double
slope Int
il Int
ir Double
yl Double
yr

-- | Reverse lists in a three-tuple.
reverse3 :: ([a], [b], [c]) -> ([a], [b], [c])
reverse3 :: ([a], [b], [c]) -> ([a], [b], [c])
reverse3 ([a]
xs, [b]
ys, [c]
zs) = ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
xs, [b] -> [b]
forall a. [a] -> [a]
reverse [b]
ys, [c] -> [c]
forall a. [a] -> [a]
reverse [c]
zs)