{-# LANGUAGE RecordWildCards #-}
module HABQTlib.Data.Particle
( Particles(..)
, genParticles
, updateParticles
, ParticleHierarchy
, initialiseParticleHierarchy
, updateParticleHierarchy
, getMixedEstimate
, foldOverPts
, reduceParticlesToMean
, effectiveSize
, ResampleArgs(..)
, resampleMultinom
, resample
, ecdf
, icdf
, nudgeParticle
) where
import Control.Monad (when)
import Control.Newtype.Generics (over)
import Data.Bool.HT (if', select)
import qualified Data.Vector as V
import HABQTlib.Data
import HABQTlib.RandomStates
import Numeric.LinearAlgebra (Complex(..))
import qualified Numeric.LinearAlgebra as LA
import qualified System.Random.MWC as MWC
import Text.Printf (printf)
data Particles = Particles
{ ptsRank :: Rank
, ptsWeight :: Weight
, ptsNumber :: NumberOfParticles
, ptsParticles :: V.Vector WeighedDensityMatrix
} deriving (Show)
type ParticleHierarchy = V.Vector Particles
genParticles :: Dim -> Rank -> NumberOfParticles -> IO Particles
genParticles d r n =
let w = 1 / fromIntegral n
in Particles r 1 n . fmap (mkWDM w) <$> V.replicateM n (genDM d r)
reduceParticlesToMean :: Particles -> WeighedDensityMatrix
reduceParticlesToMean Particles {..} =
let wdm = V.foldl1' (<+>) ptsParticles
wdmw = over WeighedDensityMatrix (\(w, dm) -> (w * ptsWeight, dm)) wdm
in truncateRank ptsRank wdmw
foldOverPts ::
(DensityMatrix -> a)
-> (Weight -> a -> b)
-> (c -> b -> c)
-> c
-> Particles
-> c
foldOverPts f wf fld z Particles {..} =
let wm (WeighedDensityMatrix (w, dm)) = wf w (f dm)
in V.foldl' (\l r -> fld l (wm r)) z ptsParticles
fullDataLogLikelihood :: [PureStateVector] -> DensityMatrix -> Double
fullDataLogLikelihood vs dm =
let lps = map (log . (`pureStateLikelihood` dm)) vs
in sum lps
updateParticles :: PureStateVector -> Particles -> Particles
updateParticles sv pts@Particles {..} =
let updateF :: WeighedDensityMatrix -> WeighedDensityMatrix
updateF (WeighedDensityMatrix (w, dm)) = WeighedDensityMatrix (wnew, dm)
where
wnew = w * pureStateLikelihood sv dm
upts = V.map updateF ptsParticles
uw = V.foldl' (\acc (WeighedDensityMatrix (w, _)) -> acc + w) 0 upts
npts = V.map (over WeighedDensityMatrix (\(w, dm) -> (w / uw, dm))) upts
in pts {ptsWeight = ptsWeight * uw, ptsParticles = npts}
initialiseParticleHierarchy :: Dim -> NumberOfParticles -> IO ParticleHierarchy
initialiseParticleHierarchy d n = V.generateM d (\r -> genParticles d (r + 1) n)
updateParticleHierarchy ::
PureStateVector -> ParticleHierarchy -> ParticleHierarchy
updateParticleHierarchy sv ph =
let uph = V.map (updateParticles sv) ph
wgts = V.map ptsWeight uph
nwgts = V.map (/ V.sum wgts) wgts
in V.zipWith (\x w -> x {ptsWeight = w}) uph nwgts
getMixedEstimate :: ParticleHierarchy -> DensityMatrix
getMixedEstimate ph =
let rankEstimates = V.map reduceParticlesToMean ph
WeighedDensityMatrix (_, result) = V.foldl1' (<+>) rankEstimates
in result
effectiveSize :: Particles -> Double
effectiveSize Particles {..} =
let ss = V.sum . V.map ((^ (2 :: Int)) . fst . getWDM) $ ptsParticles
wa = V.foldl' (flip ((+) . fst . getWDM)) 0 ptsParticles
in wa ^ (2 :: Int) / ss
nudgeParticle ::
Dim
-> Weight
-> WeighedDensityMatrix
-> IO WeighedDensityMatrix
nudgeParticle dim weightFraction (WeighedDensityMatrix (w, dm)) = do
DensityMatrix nudgeDM <- svToDM <$> genPureSV dim
let dmw = LA.scale (1 - (weightFraction :+ 0)) (getDensityMatrix dm)
dmwn = LA.scale (weightFraction :+ 0) nudgeDM
return $ WeighedDensityMatrix (w, DensityMatrix $ dmw + dmwn)
ecdf :: V.Vector WeighedDensityMatrix -> V.Vector Double
ecdf = V.postscanl' (+) 0 . V.map (fst . getWDM)
icdf :: V.Vector Double -> Double -> Int
icdf cdf x =
let tIdx = V.length cdf - 1
go (lIdx, hIdx) =
let mIdx =
truncate $ ((fromIntegral lIdx :: Double) + fromIntegral hIdx) / 2
in select
(go (lIdx, mIdx))
[ (lIdx == hIdx, lIdx)
, (lIdx + 1 == hIdx, if' (cdf V.! lIdx > x) lIdx hIdx)
, (cdf V.! mIdx < x, go (mIdx, hIdx))
]
in select (go (0, tIdx)) [(x <= V.head cdf, 0), (x > V.last cdf, tIdx)]
resampleMultinom :: MWC.GenIO -> Particles -> IO Particles
resampleMultinom gen pts@Particles {..} = do
us <- MWC.uniformVector gen ptsNumber
let cdf = ecdf ptsParticles
idxs = V.map (icdf cdf) us
w = 1 / fromIntegral ptsNumber
pointR = over WeighedDensityMatrix (\(_, dm) -> (w, dm))
resampled = V.map (ptsParticles V.!) idxs
normed = V.map pointR resampled
return pts {ptsParticles = normed}
mhmcStep ::
MWC.GenIO
-> Dim
-> Double
-> [PureStateVector]
-> Particles
-> IO (Double, Particles)
mhmcStep gen dim rw ms pts@Particles {..} = do
let cr wdm wdm' =
exp
(fullDataLogLikelihood ms (snd . getWDM $ wdm') -
fullDataLogLikelihood ms (snd . getWDM $ wdm))
newParticles <-
V.mapM (fmap (truncateRank ptsRank) . nudgeParticle dim rw) ptsParticles
us <- V.replicateM ptsNumber (MWC.uniform gen :: IO Double)
let crs = V.zipWith cr ptsParticles newParticles
change = V.zipWith (<=) us crs
accRate =
(fromIntegral . V.length . V.filter id) change / fromIntegral ptsNumber
rwdms = V.zipWith3 if' change newParticles ptsParticles
final = pts {ptsParticles = rwdms}
return (accRate, final)
resampleMHMC ::
ResampleArgs
-> DensityMatrix
-> Double
-> Int
-> [PureStateVector]
-> Particles
-> IO Particles
resampleMHMC ra@ResampleArgs {..} estimate wr iter mts pts = do
(accRate, resampled) <- mhmcStep argGen argDim wr mts pts
when (argOut == FullOutput) $
printf
"(Weight of new particle: %8.3g, MHMC acceptance rate: %8.3g)\n"
wr
accRate
let (iter', wr') =
select
(iter + 1, wr)
[ (accRate < 1e-2, (0, wr * 0.25))
, (accRate < 1e-1, (0, wr * 0.5))
, (iter < argMinIter, (iter + 1, wr))
, (accRate < 0.33, (0, wr * 0.5))
]
if iter > argMinIter
then return resampled
else resampleMHMC ra estimate wr' iter' mts resampled
data ResampleArgs = ResampleArgs
{ argOut :: OutputVerb
, argGen :: MWC.GenIO
, argDim :: Dim
, argMinIter :: MHMCiter
}
resample :: ResampleArgs -> [PureStateVector] -> Particles -> IO Particles
resample ra@ResampleArgs {..} mts pts@Particles {..} = do
let estimate = getMixedEstimate . V.singleton $ pts
nudgeW = 0.95
when (argOut == FullOutput) $ do
putStrLn ""
putStrLn $ "resampling rank " ++ show ptsRank
putStrLn ""
rm <- resampleMultinom argGen pts
resampleMHMC ra estimate nudgeW 0 mts rm