module Numeric.LBFGS.Internal where
import Foreign.C.Types (CDouble, CInt)
import qualified Numeric.LBFGS.Raw as R
import Numeric.LBFGS.Raw (CEvaluateFun, CProgressFun, CLBFGSParameter(..),
defaultCParam, CLBFGSResult(..),
c_lbfgs_malloc, c_lbfgs_free,
c_lbfgs_evaluate_t_wrap, c_lbfgs_progress_t_wrap,
c_lbfgs
)
import Numeric.LBFGS.Types
mergeLineSearchAlgorithm :: LineSearchAlgorithm -> CLBFGSParameter ->
CLBFGSParameter
mergeLineSearchAlgorithm DefaultLineSearch p =
p {R.linesearch = R.defaultLineSearch}
mergeLineSearchAlgorithm MoreThuente p =
p { R.linesearch = R.moreThuente }
mergeLineSearchAlgorithm BacktrackingArmijo p =
p { R.linesearch = R.backtrackingArmijo }
mergeLineSearchAlgorithm Backtracking p =
p { R.linesearch = R.backtracking }
mergeLineSearchAlgorithm (BacktrackingWolfe c) p =
p { R.linesearch = R.backtrackingWolfe,
R.wolfe = realToFrac c }
mergeLineSearchAlgorithm (BacktrackingStrongWolfe c) p =
p { R.linesearch = R.backtrackingStrongWolfe,
R.wolfe = realToFrac c }
mergeL1NormCoefficient :: L1NormCoefficient -> CInt -> CLBFGSParameter ->
CLBFGSParameter
mergeL1NormCoefficient Nothing _ p = p
mergeL1NormCoefficient (Just l1) n p =
p { R.linesearch = R.backtracking,
R.orthantwise_c = realToFrac l1,
R.orthantwise_start = 0,
R.orthantwise_end = n 1 }
mergePast :: Maybe Int -> Double -> CLBFGSParameter -> CLBFGSParameter
mergePast Nothing delta p = p { R.past = 0 }
mergePast (Just iterations) delta p = p {
R.past = fromIntegral iterations,
R.delta = realToFrac delta
}
deriveResult :: CLBFGSResult -> LBFGSResult
deriveResult r
| r == R.lbfgsSuccess = Success
| r == R.lbfgsStop = Stop
| r == R.lbfgsAlreadyMinimized = AlreadyMinimized
| r == R.lbfgserrUnknownerror = UnknownError
| r == R.lbfgserrLogicerror = LogicError
| r == R.lbfgserrOutofmemory = OutOfMemory
| r == R.lbfgserrCanceled = Canceled
| r == R.lbfgserrInvalidN = InvalidN
| r == R.lbfgserrInvalidNSse = InvalidNSSE
| r == R.lbfgserrInvalidXSse = InvalidXSSE
| r == R.lbfgserrInvalidEpsilon = InvalidEpsilon
| r == R.lbfgserrInvalidTestperiod = InvalidTestPeriod
| r == R.lbfgserrInvalidDelta = InvalidDelta
| r == R.lbfgserrInvalidLinesearch = InvalidLineSearch
| r == R.lbfgserrInvalidMinstep = InvalidMinStep
| r == R.lbfgserrInvalidMaxstep = InvalidMaxStep
| r == R.lbfgserrInvalidFtol = InvalidFtol
| r == R.lbfgserrInvalidWolfe = InvalidWolfe
| r == R.lbfgserrInvalidGtol = InvalidGtol
| r == R.lbfgserrInvalidXtol = InvalidXtol
| r == R.lbfgserrInvalidMaxlinesearch = InvalidMaxLineSearch
| r == R.lbfgserrInvalidOrthantwise = InvalidOrthantwise
| r == R.lbfgserrInvalidOrthantwiseStart = InvalidOrthantwiseStart
| r == R.lbfgserrInvalidOrthantwiseEnd = InvalidOrthantwiseEnd
| r == R.lbfgserrOutofinterval = OutOfInterval
| r == R.lbfgserrIncorrectTminmax = IncorrectTMinMax
| r == R.lbfgserrRoundingError = RoundingError
| r == R.lbfgserrMinimumstep = MinimumStep
| r == R.lbfgserrMaximumstep = MaximumStep
| r == R.lbfgserrMaximumlinesearch = MaximumLineSearch
| r == R.lbfgserrMaximumiteration = MaximumIteration
| r == R.lbfgserrWidthtoosmall = WidthTooSmall
| r == R.lbfgserrInvalidparameters = InvalidParameters
| r == R.lbfgserrIncreasegradient = IncreaseGradient