-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.Opt 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to optimize the parameters of an expression.
--
-----------------------------------------------------------------------------
module Algorithm.SRTree.Opt
    where

import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Data.Bifunctor (bimap, second)
import Data.Massiv.Array
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams)
import Data.SRTree.Eval (evalTree, compMode)
import qualified Data.Vector.Storable as VS

-- | minimizes the negative log-likelihood of the expression
minimizeNLL :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeNLL :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
  | Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = (PVector
t0, Double
f)
  | Bool
otherwise  = (Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt, Double
f)
  where
    tree' :: Fix SRTree
tree'      = Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
tree -- $ fst $ floatConstsToParam tree
    t0' :: Vector Double
t0'        = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
    (Sz Int
n)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
    (Sz Int
m)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree' (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
    (Double
f, Array D Int Double
_)     = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0 -- if there's no parameter or no iterations

    algorithm :: LocalAlgorithm
algorithm  = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
    stop :: NonEmpty StoppingCondition
stop       = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-10 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
    problem :: LocalProblem
problem    = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
    t_opt :: Vector Double
t_opt      = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
                  Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
                  Left Result
e    -> Vector Double
t0'

-- | minimizes the likelihood assuming repeating parameters in the expression 
minimizeNLLNonUnique :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeNLLNonUnique :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLLNonUnique Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
  | Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = (PVector
t0, Double
f)
  | Bool
otherwise  = (Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt, Double
f)
  where
    t0' :: Vector Double
t0'        = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
    (Sz Int
n)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
    (Sz Int
m)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
    (Double
f, Array D Int Double
_)     = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0 -- if there's no parameter or no iterations

    algorithm :: LocalAlgorithm
algorithm  = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
    stop :: NonEmpty StoppingCondition
stop       = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-5 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
    problem :: LocalProblem
problem    = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
    t_opt :: Vector Double
t_opt      = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
                  Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
                  Left Result
e    -> Vector Double
t0'

-- | minimizes the function while keeping the parameter ix fixed (used to calculate the profile)
minimizeNLLWithFixedParam :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> Int -> PVector -> PVector
minimizeNLLWithFixedParam :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree Int
ix PVector
t0
  | Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = PVector
t0
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m      = PVector
t0
  | Bool
otherwise  = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt
  where
    t0' :: Vector Double
t0'        = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
    (Sz Int
n)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
    (Sz Int
m)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    setTo0 :: Vector Double -> Vector Double
setTo0     = (Vector Double -> [(Int, Double)] -> Vector Double
forall a. Storable a => Vector a -> [(Int, a)] -> Vector a
VS.// [(Int
ix, Double
0.0)])
    funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Vector Double -> Vector Double
setTo0 (Vector Double -> Vector Double)
-> (Array D Int Double -> Vector Double)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S)((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
    (Double
f, Array D Int Double
_)     = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0 -- if there's no parameter or no iterations

    algorithm :: LocalAlgorithm
algorithm  = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
    stop :: NonEmpty StoppingCondition
stop       = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-5 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
    problem :: LocalProblem
problem    = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
    t_opt :: Vector Double
t_opt      = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
                  Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
                  Left Result
e    -> Vector Double
t0'

-- | minimizes using Gaussian likelihood 
minimizeGaussian :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeGaussian :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeGaussian = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Gaussian Maybe Double
forall a. Maybe a
Nothing

-- | minimizes using Binomial likelihood 
minimizeBinomial :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeBinomial :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeBinomial = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Bernoulli Maybe Double
forall a. Maybe a
Nothing

-- | minimizes using Poisson likelihood 
minimizePoisson :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizePoisson :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizePoisson = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Poisson Maybe Double
forall a. Maybe a
Nothing

-- estimates the standard error if not provided 
estimateSErr :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Int -> Maybe Double
estimateSErr :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Int
-> Maybe Double
estimateSErr Distribution
Gaussian Maybe Double
Nothing  SRMatrix
xss PVector
ys PVector
theta0 Fix SRTree
t Int
nIter = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
err
  where
    theta :: PVector
theta  = (PVector, Double) -> PVector
forall a b. (a, b) -> a
fst ((PVector, Double) -> PVector) -> (PVector, Double) -> PVector
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Gaussian (Double -> Maybe Double
forall a. a -> Maybe a
Just Double
1) Int
nIter SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta0
    (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    (Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
theta
    ssr :: Double
ssr    = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
    err :: Double
err    = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
ssr Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
estimateSErr Distribution
_        (Just Double
s) SRMatrix
_   PVector
_  PVector
_ Fix SRTree
_ Int
_   = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
s
estimateSErr Distribution
_        Maybe Double
_        SRMatrix
_   PVector
_  PVector
_ Fix SRTree
_ Int
_   = Maybe Double
forall a. Maybe a
Nothing