{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.Likelihoods 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to calculate different likelihood functions, their gradient, and Hessian matrices.
--
-----------------------------------------------------------------------------
module Algorithm.SRTree.Likelihoods
  ( Distribution (..)
  , PVector
  , SRMatrix
  , sse
  , mse
  , rmse
  , r2
  , nll
  , predict
  , buildNLL
  , gradNLL
  , gradNLLArr
  , gradNLLGraph
  , fisherNLL
  , getSErr
  , hessianNLL
  , tree2arr
  )
    where

import Algorithm.SRTree.AD ( reverseModeArr, reverseModeGraph )
import Data.Massiv.Array hiding (all, map, read, replicate, tail, take, zip)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Mutable as Mut
import Data.Maybe (fromMaybe)
import Data.SRTree
import Data.SRTree.Recursion ( cata, accu )
import Data.SRTree.Derivative (deriveByParam, deriveByVar, derivative)
import Data.SRTree.Eval
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Vector.Storable as VS
import GHC.IO (unsafePerformIO)
import Data.Maybe

import Debug.Trace
import Data.SRTree.Print

-- | Supported distributions for negative log-likelihood
-- MSE refers to mean squared error
-- HGaussian is Gaussian with heteroscedasticity, where the error should be provided
data Distribution = MSE | Gaussian | HGaussian | Bernoulli | Poisson | ROXY
    deriving (Int -> Distribution -> ShowS
[Distribution] -> ShowS
Distribution -> String
(Int -> Distribution -> ShowS)
-> (Distribution -> String)
-> ([Distribution] -> ShowS)
-> Show Distribution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Distribution -> ShowS
showsPrec :: Int -> Distribution -> ShowS
$cshow :: Distribution -> String
show :: Distribution -> String
$cshowList :: [Distribution] -> ShowS
showList :: [Distribution] -> ShowS
Show, ReadPrec [Distribution]
ReadPrec Distribution
Int -> ReadS Distribution
ReadS [Distribution]
(Int -> ReadS Distribution)
-> ReadS [Distribution]
-> ReadPrec Distribution
-> ReadPrec [Distribution]
-> Read Distribution
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Distribution
readsPrec :: Int -> ReadS Distribution
$creadList :: ReadS [Distribution]
readList :: ReadS [Distribution]
$creadPrec :: ReadPrec Distribution
readPrec :: ReadPrec Distribution
$creadListPrec :: ReadPrec [Distribution]
readListPrec :: ReadPrec [Distribution]
Read, Int -> Distribution
Distribution -> Int
Distribution -> [Distribution]
Distribution -> Distribution
Distribution -> Distribution -> [Distribution]
Distribution -> Distribution -> Distribution -> [Distribution]
(Distribution -> Distribution)
-> (Distribution -> Distribution)
-> (Int -> Distribution)
-> (Distribution -> Int)
-> (Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> [Distribution])
-> (Distribution -> Distribution -> Distribution -> [Distribution])
-> Enum Distribution
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Distribution -> Distribution
succ :: Distribution -> Distribution
$cpred :: Distribution -> Distribution
pred :: Distribution -> Distribution
$ctoEnum :: Int -> Distribution
toEnum :: Int -> Distribution
$cfromEnum :: Distribution -> Int
fromEnum :: Distribution -> Int
$cenumFrom :: Distribution -> [Distribution]
enumFrom :: Distribution -> [Distribution]
$cenumFromThen :: Distribution -> Distribution -> [Distribution]
enumFromThen :: Distribution -> Distribution -> [Distribution]
$cenumFromTo :: Distribution -> Distribution -> [Distribution]
enumFromTo :: Distribution -> Distribution -> [Distribution]
$cenumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
enumFromThenTo :: Distribution -> Distribution -> Distribution -> [Distribution]
Enum, Distribution
Distribution -> Distribution -> Bounded Distribution
forall a. a -> a -> Bounded a
$cminBound :: Distribution
minBound :: Distribution
$cmaxBound :: Distribution
maxBound :: Distribution
Bounded, Distribution -> Distribution -> Bool
(Distribution -> Distribution -> Bool)
-> (Distribution -> Distribution -> Bool) -> Eq Distribution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Distribution -> Distribution -> Bool
== :: Distribution -> Distribution -> Bool
$c/= :: Distribution -> Distribution -> Bool
/= :: Distribution -> Distribution -> Bool
Eq)

-- | Sum-of-square errors or Sum-of-square residues
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    yhat :: SRVector
yhat   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)

sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError :: SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    yhat :: SRVector
yhat   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
yhat) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr))

-- | Total Sum-of-squares
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
err
  where
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    ym :: Double
ym     = PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum PVector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m
    err :: Double
err    = SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract Double
ym) PVector
ys) SRVector -> Int -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)
        
-- | Mean squared errors
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys in SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m

-- | Root of the mean squared errors
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
rmse SRMatrix
xss PVector
ys Fix SRTree
tree = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
tree

-- | Coefficient of determination
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 :: SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
r2 SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sseTot  SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta

-- | logistic function
logistic :: Floating a => a -> a
logistic :: forall a. Floating a => a -> a
logistic a
x = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
exp (-a
x))
{-# inline logistic #-}

-- | get the standard error from a Maybe Double
-- if it is Nothing, estimate from the ssr, otherwise use the current value
-- For distributions other than Gaussian, it defaults to a constant 1
getSErr :: Num a => Distribution -> a -> Maybe a -> a
getSErr :: forall a. Num a => Distribution -> a -> Maybe a -> a
getSErr Distribution
Gaussian a
est = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
est
getSErr Distribution
_        a
_   = a -> Maybe a -> a
forall a b. a -> b -> a
const a
1
{-# inline getSErr #-}

-- negation of the sum of values in a vector
negSum :: PVector -> Double
negSum :: PVector -> Double
negSum = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (PVector -> Double) -> PVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum
{-# inline negSum #-}

-- | Negative log-likelihood
nll :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> Double

-- | Mean Squared error (not a distribution)
nll :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
MSE Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
mse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta

-- | Gaussian distribution, theta must contain an additional parameter corresponding
-- to variance.
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
  | Int
nParams Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p' = String -> Double
forall a. HasCallStack => String -> a
error String
"For Gaussian distribution theta must contain the variance as its last value."
  | Bool
otherwise     = Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double -> Double
forall a. Floating a => a -> a
log (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
s))
  where
    s :: Double
s       = PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! (Int
p' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    (Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys 
    (Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    nParams :: Int
nParams = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
t
    m :: Double
m       = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
    p :: Integer
p       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'

-- | Gaussian with heteroscedasticity, it needs a valid mYerr
nll Distribution
HGaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta =
  case Maybe PVector
mYerr of
    Maybe PVector
Nothing   -> String -> Double
forall a. HasCallStack => String -> a
error String
"For HGaussian, you must provide the measured error for the target variable."
    Just PVector
yErr -> Double
0.5Double -> Double -> Double
forall a. Num a => a -> a -> a
*(SRMatrix -> PVector -> PVector -> Fix SRTree -> PVector -> Double
sseError SRMatrix
xss PVector
ys PVector
yErr Fix SRTree
t PVector
theta Double -> Double -> Double
forall a. Num a => a -> a -> a
+ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum ((Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double
forall a. Floating a => a
piDouble -> Double -> Double
forall a. Num a => a -> a -> a
*)) PVector
yErr))
  where
    (Sz Int
m') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p') = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    m :: Integer
m       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m'
    p :: Integer
p       = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'

-- | Bernoulli distribution of f(x; theta) is, given phi = 1 / (1 + exp (-f(x; theta))),
-- y log phi + (1-y) log (1 - phi), assuming y \in {0,1}
nll Distribution
Bernoulli Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
log ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double
1Double -> Double -> Double
forall a. Num a => a -> a -> a
+) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat)
  where
    (Sz Int
m)   = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    yhat :: SRVector
yhat     = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1)

