{-# LANGUAGE ViewPatterns #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.Likelihoods 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to calculate different likelihood functions, their gradient, and Hessian matrices.
--
-----------------------------------------------------------------------------
module Algorithm.SRTree.Likelihoods
  ( Distribution (..)
  , PVector
  , SRMatrix
  , sse
  , mse
  , rmse
  , r2
  , nll
  , predict
  , gradNLL
  , gradNLLNonUnique
  , fisherNLL
  , getSErr
  , hessianNLL
  )
    where

import Algorithm.SRTree.AD ( forwardMode, reverseModeUnique ) -- ( reverseModeUnique )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import Data.Maybe (fromMaybe)
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams)
import Data.SRTree.Derivative (deriveByParam)
import Data.SRTree.Eval (PVector, SRMatrix, SRVector, compMode, evalTree)

-- | Supported distributions for negative log-likelihood
data Distribution = Gaussian | Bernoulli | Poisson
    deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded)

-- | Sum-of-square errors or Sum-of-square residues
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    yhat :: Array D Int Double
yhat   = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    err :: Double
err    = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
yhat) Array D Int Double -> Int -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)

-- | Total Sum-of-squares
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    ym :: Double
ym     = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
    err :: Double
err    = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) Array D Int Double -> Int -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
        
-- | Mean squared errors
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m

-- | Root of the mean squared errors
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree

-- | Coefficient of determination
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot  SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta

-- | logistic function
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}

-- | get the standard error from a Maybe Double
-- if it is Nothing, estimate from the ssr, otherwise use the current value
-- For distributions other than Gaussian, it defaults to a constant 1
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_        a
_   = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}

-- negation of the sum of values in a vector
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}

-- | Negative log-likelihood
nll :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double

-- | Gaussian distribution
nll :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
ssrDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
s2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s2))
  where
    (Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys 
    (Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    m :: Double
m    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m' 
    p :: Double
p    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
    ssr :: Double
ssr  = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
    mse' :: Double
mse' = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
    est :: Double
est  = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p) -- $ ssr / (m - p)
    sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr
    s2 :: Double
s2   = Double
sErr Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2

-- | Bernoulli distribution of f(x; theta) is, given phi = 1 / (1 + exp (-f(x; theta))),
-- y log phi + (1-y) log (1 - phi), assuming y \in {0,1}
nll Distribution
Bernoulli Maybe Double
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log (Array D Int Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
+ Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat)
  where
    (Sz Int
m)   = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    yhat :: Array D Int Double
yhat     = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)

nll Distribution
Poisson Maybe Double
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta 
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  -- | M.any isNaN yhat = error $ "NaN predictions " <> show theta
  | Bool
otherwise   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat
  where
    ys' :: Array D Int Double
ys'      = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    yhat :: Array D Int Double
yhat     = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)

nll' :: Distribution -> Double -> SRVector -> SRVector -> Double
nll' :: Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
ssrDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
s2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s2))
  where 
    (Sz Int
m') = Array D Int Double -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array D ix e -> Sz ix
M.size Array D Int Double
ys 
    m :: Double
m    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m' 
    ssr :: Double
ssr  = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ (Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
yhat)Array D Int Double -> Integer -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
    s2 :: Double
s2   = Double
sErr Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
nll' Distribution
Bernoulli Double
_ Array D Int Double
yhat Array D Int Double
ys = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log (Array D Int Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
+ Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat)
nll' Distribution
Poisson Double
_ Array D Int Double
yhat Array D Int Double
ys   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat
{-# INLINE nll' #-}

-- | Prediction for different distributions
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Gaussian  Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic (Array D Int Double -> Array D Int Double)
-> Array D Int Double -> Array D Int Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson   Fix SRTree
tree PVector
theta SRMatrix
xss = Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp (Array D Int Double -> Array D Int Double)
-> Array D Int Double -> Array D Int Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree

-- | Gradient of the negative log-likelihood
gradNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta =
  (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys', PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad Array D Int Double -> Double -> Array D Int Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> e -> Array r ix e
./ (Double
sErr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
sErr))
  where
    (Sz Int
m)       = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p)       = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    ys' :: Array D Int Double
ys'          = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys' Array D Int Double -> Array D Int Double
forall a. a -> a
id Fix SRTree
tree
    -- err          = yhat - delay ys
    ssr :: Double
ssr          = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    est :: Double
est          = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p) -- $ ssr / fromIntegral (m - p)
    sErr :: Double
sErr         = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr

