module Numeric.LBFGS.Vector
( LineSearchAlgorithm(..)
, EvaluateFun
, ProgressFun
, LBFGSParameters(..)
, LBFGSResult(..)
, lbfgs
) where
import Data.Vector.Storable.Mutable (IOVector)
import qualified Data.Vector.Storable.Mutable as M
import Data.Maybe
import Foreign.C.Types (CDouble, CInt)
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Marshal.Alloc (malloc, free)
import Foreign.Ptr (Ptr, freeHaskellFunPtr, nullPtr, plusPtr)
import Foreign.StablePtr (StablePtr, deRefStablePtr, newStablePtr,
freeStablePtr)
import Foreign.Storable (Storable(..), peek, poke, sizeOf)
import qualified Numeric.LBFGS.Raw as R
import Numeric.LBFGS.Raw (CEvaluateFun, CProgressFun, CLBFGSParameter(..),
defaultCParam, CLBFGSResult(..),
c_lbfgs_malloc, c_lbfgs_free,
c_lbfgs_evaluate_t_wrap, c_lbfgs_progress_t_wrap,
c_lbfgs
)
import Numeric.LBFGS.Types
import Numeric.LBFGS.Internal
withParam :: LBFGSParameters -> CInt -> CLBFGSParameter
withParam (LBFGSParameters past delta lineSearch l1NormCoeff) n =
mergeL1NormCoefficient l1NormCoeff n $ (mergeLineSearchAlgorithm lineSearch)
$ mergePast past delta defaultCParam
defaultLBFGSParameters :: LBFGSParameters
defaultLBFGSParameters = LBFGSParameters Nothing 1e-5 DefaultLineSearch Nothing
cDoublePlusPtr :: Ptr CDouble -> Int -> Ptr CDouble
cDoublePlusPtr ptr n = plusPtr ptr (n * sizeOf (undefined :: CDouble))
listToVector :: [Double] -> IO (CInt, Ptr CDouble)
listToVector l = do
v <- c_lbfgs_malloc n
copyList l v
return (n, v)
where n = fromIntegral . length $ l
copyList :: [Double] -> Ptr CDouble -> IO ()
copyList [] _ = return ()
copyList l p = do
poke p $ realToFrac $ head l
copyList (tail l) (cDoublePlusPtr p 1)
freeVector :: Ptr CDouble -> IO ()
freeVector = c_lbfgs_free
vectorToList :: CInt -> Ptr CDouble -> IO ([Double])
vectorToList cn p = vectorToList_ p (cDoublePlusPtr p (n 1)) []
where n = fromIntegral cn
vectorToList_ :: Ptr CDouble -> Ptr CDouble -> [Double] -> IO ([Double])
vectorToList_ pStart pCur l
| pCur >= pStart = do
cval <- peek pCur
let val = realToFrac cval
vectorToList_ pStart (cDoublePlusPtr pCur (1)) (val:l)
| otherwise = return l
type EvaluateFun a =
a
-> IOVector CDouble
-> IOVector CDouble
-> CInt
-> CDouble
-> IO (CDouble)
wrapEvaluateFun :: EvaluateFun a -> StablePtr a -> Ptr CDouble ->
Ptr CDouble -> CInt -> CDouble -> IO (CDouble)
wrapEvaluateFun fun inst x g n step = do
let nInt = fromIntegral n
instV <- deRefStablePtr inst
xFp <- newForeignPtr_ x
let xVec = M.unsafeFromForeignPtr xFp 0 nInt
gFp <- newForeignPtr_ g
let gVec = M.unsafeFromForeignPtr gFp 0 nInt
fun instV xVec gVec n step
type ProgressFun a =
a
-> IOVector CDouble
-> IOVector CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CInt
-> CInt
-> CInt
-> IO (CInt)
wrapProgressFun :: ProgressFun a -> StablePtr a -> Ptr CDouble ->
Ptr CDouble-> CDouble -> CDouble -> CDouble -> CDouble ->
CInt -> CInt -> CInt -> IO (CInt)
wrapProgressFun fun inst x g fx xn gn step n k ls = do
let nInt = fromIntegral n
instV <- deRefStablePtr inst
xFp <- newForeignPtr_ x
let xVec = M.unsafeFromForeignPtr xFp 0 nInt
gFp <- newForeignPtr_ g
let gVec = M.unsafeFromForeignPtr xFp 0 nInt
fun instV xVec gVec fx xn gn step n k ls
lbfgs :: LBFGSParameters
-> EvaluateFun a
-> ProgressFun a
-> a
-> [Double]
-> IO(LBFGSResult, [Double])
lbfgs lbfgsParams evalFun progressFun inst p = lbfgs_ lbfgsParams
(wrapEvaluateFun evalFun)
(wrapProgressFun progressFun) inst p
lbfgs_ :: LBFGSParameters -> CEvaluateFun a -> CProgressFun a -> a ->
[Double] -> IO(LBFGSResult, [Double])
lbfgs_ lbfgsParams evalFun progressFun inst p = do
(n, pVec) <- listToVector p
let param = withParam lbfgsParams n
instP <- newStablePtr inst
paramP <- malloc
poke paramP param
evalW <- c_lbfgs_evaluate_t_wrap evalFun
progressW <- c_lbfgs_progress_t_wrap progressFun
r <- c_lbfgs n pVec nullPtr evalW progressW instP paramP
freeHaskellFunPtr progressW
freeHaskellFunPtr evalW
free paramP
freeStablePtr instP
rl <- vectorToList n pVec
freeVector pVec
return (deriveResult $ CLBFGSResult r, rl)