{-# LANGUAGE DeriveFunctor #-} module RL.MC where import qualified Data.HashMap.Strict as HashMap import qualified Prelude import RL.Imports import RL.Types data MC_Opts = MC_Opts { o_alpha :: MC_Number , o_maxlen :: Int , o_maxlen_reward :: MC_Number } deriving (Show) defaultOpts = MC_Opts { o_alpha = 0.1 , o_maxlen = 1000 , o_maxlen_reward = -100.0 } type MC_Number = Double type Q s a = M s a MC_Number type V s = HashMap s MC_Number emptyQ :: MC_Number -> Q s a emptyQ = initM q2v :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> V s q2v = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l)) -- FIXME: handle missing states case diffV :: (Eq s, Hashable s) => V s -> V s -> MC_Number diffV tgt src = sum (HashMap.intersectionWith (\a b -> abs ((a) - (b))) tgt src) toV :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> V s toV = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l)) class (Fractional num, Ord s, Ord a, Show s, Show a, Bounded a, Enum a) => MC_Problem pr s a num | pr->s, pr->a, pr->num where mc_is_terminal :: pr -> s -> Bool mc_reward :: pr -> s -> a -> s -> num queryQ s = HashMap.toList <$> get_s s <$> get modifyQ s a f = modify (modify_s_a s a f) data MC pr m s a = MC { mc_pr :: pr , mc_transition :: s -> a -> m s } -- | MC-ES learning algorithm, pg 5.4. Alpha-learing rate is used instead of -- total averaging, maximum episode length is limited to make sure policy it -- terminates mc_es_learn :: (Monad m, Hashable s, Hashable a, MC_Problem pr s a MC_Number) => MC_Opts -> Q s a -> s -> a -> MC pr m s a -> m (Q s a) mc_es_learn MC_Opts{..} q0 s0 a0 mc@(MC pr transition) = do flip execStateT q0 $ do {- Build an episode -} ep <- do view _3 <$> do loopM (s0,a0,[],True) (view _4) $ \(s,a,ep,_) -> do s' <- lift $ mc_transition mc s a a' <- fst . maximumBy (compare`on`snd) <$> queryQ s' if length ep > o_maxlen then return (s', a', (s,a,s',o_maxlen_reward):ep, False) else do r <- pure $ mc_reward pr s a s' if mc_is_terminal pr s' then return (s', a', (s,a,s',r):ep, False) else do return (s', a', (s,a,s',r):ep, True) {- Build first-visit revard map -} rm <- do fst <$> do flip execStateT (mempty, 0) $ do forM ep $ \(s,a,s',r) -> do modify $ \(m,g) -> (HashMap.insert (s,a) (g+r) m, g+r) {- Update Q -} forM_ (HashMap.toList rm) $ \((s,a),g) -> do modifyQ s a $ \q -> q + o_alpha*(g - q)