module Numeric.Optimization.Algorithms.HagerZhang05
(
optimize
,Function(..)
,Gradient(..)
,Combined(..)
,PointMVector
,GradientMVector
,Simple
,Mutable
,Result(..)
,Statistics(..)
,defaultParameters
,Parameters(..)
,Verbose(..)
,LineSearch(..)
,StopRules(..)
,EstimateError(..)
,TechParameters(..)
) where
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Control.Exception (bracket)
import Control.Monad.Primitive (PrimMonad(..))
import Foreign
import Foreign.C
optimize :: (G.Vector v Double)
=> Parameters
-> Double
-> v Double
-> Function t1
-> Gradient t2
-> Maybe (Combined t3)
-> IO (S.Vector Double, Result, Statistics)
optimize params grad_tol initial f g c = do
let n = G.length initial
x <- GM.unstream $ G.stream initial
let mf = mutableF f
mg = mutableG g
mc = maybe (combine mf mg) mutableC c
cf = prepareF mf
cg = prepareG mg
cc = prepareC mc
(ret, stats) <-
SM.unsafeWith x $ \x_ptr ->
alloca $ \stats_ptr ->
alloca $ \param_ptr ->
bracket (mkCFunction cf) freeHaskellFunPtr $ \cf_ptr ->
bracket (mkCGradient cg) freeHaskellFunPtr $ \cg_ptr ->
bracket (mkCCombined cc) freeHaskellFunPtr $ \cc_ptr ->
allocateWorkSpace n $ \work_ptr -> do
poke param_ptr params
ret <- cg_descent x_ptr (fromIntegral n)
stats_ptr param_ptr grad_tol
cf_ptr cg_ptr cc_ptr work_ptr
stats <- peek stats_ptr
return (intToResult ret, stats)
x' <- G.unsafeFreeze x
return $ ret `seq` (x', ret, stats)
allocateWorkSpace :: Int -> (Ptr Double -> IO a) -> IO a
allocateWorkSpace n
| size < threshold = allocaBytes size
| otherwise = bracket (mallocBytes size) free
where
size = 4 * n * sizeOf (undefined :: Double)
threshold = 4096
type CFunction = Ptr Double -> CInt -> IO Double
type CGradient = Ptr Double -> Ptr Double -> CInt -> IO ()
type CCombined = Ptr Double -> Ptr Double -> CInt -> IO Double
foreign import ccall safe "cg_user.h"
cg_descent :: Ptr Double
-> CInt
-> Ptr Statistics
-> Ptr Parameters
-> Double
-> FunPtr CFunction
-> FunPtr CGradient
-> FunPtr CCombined
-> Ptr Double
-> IO CInt
foreign import ccall "wrapper" mkCFunction :: CFunction -> IO (FunPtr CFunction)
foreign import ccall "wrapper" mkCGradient :: CGradient -> IO (FunPtr CGradient)
foreign import ccall "wrapper" mkCCombined :: CCombined -> IO (FunPtr CCombined)
data Simple
data Mutable
type PointMVector m = SM.MVector (PrimState m) Double
type GradientMVector m = SM.MVector (PrimState m) Double
data Function t where
VFunction :: G.Vector v Double
=> (v Double -> Double)
-> Function Simple
MFunction :: (forall m. PrimMonad m
=> PointMVector m
-> m Double)
-> Function Mutable
mutableF :: Function t -> Function Mutable
mutableF (VFunction f) = MFunction f'
where
f' mx = do
let s = GM.length mx
mz <- GM.new s
let go i | i > s = return ()
| otherwise = GM.unsafeRead mx i >>=
GM.unsafeWrite mz i >> go (i+1)
go 0
z <- G.unsafeFreeze mz
return (f z)
mutableF (MFunction f) = MFunction f
prepareF :: Function Mutable -> CFunction
prepareF (MFunction f) =
\x_ptr n -> do
let n' = fromIntegral n
x_fptr <- newForeignPtr_ x_ptr
f (SM.unsafeFromForeignPtr x_fptr 0 n')
prepareF _ = error "HagerZhang05.prepareF: never here"
data Gradient t where
VGradient :: G.Vector v Double
=> (v Double -> v Double)
-> Gradient Simple
MGradient :: (forall m. PrimMonad m
=> PointMVector m
-> GradientMVector m
-> m ())
-> Gradient Mutable
mutableG :: Gradient t -> Gradient Mutable
mutableG (VGradient f) = MGradient f'
where
f' mx mret = do
let s = GM.length mx
mz <- GM.new s
let go i | i > s = return ()
| otherwise = GM.unsafeRead mx i >>=
GM.unsafeWrite mz i >> go (i+1)
go 0
z <- G.unsafeFreeze mz
let !r = f z
let s' = min s (G.length r)
go' i | i > s' = return ()
| otherwise = let !x = G.unsafeIndex r i
in GM.unsafeWrite mret i x >> go (i+1)
go' 0
mutableG (MGradient f) = MGradient f
prepareG :: Gradient Mutable -> CGradient
prepareG (MGradient f) =
\ret_ptr x_ptr n -> do
let n' = fromIntegral n
x_fptr <- newForeignPtr_ x_ptr
ret_fptr <- newForeignPtr_ ret_ptr
f (SM.unsafeFromForeignPtr x_fptr 0 n')
(SM.unsafeFromForeignPtr ret_fptr 0 n')
prepareG _ = error "HagerZhang05.prepareG: never here"
data Combined t where
VCombined :: G.Vector v Double
=> (v Double -> (Double, v Double))
-> Combined Simple
MCombined :: (forall m. PrimMonad m
=> PointMVector m
-> GradientMVector m
-> m Double)
-> Combined Mutable
mutableC :: Combined t -> Combined Mutable
mutableC (VCombined f) = MCombined f'
where
f' mx mret = do
let s = GM.length mx
mz <- GM.new s
let go i | i > s = return ()
| otherwise = GM.unsafeRead mx i >>=
GM.unsafeWrite mz i >> go (i+1)
go 0
z <- G.unsafeFreeze mz
let !(v,r) = f z
let s' = min s (G.length r)
go' i | i > s' = return ()
| otherwise = let !x = G.unsafeIndex r i
in GM.unsafeWrite mret i x >> go (i+1)
go' 0
return v
mutableC (MCombined f) = MCombined f
prepareC :: Combined Mutable -> CCombined
prepareC (MCombined f) =
\ret_ptr x_ptr n -> do
let n' = fromIntegral n
x_fptr <- newForeignPtr_ x_ptr
ret_fptr <- newForeignPtr_ ret_ptr
f (SM.unsafeFromForeignPtr x_fptr 0 n')
(SM.unsafeFromForeignPtr ret_fptr 0 n')
prepareC _ = error "HagerZhang05.prepareC: never here"
combine :: Function Mutable -> Gradient Mutable -> Combined Mutable
combine (MFunction f) (MGradient g) =
MCombined $ \mx mret -> g mx mret >> f mx
combine _ _ = error "HagerZhang05.combine: never here"
data Result =
ToleranceStatisfied
| FunctionChange
| MaxTotalIter
| NegativeSlope
| MaxSecantIter
| NotDescent
| LineSearchFailsInitial
| LineSearchFailsBisection
| LineSearchFailsUpdate
| DebugTol
| FunctionValueNaN
| StartFunctionValueNaN
deriving (Eq, Ord, Show, Read, Enum)
intToResult :: CInt -> Result
intToResult (2) = FunctionValueNaN
intToResult (1) = StartFunctionValueNaN
intToResult 0 = ToleranceStatisfied
intToResult 1 = FunctionChange
intToResult 2 = MaxTotalIter
intToResult 3 = NegativeSlope
intToResult 4 = MaxSecantIter
intToResult 5 = NotDescent
intToResult 6 = LineSearchFailsInitial
intToResult 7 = LineSearchFailsBisection
intToResult 8 = LineSearchFailsUpdate
intToResult 9 = DebugTol
intToResult 10 = error $ "HagerZhang05.intToResult: out of memory?! how?!"
intToResult x = error $ "HagerZhang05.intToResult: unknown value " ++ show x
data Statistics = Statistics {
finalValue :: Double
,gradNorm :: Double
,totalIters :: CInt
,funcEvals :: CInt
,gradEvals :: CInt
} deriving (Eq, Ord, Show, Read)
instance Storable Statistics where
sizeOf _ = (28)
alignment _ = alignment (undefined :: Double)
peek ptr = do
v_finalValue <- (\hsc_ptr -> peekByteOff hsc_ptr 0) ptr
v_gradNorm <- (\hsc_ptr -> peekByteOff hsc_ptr 8) ptr
v_totalIters <- (\hsc_ptr -> peekByteOff hsc_ptr 16) ptr
v_funcEvals <- (\hsc_ptr -> peekByteOff hsc_ptr 20) ptr
v_gradEvals <- (\hsc_ptr -> peekByteOff hsc_ptr 24) ptr
return Statistics {finalValue = v_finalValue
,gradNorm = v_gradNorm
,totalIters = v_totalIters
,funcEvals = v_funcEvals
,gradEvals = v_gradEvals}
poke ptr s = do
(\hsc_ptr -> pokeByteOff hsc_ptr 0) ptr (finalValue s)
(\hsc_ptr -> pokeByteOff hsc_ptr 8) ptr (gradNorm s)
(\hsc_ptr -> pokeByteOff hsc_ptr 16) ptr (totalIters s)
(\hsc_ptr -> pokeByteOff hsc_ptr 20) ptr (funcEvals s)
(\hsc_ptr -> pokeByteOff hsc_ptr 24) ptr (gradEvals s)
defaultParameters :: Parameters
defaultParameters =
unsafePerformIO $ do
alloca $ \ptr -> do
cg_default ptr
peek ptr
foreign import ccall unsafe "cg_user.h"
cg_default :: Ptr Parameters -> IO ()
data Parameters = Parameters {
printFinal :: Bool
,printParams :: Bool
,verbose :: Verbose
,lineSearch :: LineSearch
,qdecay :: Double
,stopRules :: StopRules
,estimateError :: EstimateError
,quadraticStep :: Maybe Double
,debugTol :: Maybe Double
,initialStep :: Maybe Double
,maxItersFac :: Double
,nexpand :: CInt
,nsecant :: CInt
,restartFac :: Double
,funcEpsilon :: Double
,nanRho :: Double
,techParameters :: TechParameters
} deriving (Eq, Ord, Show, Read)
instance Storable Parameters where
sizeOf _ = (192)
alignment _ = alignment (undefined :: Double)
peek ptr = do
v_printFinal <- (\hsc_ptr -> peekByteOff hsc_ptr 0) ptr
v_printParams <- (\hsc_ptr -> peekByteOff hsc_ptr 8) ptr
v_verbose <- (\hsc_ptr -> peekByteOff hsc_ptr 4) ptr
v_awolfe <- (\hsc_ptr -> peekByteOff hsc_ptr 12) ptr
v_awolfefac <- (\hsc_ptr -> peekByteOff hsc_ptr 16) ptr
v_qdecay <- (\hsc_ptr -> peekByteOff hsc_ptr 24) ptr
v_stopRule <- (\hsc_ptr -> peekByteOff hsc_ptr 32) ptr
v_stopRuleFac <- (\hsc_ptr -> peekByteOff hsc_ptr 36) ptr
v_estimateError <- (\hsc_ptr -> peekByteOff hsc_ptr 44) ptr
v_estimateEps <- (\hsc_ptr -> peekByteOff hsc_ptr 48) ptr
v_quadraticStep <- (\hsc_ptr -> peekByteOff hsc_ptr 56) ptr
v_quadraticCut <- (\hsc_ptr -> peekByteOff hsc_ptr 60) ptr
v_debug <- (\hsc_ptr -> peekByteOff hsc_ptr 68) ptr
v_debugTol <- (\hsc_ptr -> peekByteOff hsc_ptr 72) ptr
v_initialStep <- (\hsc_ptr -> peekByteOff hsc_ptr 80) ptr
v_maxItersFac <- (\hsc_ptr -> peekByteOff hsc_ptr 88) ptr
v_nexpand <- (\hsc_ptr -> peekByteOff hsc_ptr 96) ptr
v_nsecant <- (\hsc_ptr -> peekByteOff hsc_ptr 100) ptr
v_restartFac <- (\hsc_ptr -> peekByteOff hsc_ptr 104) ptr
v_funcEpsilon <- (\hsc_ptr -> peekByteOff hsc_ptr 112) ptr
v_nanRho <- (\hsc_ptr -> peekByteOff hsc_ptr 120) ptr
v_delta <- (\hsc_ptr -> peekByteOff hsc_ptr 128) ptr
v_sigma <- (\hsc_ptr -> peekByteOff hsc_ptr 136) ptr
v_gamma <- (\hsc_ptr -> peekByteOff hsc_ptr 144) ptr
v_rho <- (\hsc_ptr -> peekByteOff hsc_ptr 152) ptr
v_eta <- (\hsc_ptr -> peekByteOff hsc_ptr 160) ptr
v_psi0 <- (\hsc_ptr -> peekByteOff hsc_ptr 168) ptr
v_psi1 <- (\hsc_ptr -> peekByteOff hsc_ptr 176) ptr
v_psi2 <- (\hsc_ptr -> peekByteOff hsc_ptr 184) ptr
let tech = TechParameters {techDelta = v_delta
,techSigma = v_sigma
,techGamma = v_gamma
,techRho = v_rho
,techEta = v_eta
,techPsi0 = v_psi0
,techPsi1 = v_psi1
,techPsi2 = v_psi2}
let b :: CInt -> Bool; b = (/= 0)
return Parameters {printFinal = b v_printFinal
,printParams = b v_printParams
,verbose = case v_verbose :: CInt of
0 -> Quiet
1 -> Verbose
_ -> VeryVerbose
,lineSearch = if b v_awolfe
then ApproximateWolfe
else AutoSwitch v_awolfefac
,qdecay = v_qdecay
,stopRules = if b v_stopRule
then DefaultStopRule v_stopRuleFac
else AlternativeStopRule
,estimateError = if b v_estimateError
then RelativeEpsilon v_estimateEps
else AbsoluteEpsilon v_estimateEps
,quadraticStep = if b v_quadraticStep
then Just v_quadraticCut
else Nothing
,debugTol = if b v_debug
then Just v_debugTol
else Nothing
,initialStep = case v_initialStep of
0 -> Nothing
x -> Just x
,maxItersFac = v_maxItersFac
,nexpand = v_nexpand
,nsecant = v_nsecant
,restartFac = v_restartFac
,funcEpsilon = v_funcEpsilon
,nanRho = v_nanRho
,techParameters = tech}
poke ptr p = do
let i b = if b p then 1 else (0 :: CInt)
m b = maybe (0 :: CInt) (const 1) (b p)
(\hsc_ptr -> pokeByteOff hsc_ptr 0) ptr (i printFinal)
(\hsc_ptr -> pokeByteOff hsc_ptr 8) ptr (i printParams)
(\hsc_ptr -> pokeByteOff hsc_ptr 4) ptr (case verbose p of
Quiet -> 0 :: CInt
Verbose -> 1
VeryVerbose -> 3)
let (awolfe, awolfefac) = case lineSearch p of
ApproximateWolfe -> (1, 0)
AutoSwitch x -> (0, x)
(\hsc_ptr -> pokeByteOff hsc_ptr 12) ptr (awolfe :: CInt)
(\hsc_ptr -> pokeByteOff hsc_ptr 16) ptr awolfefac
(\hsc_ptr -> pokeByteOff hsc_ptr 24) ptr (qdecay p)
let (stopRule, stopRuleFac) = case stopRules p of
DefaultStopRule x -> (1, x)
AlternativeStopRule -> (0, 0)
(\hsc_ptr -> pokeByteOff hsc_ptr 32) ptr (stopRule :: CInt)
(\hsc_ptr -> pokeByteOff hsc_ptr 36) ptr stopRuleFac
let (pertRule, eps) = case estimateError p of
RelativeEpsilon x -> (1,x)
AbsoluteEpsilon x -> (0,x)
(\hsc_ptr -> pokeByteOff hsc_ptr 44) ptr (pertRule :: CInt)
(\hsc_ptr -> pokeByteOff hsc_ptr 48) ptr eps
(\hsc_ptr -> pokeByteOff hsc_ptr 56) ptr (m quadraticStep)
(\hsc_ptr -> pokeByteOff hsc_ptr 60) ptr (maybe 0 id $ quadraticStep p)
(\hsc_ptr -> pokeByteOff hsc_ptr 68) ptr (m debugTol)
(\hsc_ptr -> pokeByteOff hsc_ptr 72) ptr (maybe 0 id $ debugTol p)
(\hsc_ptr -> pokeByteOff hsc_ptr 80) ptr (maybe 0 id $ initialStep p)
(\hsc_ptr -> pokeByteOff hsc_ptr 88) ptr (maxItersFac p)
(\hsc_ptr -> pokeByteOff hsc_ptr 96) ptr (nexpand p)
(\hsc_ptr -> pokeByteOff hsc_ptr 100) ptr (nsecant p)
(\hsc_ptr -> pokeByteOff hsc_ptr 104) ptr (restartFac p)
(\hsc_ptr -> pokeByteOff hsc_ptr 112) ptr (funcEpsilon p)
(\hsc_ptr -> pokeByteOff hsc_ptr 120) ptr (nanRho p)
(\hsc_ptr -> pokeByteOff hsc_ptr 128) ptr (techDelta $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 136) ptr (techSigma $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 144) ptr (techGamma $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 152) ptr (techRho $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 160) ptr (techEta $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 168) ptr (techPsi0 $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 176) ptr (techPsi1 $ techParameters p)
(\hsc_ptr -> pokeByteOff hsc_ptr 184) ptr (techPsi2 $ techParameters p)
data TechParameters = TechParameters {
techDelta :: Double
,techSigma :: Double
,techGamma :: Double
,techRho :: Double
,techEta :: Double
,techPsi0 :: Double
,techPsi1 :: Double
,techPsi2 :: Double
} deriving (Eq, Ord, Show, Read)
data Verbose =
Quiet
| Verbose
| VeryVerbose
deriving (Eq, Ord, Show, Read, Enum)
data LineSearch =
ApproximateWolfe
| AutoSwitch Double
deriving (Eq, Ord, Show, Read)
data StopRules =
DefaultStopRule Double
| AlternativeStopRule
deriving (Eq, Ord, Show, Read)
data EstimateError =
AbsoluteEpsilon Double
| RelativeEpsilon Double
deriving (Eq, Ord, Show, Read)