{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.ModelSelection 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Helper functions for model selection criteria
--
-----------------------------------------------------------------------------

module Algorithm.SRTree.ModelSelection where

import Algorithm.Massiv.Utils ( det )
import Algorithm.SRTree.Likelihoods
    ( PVector, SRMatrix, fisherNLL, hessianNLL, nll, Distribution )
import Data.Massiv.Array (Ix2 (..), Sz (..), (!-!))
import qualified Data.Massiv.Array as A
import Data.SRTree
import Data.SRTree.Eval (evalTree)
import Data.SRTree.Recursion (cata)
import qualified Data.Vector.Storable as VS


-- | Bayesian information criterion
bic :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
bic :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
bic Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = (Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
n Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  where
    (A.Sz (Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Double
p)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
theta
    (A.Sz (Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Double
n)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
ys
{-# INLINE bic #-}

-- | Akaike information criterion
aic :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
aic :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
aic Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
  where
    (A.Sz (Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Double
p)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
theta
    (A.Sz (Ix1 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Integer
n)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
ys
{-# INLINE aic #-}

-- | Evidence 
evidence :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
evidence :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
evidence Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
b) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
b
  where
    (A.Sz (Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Double
p)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
theta
    (A.Sz (Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> Double
n)) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
A.size PVector
ys
    b :: Double
b = Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double -> Double
forall a. Floating a => a -> a
sqrt Double
n
{-# INLINE evidence #-}

-- | MDL as described in 
-- Bartlett, Deaglan J., Harry Desmond, and Pedro G. Ferreira. "Exhaustive symbolic regression." IEEE Transactions on Evolutionary Computation (2023).
mdl :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
mdl :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
mdl Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
nll' Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta' Fix SRTree
tree
                                  Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Fix SRTree -> Double
logFunctional Fix SRTree
tree
                                  Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
logParameters Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree
  where
    fisher :: SRVector
fisher = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    theta' :: PVector
theta' = S -> SRVector -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
A.computeAs S
A.S (SRVector -> PVector) -> SRVector -> PVector
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> PVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
A.zipWith (\Double
t Double
f -> if Double -> Double -> Bool
forall {a}. (Ord a, Floating a) => a -> a -> Bool
isSignificant Double
t Double
f then Double
t else Double
0.0) PVector
theta SRVector
fisher
    isSignificant :: a -> a -> Bool
isSignificant a
v a
f = a -> a
forall a. Num a => a -> a
abs (a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
12 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
f) ) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1
{-# INLINE mdl #-}

-- | MDL Lattice as described in
-- Bartlett, Deaglan, Harry Desmond, and Pedro Ferreira. "Priors for symbolic regression." Proceedings of the Companion Conference on Genetic and Evolutionary Computation. 2023.
mdlLatt :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
mdlLatt :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
mdlLatt Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
nll' Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta' Fix SRTree
tree
                                     Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Fix SRTree -> Double
logFunctional Fix SRTree
tree
                                     Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
logParametersLatt Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree
  where
    fisher :: SRVector
fisher = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    theta' :: PVector
theta' = S -> SRVector -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
A.computeAs S
A.S (SRVector -> PVector) -> SRVector -> PVector
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> PVector -> SRVector -> SRVector
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
A.zipWith (\Double
t Double
f -> if Double -> Double -> Bool
forall {a}. (Ord a, Floating a) => a -> a -> Bool
isSignificant Double
t Double
f then Double
t else Double
0.0) PVector
theta SRVector
fisher
    isSignificant :: a -> a -> Bool
isSignificant a
v a
f = a -> a
forall a. Num a => a -> a
abs (a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
12 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
f) ) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1
{-# INLINE mdlLatt #-}

-- | same as `mdl` but weighting the functional structure by frequency calculated using a wiki information of
-- physics and engineering functions
mdlFreq :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
mdlFreq :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
mdlFreq Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
                                     Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Fix SRTree -> Double
logFunctionalFreq Fix SRTree
tree
                                     Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
logParameters Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree
{-# INLINE mdlFreq #-}

-- log of the functional complexity
logFunctional :: Fix SRTree -> Double
logFunctional :: Fix SRTree -> Double
logFunctional Fix SRTree
tree = Fix SRTree -> Double
forall a. Num a => Fix SRTree -> a
countNodes Fix SRTree
tree Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Fix SRTree -> Double
forall a. Num a => Fix SRTree -> a
countUniqueTokens Fix SRTree
tree') 
                   Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double -> Double -> Double) -> Double -> [Double] -> Double
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Double
c Double
acc -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
c) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
acc) Double
0 [Double]
consts 
                   Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log(Double
2) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
numberOfConsts
  where
    tree' :: Fix SRTree
tree'          = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
    consts :: [Double]
consts         = Fix SRTree -> [Double]
getIntConsts Fix SRTree
tree
    numberOfConsts :: Double
numberOfConsts = Ix1 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Ix1 -> Double) -> Ix1 -> Double
forall a b. (a -> b) -> a -> b
$ [Double] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [Double]
consts
    signs :: Integer
signs          = [Integer] -> Integer
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Integer
1 | Double
a <- Fix SRTree -> [Double]
getIntConsts Fix SRTree
tree, Double
a Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0] -- TODO: will we use that?
{-# INLINE logFunctional #-}

-- same as above but weighted by frequency 
logFunctionalFreq  :: Fix SRTree -> Double
logFunctionalFreq :: Fix SRTree -> Double
logFunctionalFreq Fix SRTree
tree = Fix SRTree -> Double
treeToNat Fix SRTree
tree' 
                       Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double -> Double -> Double) -> Double -> [Double] -> Double
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Double
c Double
acc -> Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
c) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
acc) Double
0 [Double]
consts  
                       Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Fix SRTree -> Double
forall a. Num a => Fix SRTree -> a
countVarNodes Fix SRTree
tree Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Fix SRTree -> Double
forall a. Num a => Fix SRTree -> a
numberOfVars Fix SRTree
tree)
  where
    tree' :: Fix SRTree
tree'  = (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst ((Fix SRTree, [Double]) -> Fix SRTree)
-> (Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam Fix SRTree
tree
    consts :: [Double]
consts = Fix SRTree -> [Double]
getIntConsts Fix SRTree
tree
{-# INLINE logFunctionalFreq #-}

-- log of the parameters complexity
logParameters :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
logParameters :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
logParameters Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = -(Double
p Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
3 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
logFisher Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
logTheta
  where
    -- p      = fromIntegral $ VS.length theta
    fisher :: SRVector
fisher = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta

    (Double
logTheta, Double
logFisher, Double
p) = ((Double, Double)
 -> (Double, Double, Double) -> (Double, Double, Double))
-> (Double, Double, Double)
-> [(Double, Double)]
-> (Double, Double, Double)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Double, Double)
-> (Double, Double, Double) -> (Double, Double, Double)
forall {b} {c}.
(Ord b, Floating b, Num c) =>
(b, b) -> (b, b, c) -> (b, b, c)
addIfSignificant (Double
0, Double
0, Double
0)
                             ([(Double, Double)] -> (Double, Double, Double))
-> [(Double, Double)] -> (Double, Double, Double)
forall a b. (a -> b) -> a -> b
$ [Double] -> [Double] -> [(Double, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
A.toList PVector
theta) (SRVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
A.toList SRVector
fisher)

    addIfSignificant :: (b, b) -> (b, b, c) -> (b, b, c)
addIfSignificant (b
v, b
f) (b
acc_v, b
acc_f, c
acc_p)
       | b -> b -> Bool
forall {a}. (Ord a, Floating a) => a -> a -> Bool
isSignificant b
v b
f = (b
acc_v b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Floating a => a -> a
log (b -> b
forall a. Num a => a -> a
abs b
v), b
acc_f b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Floating a => a -> a
log b
f, c
acc_p c -> c -> c
forall a. Num a => a -> a -> a
+ c
1)
       | Bool
otherwise         = (b
acc_v, b
acc_f, c
acc_p)

    isSignificant :: a -> a -> Bool
isSignificant a
v a
f = a -> a
forall a. Num a => a -> a
abs (a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
12 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
f) ) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1

-- same as above but for the Lattice 
logParametersLatt :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
logParametersLatt :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
logParametersLatt Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log Double
3) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
detFisher
  where
    fisher :: SRVector
fisher = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRVector
fisherNLL Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
    detFisher :: Double
detFisher = SRMatrix -> Double
det (SRMatrix -> Double) -> SRMatrix -> Double
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> SRMatrix
hessianNLL Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta

    (Double
logTheta, Double
logFisher, Double
p) = ((Double, Double)
 -> (Double, Double, Double) -> (Double, Double, Double))
-> (Double, Double, Double)
-> [(Double, Double)]
-> (Double, Double, Double)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Double, Double)
-> (Double, Double, Double) -> (Double, Double, Double)
forall {b} {c}.
(Ord b, Floating b, Num c) =>
(b, b) -> (b, b, c) -> (b, b, c)
addIfSignificant (Double
0, Double
0, Double
0)
                             ([(Double, Double)] -> (Double, Double, Double))
-> [(Double, Double)] -> (Double, Double, Double)
forall a b. (a -> b) -> a -> b
$ [Double] -> [Double] -> [(Double, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
A.toList PVector
theta) (SRVector -> [Double]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
A.toList SRVector
fisher)

    addIfSignificant :: (b, b) -> (b, b, c) -> (b, b, c)
addIfSignificant (b
v, b
f) (b
acc_v, b
acc_f, c
acc_p)
       | b -> b -> Bool
forall {a}. (Ord a, Floating a) => a -> a -> Bool
isSignificant b
v b
f = (b
acc_v b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Floating a => a -> a
log (b -> b
forall a. Num a => a -> a
abs b
v), b
acc_f b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Floating a => a -> a
log b
f, c
acc_p c -> c -> c
forall a. Num a => a -> a -> a
+ c
1)
       | Bool
otherwise         = (b
acc_v, b
acc_f, c
acc_p)

    isSignificant :: a -> a -> Bool
isSignificant a
v a
f = a -> a
forall a. Num a => a -> a
abs (a
v a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
12 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
f) ) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1

-- flipped version of nll
nll' :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Double
nll' :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Double
nll' Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys PVector
theta Fix SRTree
tree = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe Double
mSErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
theta
{-# INLINE nll' #-}

treeToNat :: Fix SRTree -> Double
treeToNat :: Fix SRTree -> Double
treeToNat = (SRTree Double -> Double) -> Fix SRTree -> Double
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((SRTree Double -> Double) -> Fix SRTree -> Double)
-> (SRTree Double -> Double) -> Fix SRTree -> Double
forall a b. (a -> b) -> a -> b
$
  \case
    Uni Function
f Double
t    -> Function -> Double
funToNat Function
f Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
t
    Bin Op
op Double
l Double
r -> Op -> Double
opToNat Op
op Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
l Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
r
    SRTree Double
_          -> Double
0.6610799229372109
  where

    opToNat :: Op -> Double
    opToNat :: Op -> Double
opToNat Op
Add = Double
2.500842464597881
    opToNat Op
Sub = Double
2.500842464597881
    opToNat Op
Mul = Double
1.720356134912558
    opToNat Op
Div = Double
2.60436883851265
    opToNat Op
Power = Double
2.527957363394847
    opToNat Op
PowerAbs = Double
2.527957363394847
    opToNat Op
AQ = Double
2.60436883851265

    funToNat :: Function -> Double
    funToNat :: Function -> Double
funToNat Function
Sqrt = Double
4.780867285331753
    funToNat Function
Log  = Double
4.765599813200964
    funToNat Function
Exp  = Double
4.788589331425663
    funToNat Function
Abs  = Double
6.352564869783006
    funToNat Function
Sin  = Double
5.9848400896576885
    funToNat Function
Cos  = Double
5.474014465891698
    funToNat Function
Sinh = Double
8.038963823353235
    funToNat Function
Cosh = Double
8.262107374667444
    funToNat Function
Tanh = Double
7.85664226655928
    funToNat Function
Tan  = Double
8.262107374667444
    funToNat Function
_    = Double
8.262107374667444
    --funToNat Factorial = 7.702491586732021
{-# INLINE treeToNat #-}