nll Distribution
Poisson Maybe PVector
_ SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | PVector -> Bool
notValid PVector
ys = String -> Double
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  -- | M.any isNaN yhat = error $ "NaN predictions " <> show theta
  | Bool
otherwise   = Double -> Double
forall a. Num a => a -> a
negate (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
yhat SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
ys' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat
  where
    ys' :: SRVector
ys'      = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    yhat :: SRVector
yhat     = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
    notValid :: PVector -> Bool
notValid = (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0)

nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  | Maybe PVector -> Bool
forall a. Maybe a -> Bool
isNothing Maybe PVector
mYerr = String -> Double
forall a. HasCallStack => String -> a
error String
"Can't calculate ROXY nll without x,y-errors."
  | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3 = String -> Double
forall a. HasCallStack => String -> a
error String
"We need 3 additional parameters for ROXY."
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
5     = String -> Double
forall a. HasCallStack => String -> a
error String
"For ROXY dataset must contain a single variable, or 1 variable + 4 cached data."
  | Bool
otherwise          = if Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
negLL then (Double
1.0Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0.0) else Double
negLL
  where
    (Sz Int
p')      = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    (Sz2 Int
m Int
n)    = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
    p :: Int
p            = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p'
    num_params :: Int
num_params   = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree

    x0 :: Array D (Lower Ix2) Double
x0           = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
0
    logX :: Array D (Lower Ix2) Double
logX         = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
1
    logY :: Array D (Lower Ix2) Double
logY         = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
2
    logXErr :: Array D (Lower Ix2) Double
logXErr      = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
3
    logYErr :: Array D (Lower Ix2) Double
logYErr      = SRMatrix
xss SRMatrix -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
4


    yErr :: PVector
yErr         = Maybe PVector -> PVector
forall a. HasCallStack => Maybe a -> a
fromJust Maybe PVector
mYerr
    one :: SRVector
one          = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1
    zero :: SRVector
zero         = Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
0

    (Double
sig, Double
mu_gauss, Double
w_gauss) = (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
num_params, PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1), PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
num_params Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2))

    applyDer :: Op -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
    applyDer :: Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
Add SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
dr
    applyDer Op
Sub SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
dr
    applyDer Op
Mul SRVector
l SRVector
dl SRVector
r SRVector
dr      = SRVector
lSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl
    applyDer Op
Div SRVector
l SRVector
dl SRVector
r SRVector
dr      = (SRVector
dlSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
drSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
l) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (SRVector
rSRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
    applyDer Op
Power SRVector
l SRVector
dl SRVector
r SRVector
dr    = SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
rSRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.-Double
1) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr)
    applyDer Op
PowerAbs SRVector
l SRVector
dl SRVector
r SRVector
dr = (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
dr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
l) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
l)
    applyDer Op
AQ SRVector
l SRVector
dl SRVector
r SRVector
dr       = ((Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dl SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
l SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
r SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
dr) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
1.5) (Double
1 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
rSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
r)

    (SRVector
yhat, SRVector
grad) = (SRTree (SRVector, SRVector) -> (SRVector, SRVector))
-> Fix SRTree -> (SRVector, SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg Fix SRTree
tree
      where
        alg :: SRTree (SRVector, SRVector) -> (SRVector, SRVector)
alg (Var Int
ix)   = (SRVector
x0, SRVector
one)
        alg (Param Int
ix) = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) (PVector
theta PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Int
ix), SRVector
zero)
        alg (Const Double
x)  = (Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
x, SRVector
zero)
        alg (Uni Function
f (SRVector
val, SRVector
der))  = ((Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) SRVector
val, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative Function
f) SRVector
val SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
der)
        alg (Bin Op
op (SRVector
valL, SRVector
derL) (SRVector
valR, SRVector
derR)) = ((Double -> Double -> Double) -> SRVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) SRVector
valL SRVector
valR, Op -> SRVector -> SRVector -> SRVector -> SRVector -> SRVector
applyDer Op
op SRVector
valL SRVector
derL SRVector
valR SRVector
derR)

    f :: SRVector
