module RL.TD where
import qualified Prelude
import qualified Data.HashMap.Strict as HashMap
import RL.Imports
import RL.Types
import RL.Utils (eps_greedy_action)
data Q_Opts = Q_Opts {
o_alpha :: TD_Number
, o_gamma :: TD_Number
, o_eps :: TD_Number
} deriving (Show)
defaultOpts = Q_Opts {
o_alpha = 0.1
, o_gamma = 1.0
, o_eps = 0.3
}
type TD_Number = Double
type Q s a = M s a TD_Number
emptyQ :: TD_Number -> Q s a
emptyQ = initM
toV :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> HashMap s TD_Number
toV = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l))
class (Monad m, Eq s, Hashable s, Show s, Eq a, Hashable a, Enum a, Bounded a, Show a) =>
TD_Problem pr m s a | pr -> m, pr -> s , pr -> a where
td_is_terminal :: pr -> s -> Bool
td_greedy :: pr -> Bool -> a -> a
td_reward :: pr -> s -> a -> s -> TD_Number
td_transition :: pr -> s -> a -> Q s a -> m s
td_modify :: pr -> s -> a -> Q s a -> m ()
queryQ s = HashMap.toList <$> get_s s <$> get
modifyQ pr s a f = modify (modify_s_a s a f) >> get >>= lift . td_modify pr s a
action pr s eps = queryQ s >>= eps_greedy_action eps (td_greedy pr)
transition pr s a = get >>= lift . td_transition pr s a
q_learn :: (MonadRnd g m, TD_Problem pr m s a) => Q_Opts -> Q s a -> s -> pr -> m (s, Q s a)
q_learn Q_Opts{..} q0 s0 pr = do
flip runStateT q0 $ do
loopM s0 (not . td_is_terminal pr) $ \s -> do
(a,_) <- action pr s o_eps
s' <- transition pr s a
r <- pure $ td_reward pr s a s'
max_qs' <- snd . maximumBy (compare`on`snd) <$> queryQ s'
modifyQ pr s a $ \q -> q + o_alpha * (r + o_gamma * max_qs' q)
return s'
q_exec :: (MonadRnd g m, TD_Problem pr m s a) => Q_Opts -> Q s a -> s -> pr -> m s
q_exec Q_Opts{..} q0 s0 pr = do
flip evalStateT q0 $ do
loopM s0 (not . td_is_terminal pr) $ \s -> do
a <- fst . maximumBy (compare`on`snd) <$> queryQ s
s' <- transition pr s a
return s'