{-# LANGUAGE CPP #-}
module Numeric.SGD.Grad
( Grad
, empty
, add
, addL
, fromList
, fromLogList
, toList
, parUnions
) where
import Data.List (foldl')
import Control.Applicative ((<$>), (<*>))
import Control.Monad.Par (Par, runPar, get)
#if MIN_VERSION_containers(0,4,2)
import Control.Monad.Par (spawn)
#else
import Control.DeepSeq (deepseq)
import Control.Monad.Par (spawn_)
#endif
#if MIN_VERSION_containers(0,5,0)
import qualified Data.IntMap.Strict as M
#else
import qualified Data.IntMap as M
#endif
import Numeric.SGD.LogSigned
type Grad = M.IntMap LogSigned
{-# INLINE insertWith #-}
insertWith :: (a -> a -> a) -> M.Key -> a -> M.IntMap a -> M.IntMap a
#if MIN_VERSION_containers(0,5,0)
insertWith = M.insertWith
#elif MIN_VERSION_containers(0,4,1)
insertWith = M.insertWith'
#else
insertWith f k x m =
M.alter g k m
where
g my = case my of
Nothing -> Just x
Just y ->
let z = f x y
in z `seq` Just z
#endif
{-# INLINE add #-}
add :: Grad -> Int -> Double -> Grad
add grad i y = insertWith (+) i (logSigned y) grad
{-# INLINE addL #-}
addL :: Grad -> Int -> LogSigned -> Grad
addL grad i y = insertWith (+) i y grad
{-# INLINE fromList #-}
fromList :: [(Int, Double)] -> Grad
fromList =
let ins grad (i, y) = add grad i y
in foldl' ins empty
{-# INLINE fromLogList #-}
fromLogList :: [(Int, LogSigned)] -> Grad
fromLogList =
let ins grad (i, y) = addL grad i y
in foldl' ins empty
{-# INLINE toList #-}
toList :: Grad -> [(Int, Double)]
toList =
let unLog (i, x) = (i, toNorm x)
in map unLog . M.assocs
{-# INLINE empty #-}
empty :: Grad
empty = M.empty
parUnions :: [Grad] -> Grad
parUnions [] = error "parUnions: empty list"
parUnions xs = runPar (parUnionsP xs)
parUnionsP :: [Grad] -> Par Grad
parUnionsP [x] = return x
parUnionsP zs = do
let (xs, ys) = split zs
#if MIN_VERSION_containers(0,4,2)
xsP <- spawn (parUnionsP xs)
ysP <- spawn (parUnionsP ys)
M.unionWith (+) <$> get xsP <*> get ysP
#else
xsP <- spawn_ (parUnionsP xs)
ysP <- spawn_ (parUnionsP ys)
x <- M.unionWith (+) <$> get xsP <*> get ysP
M.elems x `deepseq` return x
#endif
where
split [] = ([], [])
split (x:[]) = ([x], [])
split (x:y:rest) =
let (xs, ys) = split rest
in (x:xs, y:ys)