module Numeric.MCMC.AffineInvariantEnsemble (
Config(..), AffineTransform(..), Trace
, runChain, initializeEnsemble, defaultSeed, prune
) where
import Numeric.MCMC.Util
import Data.List (foldl')
import Data.List.Split (splitEvery)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Vector.Unboxed as U
import Data.Vector.Generic (Vector)
import System.Random.MWC
import System.Random.MWC.Distributions (standard)
import Control.Monad.ST (ST, runST)
import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef)
import Control.Monad (forM_, replicateM, when)
import Data.Maybe (fromJust)
import Control.Monad.Primitive (PrimMonad)
import Data.Word (Word32)
libError :: String
libError = "Numeric.MCMC.AffineInvariantEnsemble."
data Config = Config { ensemble :: !(IntMap [Double])
, accepts :: !Int
}
data AffineTransform a = Stretch | Walk a deriving (Eq, Read)
newtype Trace a = Trace [[a]]
instance Show (Trace Double) where
show (Trace xs) = unlines $ map (unwords . map show) xs
stretch :: [Double]
-> [Double]
-> Int
-> Double
-> ([Double] -> Double)
-> ([Double], Double)
stretch xk xj nw z target = (proposal, logAP)
where proposal = zipWith (+) (map (*z) xk) (map (*(1z)) xj)
logAP = let val = target proposal target xk + (fromIntegral nw 1) * log z
in if val > 0 then 0 else val
walk :: (Fractional c, Num t, Ord t)
=> [c]
-> [[c]]
-> [c]
-> ([c] -> t)
-> ([c], t)
walk xk xjs zs target = let val = target proposal target xk in (proposal, if val > 0 then 0 else val)
where nxjs = length xjs
xjsmean = map (/ fromIntegral nxjs) $ listReducer (length xk) xjs
xjscentd = zipWith (zipWith ()) xjs (replicate nxjs xjsmean)
listReducer n = foldl' (zipWith (+)) (replicate n 0.0)
proposal = zipWith (+) xk (listReducer nxjs $ zipWith (\z -> map (*z)) zs xjscentd)
initializeEnsemble :: PrimMonad m => Int -> Int -> m Config
initializeEnsemble nw nd
| nw < 2 = error $ libError ++ "initializeEnsemble: Number of walkers must be >= 2."
| nd < 1 = error $ libError ++ "initializeEnsemble: Number of dimensions must be >= 1."
| nw < nd = error $ libError ++ "initializeEnsemble: Number of walkers should be greater than number of dimensions."
| otherwise = do
gen <- create
inits <- replicateM (nw * nd) (uniformR (0 :: Double, 1) gen)
let arr = IntMap.fromList $ zip [1..] (splitEvery nd inits)
initConfig = Config {ensemble = arr, accepts = 0}
return initConfig
moveEnsemble :: Int
-> STRef s Config
-> Gen s
-> ([Double] -> Double)
-> AffineTransform Int
-> ST s (IntMap [Double])
moveEnsemble numWalkers _ _ _ (Walk n)
| n >= numWalkers = error "Numeric.MCMC.AffineInvariantEnsemble moveEnsemble: size of `Walk` sub-ensemble must be strictly less than size of full ensemble."
moveEnsemble numWalkers stConfig gen target xform = do
forM_ [1..numWalkers] $ \targetWalkerIndex -> do
config <- readSTRef stConfig
let walkers = ensemble config
nacc = accepts config
targetWalker = fromJust $ IntMap.lookup targetWalkerIndex walkers
zc <- uniformR (0, 1) gen
seed <- save gen
let (proposal, logAcceptanceProb) = case xform of
Stretch -> runST $ do
g0 <- restore seed
altWalkerIndex <- genDiffInt targetWalkerIndex (1, numWalkers) g0
z0 <- uniformR (0, 1) g0
let z = 0.5 * (z0 + 1) * (z0 + 1)
altWalker = fromJust $ IntMap.lookup altWalkerIndex walkers
return $ stretch targetWalker altWalker numWalkers z target
Walk n -> runST $ do
g0 <- restore seed
zs <- replicateM n (standard g0)
g1 <- save g0
let subMapKeys = sample n (IntMap.keys walkers) g1
altWalkerEnsemble = createEnsemble subMapKeys
where createEnsemble = map (\k -> fromJust (IntMap.lookup k walkers))
return $ walk targetWalker altWalkerEnsemble zs target
when (zc <= exp logAcceptanceProb) $
writeSTRef stConfig Config {ensemble = IntMap.update (\_ -> Just proposal) targetWalkerIndex walkers, accepts = nacc + 1}
endConfig <- readSTRef stConfig
let endPosition = ensemble endConfig
return endPosition
runChain :: Vector v Word32 => Int -> ([Double] -> Double) -> Config -> v Word32 -> AffineTransform Int -> (Config, Trace Double)
runChain steps target initConfig seed xform
| steps < 1 = error $ libError ++ "runChain: `steps` must be >= 1."
| otherwise = runST $ do
let nw = IntMap.size (ensemble initConfig)
gen <- initialize seed
config <- newSTRef initConfig
frames <- replicateM steps (moveEnsemble nw config gen target xform)
let trace = Trace (map snd $ concatMap IntMap.toList frames)
results <- readSTRef config
return (results, trace)
prune :: Int -> Trace Double -> Trace Double
prune n (Trace xs) = Trace (drop n xs)
defaultSeed :: U.Vector Word32
defaultSeed = U.singleton 42