-- | -- Module : Numeric.LBFGS.Types -- Copyright : (c) 2010 Daniël de Kok, 2016 Ian-Woo.Kim -- License : Apache 2 -- -- -- Maintainer : Daniël de Kok -- Stability : experimental -- module Numeric.LBFGS.Internal where import Foreign.C.Types (CDouble, CInt) import qualified Numeric.LBFGS.Raw as R import Numeric.LBFGS.Raw (CEvaluateFun, CProgressFun, CLBFGSParameter(..), defaultCParam, CLBFGSResult(..), c_lbfgs_malloc, c_lbfgs_free, c_lbfgs_evaluate_t_wrap, c_lbfgs_progress_t_wrap, c_lbfgs ) import Numeric.LBFGS.Types mergeLineSearchAlgorithm :: LineSearchAlgorithm -> CLBFGSParameter -> CLBFGSParameter mergeLineSearchAlgorithm DefaultLineSearch p = p {R.linesearch = R.defaultLineSearch} mergeLineSearchAlgorithm MoreThuente p = p { R.linesearch = R.moreThuente } mergeLineSearchAlgorithm BacktrackingArmijo p = p { R.linesearch = R.backtrackingArmijo } mergeLineSearchAlgorithm Backtracking p = p { R.linesearch = R.backtracking } mergeLineSearchAlgorithm (BacktrackingWolfe c) p = p { R.linesearch = R.backtrackingWolfe, R.wolfe = realToFrac c } mergeLineSearchAlgorithm (BacktrackingStrongWolfe c) p = p { R.linesearch = R.backtrackingStrongWolfe, R.wolfe = realToFrac c } mergeL1NormCoefficient :: L1NormCoefficient -> CInt -> CLBFGSParameter -> CLBFGSParameter mergeL1NormCoefficient Nothing _ p = p mergeL1NormCoefficient (Just l1) n p = p { R.linesearch = R.backtracking, R.orthantwise_c = realToFrac l1, R.orthantwise_start = 0, R.orthantwise_end = n - 1 } mergePast :: Maybe Int -> Double -> CLBFGSParameter -> CLBFGSParameter mergePast Nothing delta p = p { R.past = 0 } mergePast (Just iterations) delta p = p { R.past = fromIntegral iterations, R.delta = realToFrac delta } deriveResult :: CLBFGSResult -> LBFGSResult deriveResult r | r == R.lbfgsSuccess = Success | r == R.lbfgsStop = Stop | r == R.lbfgsAlreadyMinimized = AlreadyMinimized | r == R.lbfgserrUnknownerror = UnknownError | r == R.lbfgserrLogicerror = LogicError | r == R.lbfgserrOutofmemory = OutOfMemory | r == R.lbfgserrCanceled = Canceled | r == R.lbfgserrInvalidN = InvalidN | r == R.lbfgserrInvalidNSse = InvalidNSSE | r == R.lbfgserrInvalidXSse = InvalidXSSE | r == R.lbfgserrInvalidEpsilon = InvalidEpsilon | r == R.lbfgserrInvalidTestperiod = InvalidTestPeriod | r == R.lbfgserrInvalidDelta = InvalidDelta | r == R.lbfgserrInvalidLinesearch = InvalidLineSearch | r == R.lbfgserrInvalidMinstep = InvalidMinStep | r == R.lbfgserrInvalidMaxstep = InvalidMaxStep | r == R.lbfgserrInvalidFtol = InvalidFtol | r == R.lbfgserrInvalidWolfe = InvalidWolfe | r == R.lbfgserrInvalidGtol = InvalidGtol | r == R.lbfgserrInvalidXtol = InvalidXtol | r == R.lbfgserrInvalidMaxlinesearch = InvalidMaxLineSearch | r == R.lbfgserrInvalidOrthantwise = InvalidOrthantwise | r == R.lbfgserrInvalidOrthantwiseStart = InvalidOrthantwiseStart | r == R.lbfgserrInvalidOrthantwiseEnd = InvalidOrthantwiseEnd | r == R.lbfgserrOutofinterval = OutOfInterval | r == R.lbfgserrIncorrectTminmax = IncorrectTMinMax | r == R.lbfgserrRoundingError = RoundingError | r == R.lbfgserrMinimumstep = MinimumStep | r == R.lbfgserrMaximumstep = MaximumStep | r == R.lbfgserrMaximumlinesearch = MaximumLineSearch | r == R.lbfgserrMaximumiteration = MaximumIteration | r == R.lbfgserrWidthtoosmall = WidthTooSmall | r == R.lbfgserrInvalidparameters = InvalidParameters | r == R.lbfgserrIncreasegradient = IncreaseGradient