f            = (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase Double
10) (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
yhat)
    fprime :: SRVector
fprime       = SRVector
grad SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ (Double -> Double
forall a. Floating a => a -> a
log Double
10 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. SRVector
yhat) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
x0 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double -> Double
forall a. Floating a => a -> a
log Double
10

    -- nll
    w_gauss2 :: Double
w_gauss2     = Double
w_gauss Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2
    s2 :: SRVector
s2           = SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector
logYErr SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.+ Double
sigDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
    den :: SRVector
den          = SRVector
fprime SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
2 SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* Double
w_gauss2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector
logXErr)

    neglogP :: SRVector
neglogP = Double -> Double
forall a. Floating a => a -> a
log (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
forall a. Floating a => a
pi)
        Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
+. SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
den
        SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ (Double
w_gauss2 Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
*. (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)
           SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
logXErr SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
fprime SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (Double
mu_gauss Double -> SRVector -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
e -> Array r ix e -> Array r ix e
-. SRVector
logX) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
f SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
logY)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2
           SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
s2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector
logX SRVector -> Double -> SRVector
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.- Double
mu_gauss)SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
den
    negLL :: Double
negLL = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
neglogP

-- WARNING: pass tree with parameters
-- TODO: handle error similar to ROXY
buildNLL :: Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
MSE Double
m Fix SRTree
tree = ((Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Double -> Fix SRTree
constv Double
m
buildNLL Distribution
Gaussian Double
m Fix SRTree
tree =  (Fix SRTree -> Fix SRTree
square(Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log ((Fix SRTree -> Fix SRTree
square (Int -> Fix SRTree
param Int
p)))
  where
    square :: Fix SRTree -> Fix SRTree
square = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Square
    p :: Int
p = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree
buildNLL Distribution
HGaussian Double
m Fix SRTree
tree = (Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Int -> Fix SRTree
var (-Int
2) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Double -> Fix SRTree
constv Double
m Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
piFix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var (-Int
2))
buildNLL Distribution
Poisson Double
m Fix SRTree
tree = Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
Bernoulli Double
m Fix SRTree
tree = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
exp (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
negate Fix SRTree
tree)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Int -> Fix SRTree
var (-Int
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree
buildNLL Distribution
ROXY Double
m Fix SRTree
tree = Fix SRTree
neglogP
  where
    p :: Int
p = Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
tree
    f :: Fix SRTree
f = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
    fprime :: Fix SRTree
fprime = Int -> Fix SRTree -> Fix SRTree
deriveByVar Int
0 Fix SRTree
tree Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ (Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
tree) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Int -> Fix SRTree
var Int
0 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
10
    logX :: Fix SRTree
logX         = Int -> Fix SRTree
var Int
1
    logY :: Fix SRTree
logY         = Int -> Fix SRTree
var Int
2
    logXErr :: Fix SRTree
logXErr      = Int -> Fix SRTree
var Int
3
    logYErr :: Fix SRTree
logYErr      = Int -> Fix SRTree
var Int
4
    sig :: Fix SRTree
sig = Int -> Fix SRTree
param Int
p
    mu_gauss :: Fix SRTree
mu_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    w_gauss :: Fix SRTree
w_gauss = Int -> Fix SRTree
param (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2)
    w_gauss2 :: Fix SRTree
w_gauss2 = Fix SRTree
w_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
    s2 :: Fix SRTree
s2 = Fix SRTree
logYErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
sig Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
    den :: Fix SRTree
den = Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr)
    neglogP :: Fix SRTree
neglogP = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree
2Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
forall a. Floating a => a
pi)
              Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
den
              Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ ( Fix SRTree
w_gauss2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)
                Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
logXErr Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
fprime Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*(Fix SRTree
mu_gauss Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logX) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
f Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
logY)Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
**Fix SRTree
2
                Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
s2 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* (Fix SRTree
logX Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
- Fix SRTree
mu_gauss) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
2
                ) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree
den

-- | Prediction for different distributions
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict :: Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE       Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Gaussian  Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Bernoulli Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
logistic (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
Poisson   Fix SRTree
tree PVector
theta SRMatrix
xss = SRVector -> SRVector
forall a. Floating a => a -> a
exp (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree
predict Distribution
ROXY      Fix SRTree
tree PVector
theta SRMatrix
xss = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta Fix SRTree
tree

-- | Gradient of the negative log-likelihood
gradNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (Double, SRVector)
gradNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = (Double
f, PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
grad) -- gradNLLArr dist xss ys mYerr treeArr j2ix (toStorableVector theta)
  where
    grad :: PVector
    grad :: PVector
grad = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq [Int -> Double
finitediff Int
ix | Int
ix <- [Int
0..Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta

    disturb :: Int -> PVector
    disturb :: Int -> PVector
disturb Int
ix = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList Comp
M.Seq ([Double] -> PVector) -> [Double] -> PVector
forall a b. (a -> b) -> a -> b
$ (Int -> Double -> Double) -> [Int] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
Prelude.zipWith (\Int
iy Double
v -> if Int
iyInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
ix  then (Double
vDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
eps) else Double
v) [Int
0..] (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
M.toList PVector
theta)
    eps :: Double
    eps :: Double
eps = Double
1e-8
    f :: Double
f = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
theta SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
    finitediff :: Int -> Double
finitediff Int
ix = let t1 :: PVector
t1 = Int -> PVector
disturb Int
ix
                        f' :: Double
f' = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Double -> Double) -> (SRVector -> Double) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ (Distribution -> Fix SRTree -> PVector -> SRMatrix -> SRVector
predict Distribution
MSE Fix SRTree
tree PVector
t1 SRMatrix
xss) SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys
                     in (Double
f' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
f)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
eps
    (Sz2 Int
m Int
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
    tree' :: Fix SRTree
tree'     = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Fix SRTree
tree
    treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr   = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
 -> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
    j2ix :: IntMap Integer
j2ix      = [(Int, Integer)] -> IntMap Integer
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Integer)] -> IntMap Integer)
-> [(Int, Integer)] -> IntMap Integer
forall a b. (a -> b) -> a -> b
$ [Int] -> [Integer] -> [(Int, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Integer
0..]

    {-
    -- EXAMPLE OF FINITE DIFFERENCE
    -- Implement for debugging
gradNLL ROXY mXerr mYerr xss ys tree theta =
   (f, delay grad)
  where
    (Sz p) = M.size theta
    (Sz2 m n) = M.size xss
    yhat   = predict Gaussian tree theta xss
    f      = nll ROXY mXerr mYerr xss ys tree theta
    grad   = makeArray @S (getComp xss) (Sz p) finiteDiff
    eps    = 1e-8

    finiteDiff ix = unsafePerformIO $ do
                      theta' <- Mut.thaw theta
                      v <- Mut.readM theta' ix
                      Mut.writeM theta' ix (v + eps)
                      theta'' <- Mut.freezeS theta'
                      let f'= nll ROXY mXerr mYerr xss ys tree theta''
                          g = (f' - f)/eps
                      pure $ if isNaN g then (1/0) else g
                      -}

nanTo0 :: p -> p
nanTo0 p
x = p
x -- if isNaN x || isInfinite x then 0 else x
{-# INLINE nanTo0 #-}

-- | Gradient of the negative log-likelihood
gradNLLArr :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys    = String -> (Double, SRVector)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  | Bool
otherwise        = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad
gradNLLArr Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix Vector Double
theta =
  ((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> SRVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (SRVector -> SRVector) -> SRVector -> SRVector
forall a b. (a -> b) -> a -> b
$ SRVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay SRVector
grad')
  where
    (SRVector
yhat, PVector
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> (SRVector, PVector)
reverseModeArr SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta [(Int, (Int, Int, Int, Double))]
tree IntMap Int
j2ix
    grad' :: SRVector
grad'        = (Double -> Double) -> PVector -> SRVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map Double -> Double
forall {p}. p -> p
nanTo0 PVector
grad

-- | Gradient of the negative log-likelihood
gradNLLGraph :: Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> Fix SRTree
-> Vector Double
-> (Double, Vector Double)
gradNLLGraph Distribution
MSE SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Gaussian SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Bernoulli SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (\Double
x -> Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1) PVector
ys = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Bernoulli distribution the output must be either 0 or 1."
  | Bool
otherwise                         = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
Poisson SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta
  | (Double -> Bool) -> PVector -> Bool
forall ix r e.
(Index ix, Source r e) =>
(e -> Bool) -> Array r ix e -> Bool
M.any (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
0) PVector
ys    = String -> (Double, Vector Double)
forall a. HasCallStack => String -> a
error String
"For Poisson distribution the output must be non-negative."
  | Bool
otherwise        = (SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad
gradNLLGraph Distribution
ROXY SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree Vector Double
theta =
  ((Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
0.5) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum SRVector
yhat, (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map (Double -> Double -> Double
forall a. Num a => a -> a -> a
*(Double
0.5)) (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double
grad')
  where
    (SRVector
yhat, Vector Double
grad) = SRMatrix
-> PVector
-> Maybe PVector
-> Vector Double
-> Fix SRTree
-> (SRVector, Vector Double)
reverseModeGraph SRMatrix
xss PVector
ys Maybe PVector
mYerr Vector Double
theta Fix SRTree
tree
    grad' :: Vector Double
grad'        = (Double -> Double) -> Vector Double -> Vector Double
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
VS.map Double -> Double
forall {p}. p -> p
nanTo0 Vector Double
grad

-- | Fisher information of negative log-likelihood
fisherNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRVector
fisherNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
  where
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    f :: Double
f      = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    eps :: Double
eps = Double
1e-6
    finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
                      theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
                      v <- Mut.readM theta' ix
                      Mut.writeM theta' ix (v + eps)
                      thetaPlus <- Mut.freezeS theta'
                      Mut.writeM theta' ix (v - eps)
                      thetaMinus <- Mut.freezeS theta'
                      let fPlus     = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
                          fMinus    = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
                      pure $ (fPlus + fMinus - 2*f)/(eps*eps)
fisherNLL Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
finiteDiff
  where
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    f :: Double
f      = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    eps :: Double
eps = Double
1e-6
    finiteDiff :: Int -> Double
finiteDiff Int
ix = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
                      theta' <- PVector -> IO (MArray RealWorld S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, MonadIO m) =>
Array r ix e -> m (MArray RealWorld r ix e)
Mut.thaw PVector
theta
                      v <- Mut.readM theta' ix
                      Mut.writeM theta' ix (v + eps)
                      thetaPlus <- Mut.freezeS theta'
                      Mut.writeM theta' ix (v - eps)
                      thetaMinus <- Mut.freezeS theta'
                      let fPlus     = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaPlus
                          fMinus    = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
Gaussian Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
thetaMinus
                      pure $ (fPlus + fMinus - 2*f)/(eps*eps)
fisherNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Int -> (Int -> Double) -> SRVector
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
p) Int -> Double
build
  where
    build :: Int -> Double
build Int
ix = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t'
                   d2tdix2 :: Fix SRTree
d2tdix2 = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
dtdix 
                   f' :: SRVector
f'      = Fix SRTree -> SRVector
eval Fix SRTree
dtdix 
                   f'' :: SRVector
f''     = Fix SRTree -> SRVector
eval Fix SRTree
d2tdix2 
               in SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f'SRVector -> Integer -> SRVector
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
f''
               --case dist of
               --     Gaussian -> M.sum . (/delay (theta M.! (p-1))) $ phi' * f'^2 - res * f''
               --     _        -> M.sum $ phi' * f'^2 - res * f''
    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss 
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
    eval :: Fix SRTree -> SRVector
eval   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
    yhat :: SRVector
yhat   = Fix SRTree -> SRVector
eval Fix SRTree
t'
    res :: SRVector
res    = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi
    yErr :: PVector
yErr   = case Maybe PVector
mYerr of
               Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
               Just PVector
e  -> PVector
e
    est :: Double
est    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)

    (SRVector
phi, SRVector
phi') = case Distribution
dist of
                    Distribution
MSE       -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Gaussian  -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
                    Distribution
Poisson   -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)

-- | Hessian of negative log-likelihood
--
-- Note, though the Fisher is just the diagonal of the return of this function
-- it is better to keep them as different functions for efficiency
hessianNLL :: Distribution -> Maybe PVector -> SRMatrix -> PVector -> Fix SRTree -> PVector -> SRMatrix
hessianNLL :: Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
ROXY Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = SRMatrix
forall a. HasCallStack => a
undefined
hessianNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta = Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
cmp (Ix2 -> Sz Ix2
forall ix. Index ix => ix -> Sz ix
Sz (Int
p Int -> Int -> Ix2
:. Int
p)) Ix2 -> Double
build
  where
    build :: Ix2 -> Double
build (Int
ix :. Int
iy) = let dtdix :: Fix SRTree
dtdix   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
ix Fix SRTree
t' 
                           dtdiy :: Fix SRTree
dtdiy   = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
t' 
                           d2tdixy :: Fix SRTree
d2tdixy = Int -> Fix SRTree -> Fix SRTree
deriveByParam Int
iy Fix SRTree
dtdix
                           fx :: SRVector
fx      = Fix SRTree -> SRVector
eval Fix SRTree
dtdix 
                           fy :: SRVector
fy      = Fix SRTree -> SRVector
eval Fix SRTree
dtdiy 
                           fxy :: SRVector
fxy     = Fix SRTree -> SRVector
eval Fix SRTree
d2tdixy 
                        in case Distribution
dist of
                            Distribution
Gaussian -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double)
-> (SRVector -> SRVector) -> SRVector -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
yErr) (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy
                            Distribution
_        -> SRVector -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (SRVector -> Double) -> SRVector -> Double
forall a b. (a -> b) -> a -> b
$ SRVector
phi' SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fx SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fy SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
res SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector
fxy

    cmp :: Comp
cmp    = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size PVector
theta
    t' :: Fix SRTree
t'     = Fix SRTree
tree -- relabelParams tree -- $ floatConstsToParam tree
    eval :: Fix SRTree -> SRVector
eval   = SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
theta
    yErr :: PVector
yErr   = case Maybe PVector
mYerr of
               Maybe PVector
Nothing -> Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
compMode (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
est
               Just PVector
e  -> PVector
e
    est :: Double
est    = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
    yhat :: SRVector
yhat   = Fix SRTree -> SRVector
eval Fix SRTree
t'
    res :: SRVector
res    = PVector -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay PVector
ys SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi

    (SRVector
phi, SRVector
phi') = case Distribution
dist of
                    Distribution
Gaussian  -> (SRVector
yhat, Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1)
                    Distribution
Bernoulli -> (SRVector -> SRVector
forall a. Floating a => a -> a
logistic SRVector
yhat, SRVector
phiSRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(Comp -> Sz Int -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m) Double
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
phi))
                    Distribution
Poisson   -> (SRVector -> SRVector
forall a. Floating a => a -> a
exp SRVector
yhat, SRVector
phi)

tree2arr :: Fix SRTree -> IntMap.IntMap (Int, Int, Int, Double)
tree2arr :: Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree = [(Int, (Int, Int, Int, Double))] -> IntMap (Int, Int, Int, Double)
forall a. [(Int, a)] -> IntMap a
IntMap.fromList [(Int, (Int, Int, Int, Double))]
listTree
  where
    height :: Fix SRTree -> Integer
height = (SRTree Integer -> Integer) -> Fix SRTree -> Integer
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree Integer -> Integer
forall {a}. (Num a, Ord a) => SRTree a -> a
alg
      where
        alg :: SRTree a -> a
alg (Var Int
ix) = a
1
        alg (Const Double
x) = a
1
        alg (Param Int
ix) = a
1
        alg (Uni Function
_ a
t) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
        alg (Bin Op
_ a
l a
r) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a -> a
forall a. Ord a => a -> a -> a
max a
l a
r
    listTree :: [(Int, (Int, Int, Int, Double))]
listTree = (forall x. SRTree x -> Int -> SRTree (x, Int))
-> (SRTree [(Int, (Int, Int, Int, Double))]
    -> Int -> [(Int, (Int, Int, Int, Double))])
-> Fix SRTree
-> Int
-> [(Int, (Int, Int, Int, Double))]
forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu SRTree x -> Int -> SRTree (x, Int)
forall x. SRTree x -> Int -> SRTree (x, Int)
forall {b} {a}. Num b => SRTree a -> b -> SRTree (a, b)
indexer SRTree [(Int, (Int, Int, Int, Double))]
-> Int -> [(Int, (Int, Int, Int, Double))]
forall {a} {a}.
Num a =>
SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert Fix SRTree
tree Int
0

    indexer :: SRTree a -> b -> SRTree (a, b)
indexer (Var Int
ix) b
iy   = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Var Int
ix
    indexer (Const Double
x) b
iy  = Double -> SRTree (a, b)
forall val. Double -> SRTree val
Const Double
x
    indexer (Param Int
ix) b
iy = Int -> SRTree (a, b)
forall val. Int -> SRTree val
Param Int
ix
    indexer (Bin Op
op a
l a
r) b
iy = Op -> (a, b) -> (a, b) -> SRTree (a, b)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1) (a
r, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
2)
    indexer (Uni Function
f a
t) b
iy = Function -> (a, b) -> SRTree (a, b)
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, b
2b -> b -> b
forall a. Num a => a -> a -> a
*b
iyb -> b -> b
forall a. Num a => a -> a -> a
+b
1)

    convert :: SRTree [(a, (a, Int, Int, Double))]
-> a -> [(a, (a, Int, Int, Double))]
convert (Var Int
ix) a
iy = [(a
iy, (a
0, Int
0, Int
ix, -Double
1))]
    convert (Const Double
x) a
iy = [(a
iy, (a
0, Int
2, -Int
1, Double
x))]
    convert (Param Int
ix) a
iy = [(a
iy, (a
0, Int
1, Int
ix, -Double
1))]
    convert (Uni Function
f [(a, (a, Int, Int, Double))]
t) a
iy = (a
iy, (a
1, Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: [(a, (a, Int, Int, Double))]
t
    convert (Bin Op
op [(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
r) a
iy = (a
iy, (a
2, Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, -Int
1, -Double
1)) (a, (a, Int, Int, Double))
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. a -> [a] -> [a]
: ([(a, (a, Int, Int, Double))]
l [(a, (a, Int, Int, Double))]
-> [(a, (a, Int, Int, Double))] -> [(a, (a, Int, Int, Double))]
forall a. Semigroup a => a -> a -> a
<> [(a, (a, Int, Int, Double))]
r)
{-# INLINE tree2arr #-}