{-# LANGUAGE FlexibleContexts #-}

{- |
Module      :  Statistics.Lcm
Description :  Compute least concave majorants
Copyright   :  2021 Dominik Schrempf
License     :  GPL-3.0-or-later

Maintainer  :  dominik.schrempf@gmail.com
Stability   :  unstable
Portability :  portable

Creation date: Sat Jun  6 16:36:32 2020.

For general help, please refer to the README distributed with the library.

-}

module Statistics.Lcm
  ( lcm
  , unsafeLcm
  )
where

import           Prelude                 hiding ( lcm )
import qualified Data.Vector.Generic           as V
import           Data.Vector.Generic            ( Vector )

import           Statistics.Pava.Common

-- Pool the last value in a vector until convexity is preserved.
pool
  :: (Ord a, Real a, Show a, Real b, Show b)
  => [a]
  -> [b]
  -> [Double]
  -> a
  -> b
  -> ([a], [b], [Double])
-- Points are ordered x0, x1, x2; and y0, y1, y2; s01 is slope from x0 to x1,
-- and so on.
pool :: [a] -> [b] -> [Double] -> a -> b -> ([a], [b], [Double])
pool xs :: [a]
xs@(a
x1 : [a]
_) ys :: [b]
ys@(b
y1 : [b]
_) ss :: [Double]
ss@(Double
s01 : [Double]
_) a
x2 b
y2 = if Double
s01 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
s12
  then (a
x2 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs, b
y2 b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
ys, Double
s12 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
ss)
  else [a] -> [b] -> [Double] -> a -> b -> ([a], [b], [Double])
forall a b.
(Ord a, Real a, Show a, Real b, Show b) =>
[a] -> [b] -> [Double] -> a -> b -> ([a], [b], [Double])
pool ([a] -> [a]
forall a. [a] -> [a]
tail [a]
xs) ([b] -> [b]
forall a. [a] -> [a]
tail [b]
ys) ([Double] -> [Double]
forall a. [a] -> [a]
tail [Double]
ss) a
x2 b
y2
  where s12 :: Double
s12 = a -> a -> b -> b -> Double
forall a b. (Real a, Real b) => a -> a -> b -> b -> Double
slope a
x1 a
x2 b
y1 b
y2
-- Initialization and fallback if all points are removed during pooling.
pool [a
x] [b
y] [] a
x' b
y' = (a
x' a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a
x], b
y' b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b
y], [a -> a -> b -> b -> Double
forall a b. (Real a, Real b) => a -> a -> b -> b -> Double
slope a
x a
x' b
y b
y'])
pool [a]
xs [b]
ys [Double]
ss a
x b
y =
  [Char] -> ([a], [b], [Double])
forall a. HasCallStack => [Char] -> a
error
    ([Char] -> ([a], [b], [Double])) -> [Char] -> ([a], [b], [Double])
forall a b. (a -> b) -> a -> b
$  [Char]
"pool: xs, ys, ss, x, y: "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [a] -> [Char]
forall a. Show a => a -> [Char]
show [a]
xs
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [b] -> [Char]
forall a. Show a => a -> [Char]
show [b]
ys
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Double] -> [Char]
forall a. Show a => a -> [Char]
show [Double]
ss
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
x
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ b -> [Char]
forall a. Show a => a -> [Char]
show b
y
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"."
{-# SPECIALIZE pool :: [Int] -> [Double] -> [Double] -> Int -> Double
                    -> ([Int], [Double], [Double]) #-}
{-# SPECIALIZE pool :: [Double] -> [Double] -> [Double] -> Double -> Double
                    -> ([Double], [Double], [Double]) #-}

-- | Greatest convex minorant. Uses the Pool Adjacent Violators Algorithm
-- (PAVA). It is required that the predictors are ordered with no ties, and that
-- the lengths of the vectors are equal.
--
-- Usage:
--
-- @
--  lcm predictors responses = (indices, values, slopes)
-- @
lcm
  :: (Real a, Real b, Show a, Vector v a, Show b, Vector v b, Vector v Bool)
  => v a
  -> v b
  -> ([a], [b], [Double])
lcm :: v a -> v b -> ([a], [b], [Double])
lcm v a
ps v b
rs
  | Int
lPs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
lRs
  = [Char] -> ([a], [b], [Double])
forall a. HasCallStack => [Char] -> a
error
    ([Char] -> ([a], [b], [Double])) -> [Char] -> ([a], [b], [Double])
forall a b. (a -> b) -> a -> b
$  [Char]
"lcm: Number of predictors is "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
lPs
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", but number of responses is "
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
lRs
    [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"."
  | Bool -> Bool
not (v a -> Bool
forall a (v :: * -> *).
(Ord a, Vector v a, Vector v Bool) =>
v a -> Bool
strictlyOrdered v a
ps)
  = [Char] -> ([a], [b], [Double])
forall a. HasCallStack => [Char] -> a
error [Char]
"lcm: The predictors are not strictly ordered."
  | Bool
otherwise
  = v a -> v b -> ([a], [b], [Double])
forall a b (v :: * -> *).
(Real a, Real b, Show a, Vector v a, Show b, Vector v b) =>
v a -> v b -> ([a], [b], [Double])
unsafeLcm v a
ps v b
rs
 where
  lPs :: Int
lPs = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v a
ps
  lRs :: Int
lRs = v b -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v b
rs

-- | See 'lcm'.
--
-- Assume that:
-- - the lengths of the provided vectors are equal;
-- - the predictors are ordered.
unsafeLcm
  :: (Real a, Real b, Show a, Vector v a, Show b, Vector v b)
  => v a
  -> v b
  -> ([a], [b], [Double])
unsafeLcm :: v a -> v b -> ([a], [b], [Double])
unsafeLcm v a
ps v b
rs | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = ([], [], [])
                | Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1    = ([a], [b], [Double])
forall a. ([a], [b], [a])
start
                | Bool
otherwise = ([a], [b], [Double]) -> ([a], [b], [Double])
forall a b c. ([a], [b], [c]) -> ([a], [b], [c])
reverse3 (([a], [b], [Double]) -> ([a], [b], [Double]))
-> ([a], [b], [Double]) -> ([a], [b], [Double])
forall a b. (a -> b) -> a -> b
$ ([a], [b], [Double]) -> Int -> ([a], [b], [Double])
go ([a], [b], [Double])
forall a. ([a], [b], [a])
start (Int
1 :: Int)
 where
  l :: Int
l     = v b -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
V.length v b
rs
  start :: ([a], [b], [a])
start = ([v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
V.head v a
ps], [v b -> b
forall (v :: * -> *) a. Vector v a => v a -> a
V.head v b
rs], [])
  -- xs and ys: x and y values of lcm
  -- i: next index of rs
  go :: ([a], [b], [Double]) -> Int -> ([a], [b], [Double])
go ([a]
xs, [b]
ys, [Double]
ss) Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
l    = ([a]
xs, [b]
ys, [Double]
ss)
                    | Bool
otherwise = ([a], [b], [Double]) -> Int -> ([a], [b], [Double])
go ([a] -> [b] -> [Double] -> a -> b -> ([a], [b], [Double])
forall a b.
(Ord a, Real a, Show a, Real b, Show b) =>
[a] -> [b] -> [Double] -> a -> b -> ([a], [b], [Double])
pool [a]
xs [b]
ys [Double]
ss a
x' b
y') (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
   where
    x' :: a
x' = v a
ps v a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
V.! Int
i
    y' :: b
y' = v b
rs v b -> Int -> b
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
V.! Int
i