gradNLL Distribution
Bernoulli Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
  | (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Bernoulli Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
  where
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Fix SRTree
tree
    grad' :: Array D Int Double
grad'        = (Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {a}. RealFloat a => a -> a
nanTo0 PVector
grad
    --err          = logistic yhat - ys
    nanTo0 :: a -> a
nanTo0 a
x     = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x

gradNLL Distribution
Poisson Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
  | (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) Array D Int Double
ys    = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
 -- | M.any isNaN grad = error $ "NaN gradient " <> show grad
  | Bool
otherwise        = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Poisson Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
  where
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Fix SRTree
tree
    --err          = exp yhat - ys

-- | Gradient of the negative log-likelihood
gradNLLNonUnique :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLLNonUnique :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta =
  (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys', PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad Array D Int Double -> Double -> Array D Int Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> e -> Array r ix e
./ (Double
sErr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
sErr))
  where
    (Sz Int
m)       = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p)       = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    ys' :: Array D Int Double
ys'          = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
    err :: Array D Int Double
err          = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Gaussian Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys'
    ssr :: Double
ssr          = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    est :: Double
est          = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p) -- $ ssr / fromIntegral (m - p)
    sErr :: Double
sErr         = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr

gradNLLNonUnique Distribution
Bernoulli Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
  | (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Bernoulli Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
  where
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
    grad' :: Array D Int Double
grad'        = (Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {a}. RealFloat a => a -> a
nanTo0 PVector
grad
    err :: Array D Int Double
err          = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array D Int Double
ys
    nanTo0 :: a -> a
nanTo0 a
x     = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x

gradNLLNonUnique Distribution
Poisson Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
  | (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) Array D Int Double
ys    = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  -- | M.any isNaN grad = error $ "NaN gradient " <> show grad
  | Bool
otherwise        = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Poisson Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
  where
    (Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
    err :: Array D Int Double
err          = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Poisson Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array D Int Double
ys

-- | Fisher information of negative log-likelihood
fisherNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Array D Int Double
fisherNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> Array D Int Double
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
  where
    build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
                   d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix 
                   f' :: Array D Int Double
f'      = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdix 
                   f'' :: Array D Int Double
f''     = Fix SRTree -> Array D Int Double
eval Fix SRTree
d2tdix2 
               in (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
sErrDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
phi' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
f'Array D Int Double -> Integer -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
res Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
f''
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss 
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
    eval :: Fix SRTree -> Array D Int Double
eval   = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta
    ssr :: Double
ssr    = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    sErr :: Double
sErr   = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
dist Double
est Maybe Double
msErr
    est :: Double
est    = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
p) -- $ ssr / fromIntegral (m - p)
    yhat :: Array D Int Double
yhat   = Fix SRTree -> Array D Int Double
eval Fix SRTree
t'
    res :: Array D Int Double
res    = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi

    (Array D Int Double
phi, Array D Int Double
phi') = case Distribution
dist of
                    Distribution
Gaussian  -> (Array D Int Double
yhat, Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Array D Int Double
yhat, Array D Int Double
phiArray D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi))
                    Distribution
Poisson   -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat, Array D Int Double
phi)

-- | Hessian of negative log-likelihood
--
-- Note, though the Fisher is just the diagonal of the return of this function
-- it is better to keep them as different functions for efficiency
hessianNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build  
  where
    build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t' 
                           dtdiy :: Fix SRTree
dtdiy   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t' 
                           d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
                           fx :: Array D Int Double
fx      = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdix 
                           fy :: Array D Int Double
fy      = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdiy 
                           fxy :: Array D Int Double
fxy     = Fix SRTree -> Array D Int Double
eval Fix SRTree
d2tdixy 
                        in (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
sErrDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
phi' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fx Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fy Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
res Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fxy
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = Fix SRTree
tree -- relabelParams tree -- $ floatConstsToParam tree
    eval :: Fix SRTree -> Array D Int Double
eval   = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta
    ssr :: Double
ssr    = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    sErr :: Double
sErr   = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
dist Double
est Maybe Double
msErr
    est :: Double
est    = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p) -- $ ssr / fromIntegral (m - p)
    yhat :: Array D Int Double
yhat   = Fix SRTree -> Array D Int Double
eval Fix SRTree
t'
    res :: Array D Int Double
res    = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi

    (Array D Int Double
phi, Array D Int Double
phi') = case Distribution
dist of
                    Distribution
Gaussian  -> (Array D Int Double
yhat, Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Array D Int Double
yhat, Array D Int Double
phiArray D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi))
                    Distribution
Poisson   -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat, Array D Int Double
phi)