module Algorithm.SRTree.Opt
where
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.NonlinearOpt
import Data.Bifunctor (bimap, second)
import Data.Massiv.Array
import Data.SRTree (Fix (..), SRTree (..), floatConstsToParam, relabelParams)
import Data.SRTree.Eval (evalTree, compMode)
import qualified Data.Vector.Storable as VS
minimizeNLL :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeNLL :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
| Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
| Bool
otherwise = (Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt, Double
f)
where
tree' :: Fix SRTree
tree' = Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
tree
t0' :: Vector Double
t0' = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
(Sz Int
n) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree' (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
(Double
f, Array D Int Double
_) = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLL Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
algorithm :: LocalAlgorithm
algorithm = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
stop :: NonEmpty StoppingCondition
stop = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-10 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
problem :: LocalProblem
problem = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
t_opt :: Vector Double
t_opt = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
Left Result
e -> Vector Double
t0'
minimizeNLLNonUnique :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeNLLNonUnique :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLLNonUnique Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
| Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (PVector
t0, Double
f)
| Bool
otherwise = (Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt, Double
f)
where
t0' :: Vector Double
t0' = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
(Sz Int
n) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) ((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
(Double
f, Array D Int Double
_) = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
algorithm :: LocalAlgorithm
algorithm = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
stop :: NonEmpty StoppingCondition
stop = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-5 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
problem :: LocalProblem
problem = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
t_opt :: Vector Double
t_opt = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
Left Result
e -> Vector Double
t0'
minimizeNLLWithFixedParam :: Distribution -> Maybe Double -> Int -> SRMatrix -> PVector -> Fix SRTree -> Int -> PVector -> PVector
minimizeNLLWithFixedParam :: Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> Int
-> PVector
-> PVector
minimizeNLLWithFixedParam Distribution
dist Maybe Double
msErr Int
niter SRMatrix
xss PVector
ys Fix SRTree
tree Int
ix PVector
t0
| Int
niter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PVector
t0
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m = PVector
t0
| Bool
otherwise = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode Vector Double
t_opt
where
t0' :: Vector Double
t0' = PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
t0
(Sz Int
n) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
t0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
setTo0 :: Vector Double -> Vector Double
setTo0 = (Vector Double -> [(Int, Double)] -> Vector Double
forall a. Storable a => Vector a -> [(Int, a)] -> Vector a
VS.// [(Int
ix, Double
0.0)])
funAndGrad :: Vector Double -> (Double, Vector Double)
funAndGrad = (Array D Int Double -> Vector Double)
-> (Double, Array D Int Double) -> (Double, Vector Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Vector Double -> Vector Double
setTo0 (Vector Double -> Vector Double)
-> (Array D Int Double -> Vector Double)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector (PVector -> Vector Double)
-> (Array D Int Double -> PVector)
-> Array D Int Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S)((Double, Array D Int Double) -> (Double, Vector Double))
-> (Vector Double -> (Double, Array D Int Double))
-> Vector Double
-> (Double, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree (PVector -> (Double, Array D Int Double))
-> (Vector Double -> PVector)
-> Vector Double
-> (Double, Array D Int Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode
(Double
f, Array D Int Double
_) = Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (Double, Array D Int Double)
gradNLLNonUnique Distribution
dist Maybe Double
msErr SRMatrix
xss PVector
ys Fix SRTree
tree PVector
t0
algorithm :: LocalAlgorithm
algorithm = (Vector Double -> (Double, Vector Double))
-> Maybe VectorStorage -> LocalAlgorithm
LBFGS Vector Double -> (Double, Vector Double)
funAndGrad Maybe VectorStorage
forall a. Maybe a
Nothing
stop :: NonEmpty StoppingCondition
stop = Double -> StoppingCondition
ObjectiveRelativeTolerance Double
1e-5 StoppingCondition
-> [StoppingCondition] -> NonEmpty StoppingCondition
forall a. a -> [a] -> NonEmpty a
:| [Word -> StoppingCondition
MaximumEvaluations (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
niter)]
problem :: LocalProblem
problem = Word
-> NonEmpty StoppingCondition -> LocalAlgorithm -> LocalProblem
LocalProblem (Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) NonEmpty StoppingCondition
stop LocalAlgorithm
algorithm
t_opt :: Vector Double
t_opt = case LocalProblem -> Vector Double -> Either Result Solution
minimizeLocal LocalProblem
problem Vector Double
t0' of
Right Solution
sol -> Solution -> Vector Double
solutionParams Solution
sol
Left Result
e -> Vector Double
t0'
minimizeGaussian :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeGaussian :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeGaussian = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Gaussian Maybe Double
forall a. Maybe a
Nothing
minimizeBinomial :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizeBinomial :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeBinomial = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Bernoulli Maybe Double
forall a. Maybe a
Nothing
minimizePoisson :: Int -> SRMatrix -> PVector -> Fix SRTree -> PVector -> (PVector, Double)
minimizePoisson :: Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizePoisson = Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Poisson Maybe Double
forall a. Maybe a
Nothing
estimateSErr :: Distribution -> Maybe Double -> SRMatrix -> PVector -> PVector -> Fix SRTree -> Int -> Maybe Double
estimateSErr :: Distribution
-> Maybe Double
-> SRMatrix
-> PVector
-> PVector
-> Fix SRTree
-> Int
-> Maybe Double
estimateSErr Distribution
Gaussian Maybe Double
Nothing SRMatrix
xss PVector
ys PVector
theta0 Fix SRTree
t Int
nIter = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
err
where
theta :: PVector
theta = (PVector, Double) -> PVector
forall a b. (a, b) -> a
fst ((PVector, Double) -> PVector) -> (PVector, Double) -> PVector
forall a b. (a -> b) -> a -> b
$ Distribution
-> Maybe Double
-> Int
-> SRMatrix
-> PVector
-> Fix SRTree
-> PVector
-> (PVector, Double)
minimizeNLL Distribution
Gaussian (Double -> Maybe Double
forall a. a -> Maybe a
Just Double
1) Int
nIter SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta0
(Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
ys
(Sz Int
p) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
theta
ssr :: Double
ssr = SRMatrix -> PVector -> Fix SRTree -> PVector -> Double
sse SRMatrix
xss PVector
ys Fix SRTree
t PVector
theta
err :: Double
err = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
ssr Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
estimateSErr Distribution
_ (Just Double
s) SRMatrix
_ PVector
_ PVector
_ Fix SRTree
_ Int
_ = Double -> Maybe Double
forall a. a -> Maybe a
Just Double
s
estimateSErr Distribution
_ Maybe Double
_ SRMatrix
_ PVector
_ PVector
_ Fix SRTree
_ Int
_ = Maybe Double
forall a. Maybe a
Nothing