module Math.Optimization.SPSA.Types (
defaultSPSA, checkSPSA,
StateSPSA,
getLoss, getConstraint, getStop, peelAll, getIterations,
setLoss, setConstraint, pushStopCrit, setGainA, setGainC, setPerturbation,
incrementIteration,
StoppingCriteria(..), shouldStop,
LossFn, ConstraintFn
) where
import Numeric.LinearAlgebra (Vector, norm2)
import Control.Monad.State (State, get, put)
data SPSA = SPSA {
iterations :: Int,
lossFn :: LossFn,
constraintFn :: ConstraintFn,
gainA, gainC :: [Double],
stoppingCrits :: [StoppingCriteria],
perturbation :: [Vector Double]
}
type LossFn = Vector Double -> Double
type ConstraintFn = Vector Double -> Vector Double
defaultSPSA :: SPSA
defaultSPSA = SPSA {
iterations = 0,
lossFn = \_ -> error "No loss function implemented",
constraintFn = id,
gainA = [],
gainC = [],
perturbation = [],
stoppingCrits = []
}
checkSPSA :: Vector Double -> StateSPSA ()
checkSPSA t = do
spsa <- get
return $ lossFn spsa t
if gainA spsa == [] then error "gain sequence a_k must be specified" else return ()
if gainC spsa == [] then error "gain sequence c_k must be specified" else return ()
if perturbation spsa == [] then error "perturbation vector sequence must be specified" else return ()
if stoppingCrits spsa == [] then error "a stopping criteria must be specified" else return ()
data StoppingCriteria = Iterations Int | NormDiff Double deriving (Eq)
shouldStop :: StoppingCriteria -> Int -> Vector Double -> Vector Double -> Bool
shouldStop (Iterations n) i _ _ = i >= n
shouldStop (NormDiff diff) _ lst cur = norm2 (cur lst) < diff
type StateSPSA = State SPSA
getSPSA :: (SPSA -> a) -> StateSPSA a
getSPSA extractor = get >>= return . extractor
setSPSA :: (a -> SPSA -> SPSA) -> a -> StateSPSA ()
setSPSA updater val = get >>= put . updater val
getLoss :: StateSPSA LossFn
getLoss = getSPSA lossFn
setLoss :: LossFn -> StateSPSA ()
setLoss = setSPSA $ \loss spsa -> spsa { lossFn = loss }
getConstraint :: StateSPSA ConstraintFn
getConstraint = getSPSA constraintFn
setConstraint :: ConstraintFn -> StateSPSA ()
setConstraint = setSPSA $ \constraint spsa -> spsa { constraintFn = constraint }
getStop :: StateSPSA [StoppingCriteria]
getStop = getSPSA stoppingCrits
pushStopCrit :: StoppingCriteria -> StateSPSA ()
pushStopCrit = setSPSA $ \sc spsa -> let crits = sc : stoppingCrits spsa in spsa { stoppingCrits = crits }
getIterations :: StateSPSA Int
getIterations = getSPSA iterations
incrementIteration :: StateSPSA Int
incrementIteration = do
spsa <- get
let iter = 1 + (iterations spsa)
put spsa { iterations = iter }
return iter
peel :: (SPSA -> [a]) -> ([a] -> SPSA -> SPSA) -> StateSPSA a
peel getter updater = do
spsa <- get
let sq = getter spsa
let ([nxt],rst) = splitAt 1 sq
put $ updater rst spsa
return nxt
peelA :: StateSPSA Double
peelA = peel gainA (\as spsa -> spsa { gainA = as })
peelC :: StateSPSA Double
peelC = peel gainC (\cs spsa -> spsa { gainC = cs })
peelD :: StateSPSA (Vector Double)
peelD = peel perturbation (\ds spsa -> spsa { perturbation = ds })
peelAll :: StateSPSA (Double, Double, Vector Double)
peelAll = do
a <- peelA
c <- peelC
d <- peelD
return (a, c, d)
setGainA :: [Double] -> StateSPSA ()
setGainA = setSPSA $ \as spsa -> spsa { gainA = as }
setGainC :: [Double] -> StateSPSA ()
setGainC = setSPSA $ \cs spsa -> spsa { gainC = cs }
setPerturbation :: [Vector Double] -> StateSPSA ()
setPerturbation = setSPSA $ \ds spsa -> spsa { perturbation = ds }