{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
module Algorithm.SRTree.Likelihoods
( Distribution (..)
, PVector
, SRMatrix
, sse
, mse
, rmse
, r2
, nll
, predict
, buildNLL
, gradNLL
, gradNLLArr
, gradNLLGraph
, fisherNLL
, getSErr
, hessianNLL
, tree2arr
)
where
import Algorithm.SRTree.AD ( reverseModeArr, reverseModeGraph )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Mutable as Mut
import Data.Maybe (fromMaybe)
import Data.SRTree
import Data.SRTree.Recursion ( cata, accu )
import Data.SRTree.Derivative (deriveByParam, deriveByVar, derivative)
import Data.SRTree.Eval
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Vector.Storable as VS
import GHC.IO (unsafePerformIO)
import Data.Maybe
import Debug.Trace
import Data.SRTree.Print
data Distribution = MSE | Gaussian | HGaussian | Bernoulli | Poisson | ROXY
deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded, Distribution -> Distribution -> Bool
(Distribution -> Distribution -> Bool)
-> (Distribution -> Distribution -> Bool) -> Eq Distribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Distribution -> Distribution -> Bool
== :: Distribution -> Distribution -> Bool
$c/= :: Distribution -> Distribution -> Bool
/= :: Distribution -> Distribution -> Bool
Eq)
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr))
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
ym :: Double
ym = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
err :: Double
err = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_ a
_ = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}
nll :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
nll :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
MSE Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
| Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p' = String -> Double
forall a. HasCallStack => String -> a
error String
"For Gaussian distribution theta must contain the variance as its last value."
| Bool
otherwise = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s))
where
s :: Double
s = PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! (Int
p' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
(Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
nParams :: Int
nParams = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
t
m :: Double
m = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
p :: Integer
p = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
nll Distribution
HGaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta =
case Maybe PVector
mYerr of
Maybe PVector
Nothing -> String -> Double
forall a. HasCallStack => String -> a
error String
"For HGaussian, you must provide the measured error for the target variable."
Just PVector
yErr -> Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Num a => a -> a -> a
+ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*)) PVector
yErr))
where
(Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
m :: Integer
m = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
p :: Integer
p = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
nll Distribution
Bernoulli Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
log ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
+) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat)
where
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)
nll Distribution
Poisson Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat
where
ys' :: SRVector
ys' = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
yhat :: SRVector
yhat = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
| Maybe PVector -> Bool
forall a. Maybe a -> Bool
isNothing Maybe PVector
mYerr = String -> Double
forall a. HasCallStack => String -> a
error String
"Can't calculate ROXY nll without x,y-errors."
| Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 = String -> Double
forall a. HasCallStack => String -> a
error String
"We need 3 additional parameters for ROXY."
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
5 = String -> Double
forall a. HasCallStack => String -> a
error String
"For ROXY dataset must contain a single variable, or 1 variable + 4 cached data."
| Bool
otherwise = if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
negLL then (Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0) else Double
negLL
where
(Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
(Sz2 Int
m Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
p :: Int
p = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
num_params :: Int
num_params = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree
x0 :: Array D (Lower Ix2) Double
x0 = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
0
logX :: Array D (Lower Ix2) Double
logX = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
1
logY :: Array D (Lower Ix2) Double
logY = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
2
logXErr :: Array D (Lower Ix2) Double
logXErr = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
3
logYErr :: Array D (Lower Ix2) Double
logYErr = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
4
yErr :: PVector
yErr = Maybe PVector -> PVector
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PVector
mYerr
one :: SRVector
one = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1
zero :: SRVector
zero = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
0
(Double
sig, Double
mu_gauss, Double
w_gauss) = (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
num_params, PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1), PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2))
applyDer :: Op -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
applyDer :: Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
Add SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
dr
applyDer Op
Sub SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
dr
applyDer Op
Mul SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
lSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl
applyDer Op
Div SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
drSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
l) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (SRVector
rSRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
applyDer Op
Power SRVector
l SRVector
dl SRVector
r SRVector
dr = SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
rSRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.-Double
1) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr)
applyDer Op
PowerAbs SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
l)
applyDer Op
AQ SRVector
l SRVector
dl SRVector
r SRVector
dr = ((Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
1.5) (Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r)
(SRVector
yhat, SRVector
grad) = (SRTree (SRVector, SRVector) -> (SRVector, SRVector))
-> Fix SRTree -> (SRVector, SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg Fix SRTree
tree
where
alg :: SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg (Var Int
ix) = (SRVector
x0, SRVector
one)
alg (Param Int
ix) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Int
ix), SRVector
zero)
alg (Const Double
x) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
x, SRVector
zero)
alg (Uni Function
f (SRVector
val, SRVector
der)) = ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) SRVector
val, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) SRVector
val SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
der)
alg (Bin Op
op (SRVector
valL, SRVector
derL) (SRVector
valR, SRVector
derR)) = ((Double -> Double -> Double) -> SRVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) SRVector
valL SRVector
valR, Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
op SRVector
valL SRVector
derL SRVector
valR SRVector
derR)
f :: SRVector
f = (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10) (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
yhat)
fprime :: SRVector
fprime = SRVector
grad SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double
forall a. Floating a => a -> a
log Double
10 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. SRVector
yhat) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
x0 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double -> Double
forall a. Floating a => a -> a
log Double
10
w_gauss2 :: Double
w_gauss2 = Double
w_gauss Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
s2 :: SRVector
s2 = SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector
logYErr SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.+ Double
sigDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
den :: SRVector
den = SRVector
fprime SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double
w_gauss2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
logXErr)
neglogP :: SRVector
neglogP = Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi)
Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
den
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
fprime SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
mu_gauss Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
-. SRVector
logX) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
logX SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.- Double
mu_gauss)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
den
negLL :: Double
negLL = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
neglogP
buildNLL :: Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
MSE Double
m Fix SRTree
tree = ((Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Double -> Fix SRTree
constv Double
m
buildNLL Distribution
Gaussian Double
m Fix SRTree
tree = (Fix SRTree -> Fix SRTree
square(Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log ((Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)))
where
square :: Fix SRTree -> Fix SRTree
square = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Square
p :: Int
p = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree
buildNLL Distribution
HGaussian Double
m Fix SRTree
tree = (Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Int -> Fix SRTree
var (-Int
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Double -> Fix SRTree
constv Double
m Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
piFix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var (-Int
2))
buildNLL Distribution
Poisson Double
m Fix SRTree
tree = Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
Bernoulli Double
m Fix SRTree
tree = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
negate Fix SRTree
tree)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
ROXY Double
m Fix SRTree
tree = Fix SRTree
neglogP
where
p :: Int
p = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree
f :: Fix SRTree
f = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
fprime :: Fix SRTree
fprime = Int -> Fix SRTree -> Fix SRTree
deriveByVar Int
0 Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ (Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var Int
0 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
logX :: Fix SRTree
logX = Int -> Fix SRTree
var Int
1
logY :: Fix SRTree
logY = Int -> Fix SRTree
var Int
2
logXErr :: Fix SRTree
logXErr = Int -> Fix SRTree
var Int
3
logYErr :: Fix SRTree
logYErr = Int -> Fix SRTree
var Int
4
sig :: Fix SRTree
sig = Int -> Fix SRTree
param Int
p
mu_gauss :: Fix SRTree
mu_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
w_gauss :: Fix SRTree
w_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2)
w_gauss2 :: Fix SRTree
w_gauss2 = Fix SRTree
w_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
s2 :: Fix SRTree
s2 = Fix SRTree
logYErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
sig Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
den :: Fix SRTree
den = Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr)
neglogP :: Fix SRTree
neglogP = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
pi)
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
den
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ ( Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*(Fix SRTree
mu_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logX) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
**Fix SRTree
2
Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
logX Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
mu_gauss) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree
den
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Gaussian Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
logistic (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
exp (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
ROXY Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
gradNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = (Double
f, PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad)
where
grad :: PVector
grad :: PVector
grad = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Int -> Double
finitediff Int
ix | Int
ix <- [Int
0..Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
disturb :: Int -> PVector
disturb :: Int -> PVector
disturb Int
ix = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq ([Double] -> PVector) -> [Double] -> PVector
forall a b. (a -> b) -> a -> b
$ (Int -> Double -> Double) -> [Int] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
Prelude.zipWith (\Int
iy Double
v -> if Int
iyInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
ix then (Double
vDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
eps) else Double
v) [Int
0..] (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
M.toList PVector
theta)
eps :: Double
eps :: Double
eps = Double
1e-8
f :: Double
f = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
finitediff :: Int -> Double
finitediff Int
ix = let t1 :: PVector
t1 = Int -> PVector
disturb Int
ix
f' :: Double
f' = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
t1 SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
in (Double
f' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
eps
(Sz2 Int
m Int
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
tree' :: Fix SRTree
tree' = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Fix SRTree
tree
treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
j2ix :: IntMap Integer
j2ix = [(Int, Integer)] -> IntMap Integer
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Integer)] -> IntMap Integer)
-> [(Int, Integer)] -> IntMap Integer
forall a b. (a -> b) -> a -> b
$ [Int] -> [Integer] -> [(Int, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Integer
0..]
nanTo0 :: p -> p
nanTo0 p
x = p
x
{-# INLINE nanTo0 #-}
gradNLLArr :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
where
(SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
grad' :: SRVector
grad' = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> Fix SRTree
-> Vector Double
-> (Double, Vector Double)
gradNLLGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
(SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
| (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
| Bool
otherwise = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
where
(SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
grad' :: Vector Double
grad' = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
fisherNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
where
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
f :: Double
f = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
eps :: Double
eps = Double
1e-6
finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
v <- Mut.readM theta' ix
Mut.writeM theta' ix (v + eps)
thetaPlus <- Mut.freezeS theta'
Mut.writeM theta' ix (v - eps)
thetaMinus <- Mut.freezeS theta'
let fPlus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
fMinus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
pure $ (fPlus + fMinus - 2*f)/(eps*eps)
fisherNLL Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
where
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
f :: Double
f = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
eps :: Double
eps = Double
1e-6
finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
v <- Mut.readM theta' ix
Mut.writeM theta' ix (v + eps)
thetaPlus <- Mut.freezeS theta'
Mut.writeM theta' ix (v - eps)
thetaMinus <- Mut.freezeS theta'
let fPlus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
fMinus = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
pure $ (fPlus + fMinus - 2*f)/(eps*eps)
fisherNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
where
build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix
f' :: SRVector
f' = Fix SRTree -> SRVector
eval Fix SRTree
dtdix
f'' :: SRVector
f'' = Fix SRTree -> SRVector
eval Fix SRTree
d2tdix2
in SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f'SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f''
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
eval :: Fix SRTree -> SRVector
eval = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
yhat :: SRVector
yhat = Fix SRTree -> SRVector
eval Fix SRTree
t'
res :: SRVector
res = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
yErr :: PVector
yErr = case Maybe PVector
mYerr of
Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
Just PVector
e -> PVector
e
est :: Double
est = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
(SRVector
phi, SRVector
phi') = case Distribution
dist of
Distribution
MSE -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Gaussian -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
Distribution
Poisson -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)
hessianNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = SRMatrix
forall a. HasCallStack => a
undefined
hessianNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build
where
build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
dtdiy :: Fix SRTree
dtdiy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t'
d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
fx :: SRVector
fx = Fix SRTree -> SRVector
eval Fix SRTree
dtdix
fy :: SRVector
fy = Fix SRTree -> SRVector
eval Fix SRTree
dtdiy
fxy :: SRVector
fxy = Fix SRTree -> SRVector
eval Fix SRTree
d2tdixy
in case Distribution
dist of
Distribution
Gaussian -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
Distribution
_ -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
cmp :: Comp
cmp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
t' :: Fix SRTree
t' = Fix SRTree
tree
eval :: Fix SRTree -> SRVector
eval = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
yErr :: PVector
yErr = case Maybe PVector
mYerr of
Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
Just PVector
e -> PVector
e
est :: Double
est = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
yhat :: SRVector
yhat = Fix SRTree -> SRVector
eval Fix SRTree
t'
res :: SRVector
res = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
(SRVector
phi, SRVector
phi') = case Distribution
dist of
Distribution
Gaussian -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
Distribution
Poisson -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)
tree2arr :: Fix SRTree -> IntMap.IntMap (Int, Int, Int, Double)
tree2arr :: Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree = [(Int, (Int, Int, Int, Double))] -> IntMap (Int, Int, Int, Double)
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int, (Int, Int, Int, Double))]
listTree
where
height :: Fix SRTree -> Integer
height = (SRTree Integer -> Integer) -> Fix SRTree -> Integer
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree Integer -> Integer
forall {a}. (Num a, Ord a) => SRTree a -> a
alg
where
alg :: SRTree a -> a
alg (Var Int
ix) = a
1
alg (Const Double
x) = a
1
alg (Param Int
ix) = a
1
alg (Uni Function
_ a
t) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
alg (Bin Op
_ a
l a
r) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a -> a
forall a. Ord a => a -> a -> a
max a
l a
r
listTree :: [(Int, (Int, Int, Int, Double))]
listTree = (forall x. SRTree x -> Int -> SRTree (x, Int))
-> (SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))])
-> Fix SRTree
-> Int
-> [(Int, (Int, Int, Int, Double))]
forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu SRTree x -> Int -> SRTree (x, Int)
forall x. SRTree x -> Int -> SRTree (x, Int)
forall {b} {a}. Num b => SRTree a -> b -> SRTree (a, b)
indexer SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))]
forall {a} {a}.
Num a =>
SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert Fix SRTree
tree Int
0
indexer :: SRTree a -> b -> SRTree (a, b)
indexer (Var Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Var Int
ix
indexer (Const Double
x) b
iy = Double -> SRTree (a, b)
forall val. Double -> SRTree val
Const Double
x
indexer (Param Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Param Int
ix
indexer (Bin Op
op a
l a
r) b
iy = Op -> (a, b) -> (a, b) -> SRTree (a, b)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1) (a
r, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
2)
indexer (Uni Function
f a
t) b
iy = Function -> (a, b) -> SRTree (a, b)
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1)
convert :: SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert (Var Int
ix) a
iy = [(a
iy, (a
0, Int
0, Int
ix, -Double
1))]
convert (Const Double
x) a
iy = [(a
iy, (a
0, Int
2, -Int
1, Double
x))]
convert (Param Int
ix) a
iy = [(a
iy, (a
0, Int
1, Int
ix, -Double
1))]
convert (Uni Function
f [(a, (a, Int, Int, Double))]
t) a
iy = (a
iy, (a
1, Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: [(a, (a, Int, Int, Double))]
t
convert (Bin Op
op [(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
r) a
iy = (a
iy, (a
2, Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: ([(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. Semigroup a => a -> a -> a
<> [(a, (a, Int, Int, Double))]
r)
{-# INLINE tree2arr #-}