{-# LANGUAGE ViewPatterns #-}
module Algorithm.SRTree.Likelihoods
( Distribution (..)
, PVector
, SRMatrix
, sse
, mse
, rmse
, r2
, nll
, predict
, gradNLL
, gradNLLNonUnique
, fisherNLL
, getSErr
, hessianNLL
)
where
import Algorithm.SRTree.AD ( forwardMode, reverseModeUnique )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import Data.Maybe (fromMaybe)
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams)
import Data.SRTree.Derivative (deriveByParam)
import Data.SRTree.Eval (PVector, SRMatrix, SRVector, compMode, evalTree)
data Distribution = Gaussian | Bernoulli | Poisson
deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded)
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
yhat :: Array D Int Double
yhat = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
err :: Double
err = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
yhat) Array D Int Double -> Int -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
ym :: Double
ym = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
err :: Double
err = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) Array D Int Double -> Int -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_ a
_ = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}
nll :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
nll :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
ssrDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
s2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s2))
where
(Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
m :: Double
m = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
p :: Double
p = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
mse' :: Double
mse' = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
est :: Double
est = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p)
sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr
s2 :: Double
s2 = Double
sErr Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
nll Distribution
Bernoulli Maybe Double
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log (Array D Int Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
+ Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat)
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
yhat :: Array D Int Double
yhat = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)
nll Distribution
Poisson Maybe Double
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log Array D Int Double
ys' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat
where
ys' :: Array D Int Double
ys' = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
yhat :: Array D Int Double
yhat = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)
nll' :: Distribution -> Double -> SRVector -> SRVector -> Double
nll' :: Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
ssrDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
s2 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s2))
where
(Sz Int
m') = Array D Int Double -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array D ix e -> Sz ix
M.size Array D Int Double
ys
m :: Double
m = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
ssr :: Double
ssr = Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ (Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
yhat)Array D Int Double -> Integer -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
s2 :: Double
s2 = Double
sErr Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
nll' Distribution
Bernoulli Double
_ Array D Int Double
yhat Array D Int Double
ys = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log (Array D Int Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
+ Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat)
nll' Distribution
Poisson Double
_ Array D Int Double
yhat Array D Int Double
ys = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
yhat Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
log Array D Int Double
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat
{-# INLINE nll' #-}
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Gaussian Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic (Array D Int Double -> Array D Int Double)
-> Array D Int Double -> Array D Int Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson Fix SRTree
tree PVector
theta SRMatrix
xss = Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp (Array D Int Double -> Array D Int Double)
-> Array D Int Double -> Array D Int Double
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
gradNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta =
(Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys', PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad Array D Int Double -> Double -> Array D Int Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> e -> Array r ix e
./ (Double
sErr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
sErr))
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
ys' :: Array D Int Double
ys' = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys' Array D Int Double -> Array D Int Double
forall a. a -> a
id Fix SRTree
tree
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
est :: Double
est = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr
gradNLL Distribution
Bernoulli Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
| (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Bernoulli Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Fix SRTree
tree
grad' :: Array D Int Double
grad' = (Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {a}. RealFloat a => a -> a
nanTo0 PVector
grad
nanTo0 :: a -> a
nanTo0 a
x = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x
gradNLL Distribution
Poisson Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
| (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Poisson Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> (Array D Int Double -> Array D Int Double)
-> Fix SRTree
-> (Array D Int Double, PVector)
reverseModeUnique SRMatrix
xss PVector
theta Array D Int Double
ys Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Fix SRTree
tree
gradNLLNonUnique :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLLNonUnique :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
Gaussian Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta =
(Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Gaussian Double
sErr Array D Int Double
yhat Array D Int Double
ys', PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad Array D Int Double -> Double -> Array D Int Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> e -> Array r ix e
./ (Double
sErr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
sErr))
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
ys' :: Array D Int Double
ys' = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
err :: Array D Int Double
err = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Gaussian Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
ys'
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
est :: Double
est = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian Double
est Maybe Double
msErr
gradNLLNonUnique Distribution
Bernoulli Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
| (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Bernoulli Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
grad' :: Array D Int Double
grad' = (Double -> Double) -> PVector -> Array D Int Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {a}. RealFloat a => a -> a
nanTo0 PVector
grad
err :: Array D Int Double
err = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array D Int Double
ys
nanTo0 :: a -> a
nanTo0 a
x = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x
gradNLLNonUnique Distribution
Poisson Maybe Double
_ SRMatrix
xss (PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay -> Array D Int Double
ys) Fix SRTree
tree PVector
theta
| (Double -> Bool) -> Array D Int Double -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) Array D Int Double
ys = String -> (Double, Array D Int Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (Distribution
-> Double -> Array D Int Double -> Array D Int Double -> Double
nll' Distribution
Poisson Double
1.0 Array D Int Double
yhat Array D Int Double
ys, PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
(Array D Int Double
yhat, PVector
grad) = SRMatrix
-> PVector
-> Array D Int Double
-> Fix SRTree
-> (Array D Int Double, PVector)
forwardMode SRMatrix
xss PVector
theta Array D Int Double
err Fix SRTree
tree
err :: Array D Int Double
err = Distribution
-> Fix SRTree -> PVector -> SRMatrix -> Array D Int Double
predict Distribution
Poisson Fix SRTree
tree PVector
theta SRMatrix
xss Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array D Int Double
ys
fisherNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Array D Int Double
fisherNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> Array D Int Double
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
where
build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix
f' :: Array D Int Double
f' = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdix
f'' :: Array D Int Double
f'' = Fix SRTree -> Array D Int Double
eval Fix SRTree
d2tdix2
in (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
sErrDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
phi' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
f'Array D Int Double -> Integer -> Array D Int Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
res Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
f''
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
eval :: Fix SRTree -> Array D Int Double
eval = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
dist Double
est Maybe Double
msErr
est :: Double
est = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
p)
yhat :: Array D Int Double
yhat = Fix SRTree -> Array D Int Double
eval Fix SRTree
t'
res :: Array D Int Double
res = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi
(Array D Int Double
phi, Array D Int Double
phi') = case Distribution
dist of
Distribution
Gaussian -> (Array D Int Double
yhat, Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Array D Int Double
yhat, Array D Int Double
phiArray D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi))
Distribution
Poisson -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat, Array D Int Double
phi)
hessianNLL :: Distribution -> Maybe Double -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build
where
build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
dtdiy :: Fix SRTree
dtdiy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t'
d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
fx :: Array D Int Double
fx = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdix
fy :: Array D Int Double
fy = Fix SRTree -> Array D Int Double
eval Fix SRTree
dtdiy
fxy :: Array D Int Double
fxy = Fix SRTree -> Array D Int Double
eval Fix SRTree
d2tdixy
in (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
sErrDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double)
-> (Array D Int Double -> Double) -> Array D Int Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array D Int Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array D Int Double -> Double) -> Array D Int Double -> Double
forall a b. (a -> b) -> a -> b
$ Array D Int Double
phi' Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fx Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fy Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
res Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
* Array D Int Double
fxy
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = Fix SRTree
tree
eval :: Fix SRTree -> Array D Int Double
eval = SRMatrix -> PVector -> Fix SRTree -> Array D Int Double
evalTree SRMatrix
xss PVector
theta
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
sErr :: Double
sErr = Distribution -> Double -> Maybe Double -> Double
forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
dist Double
est Maybe Double
msErr
est :: Double
est = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
yhat :: Array D Int Double
yhat = Fix SRTree -> Array D Int Double
eval Fix SRTree
t'
res :: Array D Int Double
res = PVector -> Array D Int Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi
(Array D Int Double
phi, Array D Int Double
phi') = case Distribution
dist of
Distribution
Gaussian -> (Array D Int Double
yhat, Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
logistic Array D Int Double
yhat, Array D Int Double
phiArray D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> Array D Int Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 Array D Int Double -> Array D Int Double -> Array D Int Double
forall a. Num a => a -> a -> a
- Array D Int Double
phi))
Distribution
Poisson -> (Array D Int Double -> Array D Int Double
forall a. Floating a => a -> a
exp Array D Int Double
yhat, Array D Int Double
phi)