module RL.MC where
import qualified Data.HashMap.Strict as HashMap
import qualified Prelude
import RL.Imports
import RL.Types
data MC_Opts = MC_Opts {
o_alpha :: MC_Number
, o_maxlen :: Int
, o_maxlen_reward :: MC_Number
} deriving (Show)
defaultOpts = MC_Opts {
o_alpha = 0.1
, o_maxlen = 1000
, o_maxlen_reward = 100.0
}
type MC_Number = Double
type Q s a = M s a MC_Number
type V s = HashMap s MC_Number
emptyQ :: MC_Number -> Q s a
emptyQ = initM
q2v :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> V s
q2v = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l))
diffV :: (Eq s, Hashable s) => V s -> V s -> MC_Number
diffV tgt src = sum (HashMap.intersectionWith (\a b -> abs ((a) (b))) tgt src)
toV :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> V s
toV = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l))
class (Fractional num, Ord s, Ord a, Show s, Show a, Bounded a, Enum a) =>
MC_Problem pr s a num | pr->s, pr->a, pr->num where
mc_is_terminal :: pr -> s -> Bool
mc_reward :: pr -> s -> a -> s -> num
queryQ s = HashMap.toList <$> get_s s <$> get
modifyQ s a f = modify (modify_s_a s a f)
data MC pr m s a = MC {
mc_pr :: pr
, mc_transition :: s -> a -> m s
}
mc_es_learn :: (Monad m, Hashable s, Hashable a, MC_Problem pr s a MC_Number)
=> MC_Opts -> Q s a -> s -> a -> MC pr m s a -> m (Q s a)
mc_es_learn MC_Opts{..} q0 s0 a0 mc@(MC pr transition) = do
flip execStateT q0 $ do
ep <- do
view _3 <$> do
loopM (s0,a0,[],True) (view _4) $ \(s,a,ep,_) -> do
s' <- lift $ mc_transition mc s a
a' <- fst . maximumBy (compare`on`snd) <$> queryQ s'
if length ep > o_maxlen then
return (s', a', (s,a,s',o_maxlen_reward):ep, False)
else do
r <- pure $ mc_reward pr s a s'
if mc_is_terminal pr s' then
return (s', a', (s,a,s',r):ep, False)
else do
return (s', a', (s,a,s',r):ep, True)
rm <- do
fst <$> do
flip execStateT (mempty, 0) $ do
forM ep $ \(s,a,s',r) -> do
modify $ \(m,g) -> (HashMap.insert (s,a) (g+r) m, g+r)
forM_ (HashMap.toList rm) $ \((s,a),g) -> do
modifyQ s a $ \q -> q + o_alpha*(g q)