module Numeric.LBFGS (LineSearchAlgorithm(..), EvaluateFun,
ProgressFun, LBFGSResult, lbfgs) where
import Data.Array.Storable (StorableArray,
unsafeForeignPtrToStorableArray)
import Foreign.C.Types (CDouble, CInt)
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Marshal.Alloc (malloc, free)
import Foreign.Ptr (Ptr, freeHaskellFunPtr, nullPtr, plusPtr)
import Foreign.StablePtr (StablePtr, deRefStablePtr, newStablePtr,
freeStablePtr)
import Foreign.Storable (Storable(..), peek, poke, sizeOf)
import qualified Numeric.LBFGS.Raw as R
import Numeric.LBFGS.Raw (CEvaluateFun, CProgressFun, CLBFGSParameter(..),
defaultCParam, CLBFGSResult(..),
c_lbfgs_malloc, c_lbfgs_free,
c_lbfgs_evaluate_t_wrap, c_lbfgs_progress_t_wrap,
c_lbfgs
)
data LineSearchAlgorithm = DefaultLineSearch
| MoreThuente
| BacktrackingArmijo
| Backtracking
| BacktrackingWolfe {coeff :: Double }
| BacktrackingStrongWolfe {coeff :: Double }
mergeLineSearchAlgorithm :: CLBFGSParameter -> LineSearchAlgorithm ->
CLBFGSParameter
mergeLineSearchAlgorithm p DefaultLineSearch =
p {R.linesearch = R.defaultLineSearch}
mergeLineSearchAlgorithm p MoreThuente =
p { R.linesearch = R.moreThuente }
mergeLineSearchAlgorithm p BacktrackingArmijo =
p { R.linesearch = R.backtrackingArmijo }
mergeLineSearchAlgorithm p Backtracking =
p { R.linesearch = R.backtracking }
mergeLineSearchAlgorithm p (BacktrackingWolfe c) =
p { R.linesearch = R.backtrackingWolfe,
R.wolfe = realToFrac c }
mergeLineSearchAlgorithm p (BacktrackingStrongWolfe c) =
p { R.linesearch = R.backtrackingStrongWolfe,
R.wolfe = realToFrac c }
withParam :: LineSearchAlgorithm -> CLBFGSParameter
withParam lineSearch =
mergeLineSearchAlgorithm defaultCParam lineSearch
data LBFGSResult
= Success
| Stop
| AlreadyMinimized
| UnknownError
| LogicError
| OutOfMemory
| Canceled
| InvalidN
| InvalidNSSE
| InvalidXSSE
| InvalidEpsilon
| InvalidTestPeriod
| InvalidDelta
| InvalidLineSearch
| InvalidMinStep
| InvalidMaxStep
| InvalidFtol
| InvalidWolfe
| InvalidGtol
| InvalidXtol
| InvalidMaxLineSearch
| InvalidOrthantwise
| InvalidOrthantwiseStart
| InvalidOrthantwiseEnd
| OutOfInterval
| IncorrectTMinMax
| RoundingError
| MinimumStep
| MaximumStep
| MaximumLineSearch
| MaximumIteration
| WidthTooSmall
| InvalidParameters
| IncreaseGradient
deriving (Eq, Show)
deriveResult :: CLBFGSResult -> LBFGSResult
deriveResult r
| r == R.lbfgsSuccess = Success
| r == R.lbfgsStop = Stop
| r == R.lbfgsAlreadyMinimized = AlreadyMinimized
| r == R.lbfgserrUnknownerror = UnknownError
| r == R.lbfgserrLogicerror = LogicError
| r == R.lbfgserrOutofmemory = OutOfMemory
| r == R.lbfgserrCanceled = Canceled
| r == R.lbfgserrInvalidN = InvalidN
| r == R.lbfgserrInvalidNSse = InvalidNSSE
| r == R.lbfgserrInvalidXSse = InvalidXSSE
| r == R.lbfgserrInvalidEpsilon = InvalidEpsilon
| r == R.lbfgserrInvalidTestperiod = InvalidTestPeriod
| r == R.lbfgserrInvalidDelta = InvalidDelta
| r == R.lbfgserrInvalidLinesearch = InvalidLineSearch
| r == R.lbfgserrInvalidMinstep = InvalidMinStep
| r == R.lbfgserrInvalidMaxstep = InvalidMaxStep
| r == R.lbfgserrInvalidFtol = InvalidFtol
| r == R.lbfgserrInvalidWolfe = InvalidWolfe
| r == R.lbfgserrInvalidGtol = InvalidGtol
| r == R.lbfgserrInvalidXtol = InvalidXtol
| r == R.lbfgserrInvalidMaxlinesearch = InvalidMaxLineSearch
| r == R.lbfgserrInvalidOrthantwise = InvalidOrthantwise
| r == R.lbfgserrInvalidOrthantwiseStart = InvalidOrthantwiseStart
| r == R.lbfgserrInvalidOrthantwiseEnd = InvalidOrthantwiseEnd
| r == R.lbfgserrOutofinterval = OutOfInterval
| r == R.lbfgserrIncorrectTminmax = IncorrectTMinMax
| r == R.lbfgserrRoundingError = RoundingError
| r == R.lbfgserrMinimumstep = MinimumStep
| r == R.lbfgserrMaximumstep = MaximumStep
| r == R.lbfgserrMaximumlinesearch = MaximumLineSearch
| r == R.lbfgserrMaximumiteration = MaximumIteration
| r == R.lbfgserrWidthtoosmall = WidthTooSmall
| r == R.lbfgserrInvalidparameters = InvalidParameters
| r == R.lbfgserrIncreasegradient = IncreaseGradient
cDoublePlusPtr :: Ptr CDouble -> Int -> Ptr CDouble
cDoublePlusPtr ptr n = plusPtr ptr (n * sizeOf (undefined :: CDouble))
listToVector :: [Double] -> IO (CInt, Ptr CDouble)
listToVector l = do
v <- c_lbfgs_malloc n
copyList l v
return (n, v)
where n = fromIntegral . length $ l
copyList :: [Double] -> Ptr CDouble -> IO ()
copyList [] _ = return ()
copyList l p = do
poke p $ realToFrac $ head l
copyList (tail l) (cDoublePlusPtr p 1)
freeVector :: Ptr CDouble -> IO ()
freeVector = c_lbfgs_free
vectorToList :: CInt -> Ptr CDouble -> IO ([Double])
vectorToList cn p = vectorToList_ p (cDoublePlusPtr p (n 1)) []
where n = fromIntegral cn
vectorToList_ :: Ptr CDouble -> Ptr CDouble -> [Double] -> IO ([Double])
vectorToList_ pStart pCur l
| pCur >= pStart = do
cval <- peek pCur
let val = realToFrac cval
vectorToList_ pStart (cDoublePlusPtr pCur (1)) (val:l)
| otherwise = return l
type EvaluateFun a =
a
-> StorableArray Int CDouble
-> StorableArray Int CDouble
-> CInt
-> CDouble
-> IO (CDouble)
wrapEvaluateFun :: EvaluateFun a -> StablePtr a -> Ptr CDouble ->
Ptr CDouble -> CInt -> CDouble -> IO (CDouble)
wrapEvaluateFun fun inst x g n step = do
let nInt = fromIntegral n
instV <- deRefStablePtr inst
xFp <- newForeignPtr_ x
xArr <- unsafeForeignPtrToStorableArray xFp (0, nInt 1)
gFp <- newForeignPtr_ g
gArr <- unsafeForeignPtrToStorableArray gFp (0, nInt 1)
fun instV xArr gArr n step
type ProgressFun a =
a
-> StorableArray Int CDouble
-> StorableArray Int CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CInt
-> CInt
-> CInt
-> IO (CInt)
wrapProgressFun :: ProgressFun a -> StablePtr a -> Ptr CDouble ->
Ptr CDouble-> CDouble -> CDouble -> CDouble -> CDouble ->
CInt -> CInt -> CInt -> IO (CInt)
wrapProgressFun fun inst x g fx xn gn step n k ls = do
let nInt = fromIntegral n
instV <- deRefStablePtr inst
xFp <- newForeignPtr_ x
xArr <- unsafeForeignPtrToStorableArray xFp (0, nInt 1)
gFp <- newForeignPtr_ g
gArr <- unsafeForeignPtrToStorableArray gFp (0, nInt 1)
fun instV xArr gArr fx xn gn step n k ls
lbfgs :: LineSearchAlgorithm
-> EvaluateFun a
-> ProgressFun a
-> a
-> [Double]
-> IO(LBFGSResult, [Double])
lbfgs ls evalFun progressFun inst p = lbfgs_ ls (wrapEvaluateFun evalFun)
(wrapProgressFun progressFun) inst p
lbfgs_ :: LineSearchAlgorithm -> CEvaluateFun a -> CProgressFun a -> a ->
[Double] -> IO(LBFGSResult, [Double])
lbfgs_ ls evalFun progressFun inst p = do
(n, pVec) <- listToVector p
let param = withParam ls
instP <- newStablePtr inst
paramP <- malloc
poke paramP param
evalW <- c_lbfgs_evaluate_t_wrap evalFun
progressW <- c_lbfgs_progress_t_wrap progressFun
r <- c_lbfgs n pVec nullPtr evalW progressW instP paramP
freeHaskellFunPtr progressW
freeHaskellFunPtr evalW
free paramP
freeStablePtr instP
freeVector pVec
rl <- vectorToList n pVec
return (deriveResult $ CLBFGSResult r, rl)