{-# LANGUAGE BangPatterns #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.SRTree.Opt 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to optimize the parameters of an expression.
--
-----------------------------------------------------------------------------
module Algorithm.SRTree.Opt
    where

import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Data.Bifunctor (bimap, second)
import Data.Massiv.Array
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams, countNodes, convertProtectedOps)
import Data.SRTree.Eval (evalTree, compMode)
import qualified Data.Vector.Storable as VS
import qualified Data.IntMap.Strict as IntMap
import Data.SRTree.Recursion

import Debug.Trace



-- | minimizes the negative log-likelihood of the expression
minimizeNLL' :: (ObjectiveD -> (Maybe VectorStorage) -> LocalAlgorithm) -> Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeNLL' :: (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg Distribution
dist Maybe PVector
mYerr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
  | Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f, Int
0)
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = (PVector
t0, Double
f, Int
0)
  | Bool
otherwise  = (PVector
t_opt', Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t_opt', Int
nEvs)
  where
    tree' :: Fix SRTree
tree'      = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> Fix SRTree
relabelParams (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree
tree -- convertProtectedOps
    t0' :: Vector Double
t0'        = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
    treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr    = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
 -> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
    j2ix :: IntMap Int
j2ix       = [(Int, Int)] -> IntMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Int)] -> IntMap Int) -> [(Int, Int)] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Int
0..]
    (Sz Int
n)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
    (Sz Int
m)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    funAndGrad :: ObjectiveD
funAndGrad = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree' -- second (toStorableVector . computeAs S) . gradNLLArr dist xss ys mYerr treeArr j2ix

    (Double
f, Vector Double
_)     = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree' Vector Double
t0' -- if there's no parameter or no iterations
    -- gradNLL dist mYerr xss ys tree t0
    --debug1     = gradNLLArr dist msErr xss ys treeArr j2ix t0
    --debug2     = gradNLL dist msErr xss ys tree t0

    algorithm :: LocalAlgorithm
algorithm  = ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg ObjectiveD
funAndGrad (VectorStorage -> Maybe VectorStorage
forall a. a -> Maybe a
Just (VectorStorage -> Maybe VectorStorage)
-> VectorStorage -> Maybe VectorStorage
forall a b. (a -> b) -> a -> b
$ Word -> VectorStorage
VectorStorage (Word -> VectorStorage) -> Word -> VectorStorage
forall a b. (a -> b) -> a -> b
$ Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) -- alg funAndGrad Nothing -- PRAXIS (fst . funAndGrad) [] Nothing -- TNEWTON funAndGrad Nothing
    stop :: NonEmpty StoppingCondition
stop       = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-6 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Double -> StoppingCondition
ObjectiveAbsoluteTolerance Double
1e-6, Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
    problem :: LocalProblem
problem    = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
    (Vector Double
t_opt, Int
nEvs) = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
                      Right Solution
sol -> (Solution -> Vector Double
solutionParams Solution
sol, Solution -> Int
nEvals Solution
sol) -- traceShow (">>>>>>>", nEvals sol) $
                      Left Result
e    -> (Vector Double
t0', Int
0)
    t_opt' :: PVector
t_opt'      = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt
    debugGrad :: ObjectiveD
debugGrad Vector Double
t = let g1 :: (Double, SRVector)
g1 = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, SRVector))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, SRVector)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode (Vector Double -> (Double, SRVector))
-> Vector Double -> (Double, SRVector)
forall a b. (a -> b) -> a -> b
$ Vector Double
t
                      g2 :: (Double, SRVector)
g2 = Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
treeArr IntMap Int
j2ix Vector Double
t
                      g3 :: (Double, Vector Double)
g3 = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree' Vector Double
t
                  in (Vector Double, (Double, SRVector), (Double, SRVector),
 (Double, Vector Double))
-> (Double, Vector Double) -> (Double, Vector Double)
forall a b. Show a => a -> b -> b
traceShow (Vector Double
t, (Double, SRVector)
g1, (Double, SRVector)
g2, (Double, Vector Double)
g3) ((Double, Vector Double) -> (Double, Vector Double))
-> (Double, Vector Double) -> (Double, Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double, Vector Double)
g3 -- second (toStorableVector . computeAs S) g2

minimizeNLL :: Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeNLL :: Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
TNEWTON

-- | minimizes the function while keeping the parameter ix fixed (used to calculate the profile)
minimizeNLLWithFixedParam' :: (ObjectiveD -> (Maybe VectorStorage) -> LocalAlgorithm) -> Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> Int -> PVector -> PVector
minimizeNLLWithFixedParam' :: (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg Distribution
dist Maybe PVector
mYerr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree Int
ix PVector
t0
  | Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     = PVector
t0
  | Bool
otherwise  = PVector
t_opt'
  where
    tree' :: Fix SRTree
tree'      = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
tree
    t0' :: Vector Double
t0'        = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
    treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr    = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
 -> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
    j2ix :: IntMap Int
j2ix       = [(Int, Int)] -> IntMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Int)] -> IntMap Int) -> [(Int, Int)] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Int
0..]
    (Sz Int
n)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
    (Sz Int
m)     = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
    setTo0 :: Vector Double -> Vector Double
setTo0     = (Vector Double -> [(Int, Double)] -> Vector Double
forall a. Storable a => Vector a -> [(Int, a)] -> Vector a
VS.// [(Int
ix, Double
0.0)])
    funAndGrad :: ObjectiveD
funAndGrad = (SRVector -> Vector Double)
-> (Double, SRVector) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Vector Double -> Vector Double
setTo0 (Vector Double -> Vector Double)
-> (SRVector -> Vector Double) -> SRVector -> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (SRVector -> PVector) -> SRVector -> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> SRVector -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, SRVector) -> (Double, Vector Double))
-> (Vector Double -> (Double, SRVector)) -> ObjectiveD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
treeArr IntMap Int
j2ix

    (Double
f, SRVector
_)     = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0 -- if there's no parameter or no iterations

    algorithm :: LocalAlgorithm
algorithm  = ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg ObjectiveD
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing -- PRAXIS (fst . funAndGrad) [] Nothing -- TNEWTON funAndGrad Nothing
    stop :: NonEmpty StoppingCondition
stop       = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-8 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Double -> StoppingCondition
ObjectiveAbsoluteTolerance Double
1e-8, Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
    problem :: LocalProblem
problem    = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
    (Vector Double
t_opt, Int
nEvs) = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
                      Right Solution
sol -> (Solution -> Vector Double
solutionParams Solution
sol, Solution -> Int
nEvals Solution
sol) -- traceShow (">>>>>>>", nEvals sol) $
                      Left Result
e    -> (Vector Double
t0', Int
0)
    t_opt' :: PVector
t_opt'      = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt

minimizeNLLWithFixedParam :: Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
TNEWTON

-- | minimizes using Gaussian likelihood 
minimizeGaussian :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeGaussian :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeGaussian = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Gaussian Maybe PVector
forall a. Maybe a
Nothing

-- | minimizes using Binomial likelihood 
minimizeBinomial :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeBinomial :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeBinomial = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Bernoulli Maybe PVector
forall a. Maybe a
Nothing

-- | minimizes using Poisson likelihood 
minimizePoisson :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizePoisson :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizePoisson = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Poisson Maybe PVector
forall a. Maybe a
Nothing