{-# LANGUAGE FlexibleInstances, FlexibleContexts #-}

-- | A Haskell implementation of Goodman & Weare (2010)'s /affine invariant ensemble MCMC/, a family of Markov
--   Chain Monte Carlo methods that can efficiently sample from highly skewed or anisotropic distributions. 
--
--   See 'runChain' for an overview of use, and <http://msp.berkeley.edu/camcos/2010/5-1/p04.xhtml> for details 
--   of the general sampling routine.
module Numeric.MCMC.AffineInvariantEnsemble ( 
           -- * Data structures
             Config(..), AffineTransform(..), Trace
           -- * Chain management
           , runChain, initializeEnsemble, defaultSeed, prune
           ) where

import Numeric.MCMC.Util 
import Data.List                        (foldl')
import Data.List.Split                  (splitEvery) 
import Data.IntMap.Strict               (IntMap)
import qualified Data.IntMap.Strict  as  IntMap
import qualified Data.Vector.Unboxed as  U
import Data.Vector.Generic              (Vector)
import System.Random.MWC 
import System.Random.MWC.Distributions  (standard)
import Control.Monad.ST                 (ST, runST)
import Data.STRef                       (STRef, newSTRef, readSTRef, writeSTRef)
import Control.Monad                    (forM_, replicateM, when)
import Data.Maybe                       (fromJust)
import Control.Monad.Primitive          (PrimMonad)
import Data.Word                        (Word32)

libError :: String
libError = "Numeric.MCMC.AffineInvariantEnsemble."

-- | A data type holding the configuration of the Markov chain at any given epoch.  `ensemble` accesses
--   the IntMap constituting the current ensemble, while 'accepts' records the number of proposals that 
--   have been accepted up to the current epoch.
data Config   = Config { ensemble    :: !(IntMap [Double]) 
                       , accepts     :: {-# UNPACK #-} !Int
                       } 

-- | A data type representing the affine transformation to be used on particles in an ensemble.  The general-purpose
--   /stretch/ and /walk/ transformations described in Goodman and Weare (2010) are supported.
data AffineTransform a = Stretch | Walk a deriving (Eq, Read)
 
-- | A data type holding a chain's trace.
newtype Trace a = Trace [[a]]

instance Show (Trace Double) where
    show (Trace xs) = unlines $ map (unwords . map show) xs

-- The `stretch` affine transform.  
stretch :: [Double]             -- ^ Focused walker
        -> [Double]             -- ^ Alternate walker
        -> Int                  -- ^ Number of walkers in ensemble
        -> Double               -- ^ Random double drawn from appropriate distribution
        -> ([Double] -> Double) -- ^ Target function
        -> ([Double], Double)   -- ^ Tuple containing proposed move and its log acceptance prob
stretch xk xj nw z target = (proposal, logAP)
    where proposal = zipWith (+) (map (*z) xk) (map (*(1-z)) xj)
          logAP    = let val = target proposal - target xk + (fromIntegral nw - 1) * log z
                     in  if val > 0 then 0 else val

-- The `walk` affine transform.
walk :: (Fractional c, Num t, Ord t) 
     => [c]                     -- ^ Focused walker
     -> [[c]]                   -- ^ Sub-ensemble of n alternate walkers
     -> [c]                     -- ^ n random doubles drawn from a standard normal
     -> ([c] -> t)              -- ^ Target function
     -> ([c], t)                -- ^ Tuple containing proposed move and its log acceptance prob
walk xk xjs zs target = let val = target proposal - target xk in (proposal, if val > 0 then 0 else val)
    where nxjs          = length xjs
          xjsmean       = map (/ fromIntegral nxjs) $ listReducer (length xk) xjs
          xjscentd      = zipWith (zipWith (-)) xjs (replicate nxjs xjsmean)
          listReducer n = foldl' (zipWith (+)) (replicate n 0.0)
          proposal      = zipWith (+) xk (listReducer nxjs $ zipWith (\z -> map (*z)) zs xjscentd)

-- | Naively initialize an ensemble.  Creates a 'Config' containing /nw/ walkers, each of dimension /nd/,
--   and initializes 'accepts' at 0.  Each dimensional element is drawn randomly from (0,1] (using a different 
--   seed than 'defaultSeed').  
--
--   If this is expected to be a region of low density, you'll probably want to specify
--   your own initial configuration.
initializeEnsemble :: PrimMonad m => Int -> Int -> m Config 
initializeEnsemble nw nd
    | nw < 2    = error $ libError ++ "initializeEnsemble: Number of walkers must be >= 2."
    | nd < 1    = error $ libError ++ "initializeEnsemble: Number of dimensions must be >= 1."
    | nw < nd   = error $ libError ++ "initializeEnsemble: Number of walkers should be greater than number of dimensions."
    | otherwise = do
        gen   <- create
        inits <- replicateM (nw * nd) (uniformR (0 :: Double, 1) gen)
        let arr        = IntMap.fromList $ zip [1..] (splitEvery nd inits)
            initConfig = Config {ensemble = arr, accepts = 0}
        return initConfig

-- Move an ensemble forward one step.
moveEnsemble :: Int                          -- Number of walkers in the ensemble.
             -> STRef s Config               -- A STRef storing the ensemble configuration.
             -> Gen s                        -- Random number generator
             -> ([Double] -> Double)         -- Desired target 
             -> AffineTransform Int          -- Affine transformation to use
             -> ST s (IntMap [Double]) 
moveEnsemble numWalkers _        _   _      (Walk n) 
    | n >= numWalkers = error "Numeric.MCMC.AffineInvariantEnsemble moveEnsemble: size of `Walk` sub-ensemble must be strictly less than size of full ensemble."
moveEnsemble numWalkers stConfig gen target xform = do
    forM_ [1..numWalkers] $ \targetWalkerIndex -> do
        -- Inits
        config <- readSTRef stConfig 
        let walkers = ensemble config
            nacc    = accepts  config
            targetWalker = fromJust $ IntMap.lookup targetWalkerIndex walkers

        -- Randomness
        zc     <- uniformR (0, 1) gen        
        seed   <- save gen

        -- Generate proposal and acceptance probability
        let (proposal, logAcceptanceProb) = case xform of
                Stretch -> runST $ do
                    g0             <- restore seed
                    altWalkerIndex <- genDiffInt targetWalkerIndex (1, numWalkers) g0
                    z0             <- uniformR (0, 1) g0
                    let z         = 0.5 * (z0 + 1) * (z0 + 1)
                        altWalker = fromJust $ IntMap.lookup altWalkerIndex walkers
                    return $ stretch targetWalker altWalker numWalkers z target

                Walk n -> runST $ do 
                    g0 <- restore seed
                    zs <- replicateM n (standard g0)
                    g1 <- save g0
                    let subMapKeys        = sample n (IntMap.keys walkers) g1
                        altWalkerEnsemble = createEnsemble subMapKeys
                            where createEnsemble = map (\k -> fromJust (IntMap.lookup k walkers))   
                    return $ walk targetWalker altWalkerEnsemble zs target

        -- Compare and possible accept proposal
        when (zc <= exp logAcceptanceProb) $ 
            writeSTRef stConfig Config {ensemble = IntMap.update (\_ -> Just proposal) targetWalkerIndex walkers, accepts = nacc + 1}

    -- Return end state
    endConfig       <- readSTRef stConfig
    let endPosition =  ensemble endConfig 
    return endPosition

-- | Typical use:
--
--   @
--   runChain steps target initConfig seed xform
--   @
--
--   Run the Markov chain for /steps/ epochs.  The chain will wander over /target/'s parameter space such that,
--   after \"long enough\", the points it visits will effectively be independent samples from the distribution 
--   proportional to /target/.  The Markov chain procedes by possibly applying an affine transformation to each of 
--   the particles contained in 'ensemble' /initConfig/, sequentially.
--
--   This function will return a tuple contanining 1) the 'Config' corresponding to the final epoch of the chain, 
--   and 2) the chain's 'Trace'.  The 'Trace' can be used, for example, to approximate integrals of the target function.
--
--   The /target/ must be a function with type @[Double] -> Double@.  Functions using more complicated data structures
--   internally can simply be curried to this type.  
--
--   Examples of use can be found at <http://github.com/jtobin/affine-invariant-ensemble-mcmc/Numeric/MCMC/Examples>.
runChain :: Vector v Word32 => Int -> ([Double] -> Double) -> Config -> v Word32 -> AffineTransform Int -> (Config, Trace Double)
runChain steps target initConfig seed xform 
    | steps < 1 = error $ libError ++ "runChain: `steps` must be >= 1."
    | otherwise = runST $ do
        let nw =  IntMap.size (ensemble initConfig)
        gen    <- initialize seed
        config <- newSTRef initConfig

        frames <- replicateM steps (moveEnsemble nw config gen target xform)  
        let trace = Trace (map snd $ concatMap IntMap.toList frames)

        results <- readSTRef config 
        return (results, trace)

-- | Prune some initial epochs (i.e. suspected burn-in) from a 'Trace'.
prune :: Int -> Trace Double -> Trace Double
prune n (Trace xs) = Trace (drop n xs)

-- | The default seed provided by the library.  This seed is different from the one used internally in 'initializeEnsemble'.
defaultSeed :: U.Vector Word32
defaultSeed = U.singleton 42