{-# LANGUAGE BangPatterns #-}
module Algorithm.SRTree.Opt
where
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Data.Bifunctor (bimap, second)
import Data.Massiv.Array
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams, countNodes, convertProtectedOps)
import Data.SRTree.Eval (evalTree, compMode)
import qualified Data.Vector.Storable as VS
import qualified Data.IntMap.Strict as IntMap
import Data.SRTree.Recursion
import Debug.Trace
minimizeNLL' :: (ObjectiveD -> (Maybe VectorStorage) -> LocalAlgorithm) -> Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeNLL' :: (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg Distribution
dist Maybe PVector
mYerr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
| Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f, Int
0)
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f, Int
0)
| Bool
otherwise = (PVector
t_opt', Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> Double
nll Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t_opt', Int
nEvs)
where
tree' :: Fix SRTree
tree' = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> Fix SRTree
relabelParams (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree
tree
t0' :: Vector Double
t0' = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
j2ix :: IntMap Int
j2ix = [(Int, Int)] -> IntMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Int)] -> IntMap Int) -> [(Int, Int)] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Int
0..]
(Sz Int
n) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
funAndGrad :: ObjectiveD
funAndGrad = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree'
(Double
f, Vector Double
_) = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree' Vector Double
t0'
algorithm :: LocalAlgorithm
algorithm = ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg ObjectiveD
funAndGrad (VectorStorage -> Maybe VectorStorage
forall a. a -> Maybe a
Just (VectorStorage -> Maybe VectorStorage)
-> VectorStorage -> Maybe VectorStorage
forall a b. (a -> b) -> a -> b
$ Word -> VectorStorage
VectorStorage (Word -> VectorStorage) -> Word -> VectorStorage
forall a b. (a -> b) -> a -> b
$ Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
stop :: NonEmpty StoppingCondition
stop = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-6 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Double -> StoppingCondition
ObjectiveAbsoluteTolerance Double
1e-6, Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
problem :: LocalProblem
problem = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
(Vector Double
t_opt, Int
nEvs) = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
Right Solution
sol -> (Solution -> Vector Double
solutionParams Solution
sol, Solution -> Int
nEvals Solution
sol)
Left Result
e -> (Vector Double
t0', Int
0)
t_opt' :: PVector
t_opt' = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt
debugGrad :: ObjectiveD
debugGrad Vector Double
t = let g1 :: (Double, SRVector)
g1 = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, SRVector))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, SRVector)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode (Vector Double -> (Double, SRVector))
-> Vector Double -> (Double, SRVector)
forall a b. (a -> b) -> a -> b
$ Vector Double
t
g2 :: (Double, SRVector)
g2 = Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
treeArr IntMap Int
j2ix Vector Double
t
g3 :: (Double, Vector Double)
g3 = Distribution
-> SRMatrix -> PVector -> Maybe PVector -> Fix SRTree -> ObjectiveD
gradNLLGraph Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr Fix SRTree
tree' Vector Double
t
in (Vector Double, (Double, SRVector), (Double, SRVector),
(Double, Vector Double))
-> (Double, Vector Double) -> (Double, Vector Double)
forall a b. Show a => a -> b -> b
traceShow (Vector Double
t, (Double, SRVector)
g1, (Double, SRVector)
g2, (Double, Vector Double)
g3) ((Double, Vector Double) -> (Double, Vector Double))
-> (Double, Vector Double) -> (Double, Vector Double)
forall a b. (a -> b) -> a -> b
$ (Double, Vector Double)
g3
minimizeNLL :: Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeNLL :: Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
TNEWTON
minimizeNLLWithFixedParam' :: (ObjectiveD -> (Maybe VectorStorage) -> LocalAlgorithm) -> Distribution -> Maybe PVector -> Int -> SRMatrix -> PVector -> Fix SRTree -> Int -> PVector -> PVector
minimizeNLLWithFixedParam' :: (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg Distribution
dist Maybe PVector
mYerr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree Int
ix PVector
t0
| Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
| Bool
otherwise = PVector
t_opt'
where
tree' :: Fix SRTree
tree' = Distribution -> Double -> Fix SRTree -> Fix SRTree
buildNLL Distribution
dist (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) (Fix SRTree -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
tree
t0' :: Vector Double
t0' = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
treeArr :: [(Int, (Int, Int, Int, Double))]
treeArr = IntMap (Int, Int, Int, Double) -> [(Int, (Int, Int, Int, Double))]
forall a. IntMap a -> [(Int, a)]
IntMap.toAscList (IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))])
-> IntMap (Int, Int, Int, Double)
-> [(Int, (Int, Int, Int, Double))]
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> IntMap (Int, Int, Int, Double)
tree2arr Fix SRTree
tree'
j2ix :: IntMap Int
j2ix = [(Int, Int)] -> IntMap Int
forall a. [(Int, a)] -> IntMap a
IntMap.fromList ([(Int, Int)] -> IntMap Int) -> [(Int, Int)] -> IntMap Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
Prelude.zip (((Int, (Int, Int, Int, Double)) -> Int)
-> [(Int, (Int, Int, Int, Double))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Int, (Int, Int, Int, Double)) -> Int
forall a b. (a, b) -> a
fst [(Int, (Int, Int, Int, Double))]
treeArr) [Int
0..]
(Sz Int
n) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
setTo0 :: Vector Double -> Vector Double
setTo0 = (Vector Double -> [(Int, Double)] -> Vector Double
forall a. Storable a => Vector a -> [(Int, a)] -> Vector a
VS.// [(Int
ix, Double
0.0)])
funAndGrad :: ObjectiveD
funAndGrad = (SRVector -> Vector Double)
-> (Double, SRVector) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Vector Double -> Vector Double
setTo0 (Vector Double -> Vector Double)
-> (SRVector -> Vector Double) -> SRVector -> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (SRVector -> PVector) -> SRVector -> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> SRVector -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, SRVector) -> (Double, Vector Double))
-> (Vector Double -> (Double, SRVector)) -> ObjectiveD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> SRMatrix
-> PVector
-> Maybe PVector
-> [(Int, (Int, Int, Int, Double))]
-> IntMap Int
-> Vector Double
-> (Double, SRVector)
gradNLLArr Distribution
dist SRMatrix
xss PVector
ys Maybe PVector
mYerr [(Int, (Int, Int, Int, Double))]
treeArr IntMap Int
j2ix
(Double
f, SRVector
_) = Distribution
-> Maybe PVector
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, SRVector)
gradNLL Distribution
dist Maybe PVector
mYerr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
algorithm :: LocalAlgorithm
algorithm = ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
alg ObjectiveD
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
stop :: NonEmpty StoppingCondition
stop = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-8 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Double -> StoppingCondition
ObjectiveAbsoluteTolerance Double
1e-8, Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
problem :: LocalProblem
problem = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
(Vector Double
t_opt, Int
nEvs) = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
Right Solution
sol -> (Solution -> Vector Double
solutionParams Solution
sol, Solution -> Int
nEvals Solution
sol)
Left Result
e -> (Vector Double
t0', Int
0)
t_opt' :: PVector
t_opt' = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt
minimizeNLLWithFixedParam :: Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam = (ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm)
-> Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam' ObjectiveD -> Maybe VectorStorage -> LocalAlgorithm
TNEWTON
minimizeGaussian :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeGaussian :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeGaussian = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Gaussian Maybe PVector
forall a. Maybe a
Nothing
minimizeBinomial :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizeBinomial :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeBinomial = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Bernoulli Maybe PVector
forall a. Maybe a
Nothing
minimizePoisson :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double, Int)
minimizePoisson :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizePoisson = Distribution
-> Maybe PVector
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double, Int)
minimizeNLL Distribution
Poisson Maybe PVector
forall a. Maybe a
Nothing