module RL.DP where
import qualified Data.List as List
import qualified Data.Map.Strict as Map
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Set as Set
import Prelude hiding(break)
import RL.Imports
type Probability = Rational
type P s a = HashMap s (Set (a,Probability))
type V s num = HashMap s num
diffV :: (Eq s, Hashable s, Num num) => V s num -> V s num -> num
diffV tgt src = sum (HashMap.intersectionWith (\a b -> abs (a b)) tgt src)
class (Ord s, Ord a, Fractional num, Ord num, Hashable s) =>
DP_Problem pr s a num | pr -> s, pr -> a, pr -> num where
dp_states :: pr -> Set s
dp_actions :: pr -> s -> Set a
dp_transitions :: pr -> s -> a -> Set (s, Probability)
dp_reward :: pr -> s -> a -> s -> num
dp_terminal_states :: pr -> Set s
action :: (DP_Problem pr s a num) => pr -> P s a -> s -> Set (a,Probability)
action pr p s = p HashMap.! s
initV :: (DP_Problem pr s a num)
=> pr -> num -> V s num
initV pr num = HashMap.fromList $ map (\s -> (s,num)) (Set.toList $ dp_states pr)
invariant_probable_actions :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool
invariant_probable_actions pr =
flip all (dp_states pr) $ \s ->
flip all (dp_actions pr s) $ \a ->
case sum (map snd (Set.toList (dp_transitions pr s a))) of
1 -> True
x -> error $ "Total probability of state " ++ show s ++ " action " ++ show a ++ " sum up to " ++ show x
invariant_closed_transition :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool
invariant_closed_transition pr =
flip all (dp_states pr) $ \s ->
flip all (dp_actions pr s) $ \a ->
flip all (dp_transitions pr s a) $ \(s',p) ->
case (Set.member s' (dp_states pr)) of
True -> True
False -> error $ "State " ++ show s ++ ", action " ++ show a ++ " lead to invalid state " ++ show s'
invariant_no_dead_states :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool
invariant_no_dead_states pr =
flip all (dp_states pr) $ \s ->
case (member s (dp_terminal_states pr), Set.null (dp_actions pr s)) of
(True,True) -> True
(True,False) -> error $ "Terminal state " ++ show s ++ " is not dead end"
(False,False) -> True
(False,True) -> error $ "State " ++ show s ++ " is dead end"
invariant_terminal :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool
invariant_terminal pr =
flip all (dp_terminal_states pr) $ \st ->
case Set.member st (dp_states pr) of
True -> True
False -> error $ "State " ++ show st ++ " is not a valid state"
invariant_policy_actions :: (DP_Problem pr s a num, Ord a, Show s, Show a) => P s a -> pr -> Bool
invariant_policy_actions p pr =
flip all (dp_states pr) $ \s ->
flip all (action pr p s) $ \(a, prob) ->
case Set.member a (dp_actions pr s) of
True -> True
False -> error $ "Policy from state " ++ show s ++ " leads to invalid action " ++ show a
invariant_policy_prob :: (DP_Problem pr s a num, Ord a, Show s, Show a) => P s a -> pr -> Bool
invariant_policy_prob p pr =
flip all (dp_states pr) $ \s ->
let
as = Set.toList (action pr p s)
in
case sum $ map snd as of
1 -> True
0 | null as -> True
x -> error $ "Policy state " ++ show s ++ " probabilities sum up to " ++ show x
invariant :: (DP_Problem pr s a num, Show s, Show a, Ord a) => pr -> Bool
invariant pr = all ($ pr) [
invariant_probable_actions
, invariant_closed_transition
, invariant_terminal
, invariant_policy_actions (uniformPolicy pr)
, invariant_policy_prob (uniformPolicy pr)
, invariant_no_dead_states
]
policy_eq :: (Eq a, DP_Problem pr s a num) => pr -> P s a -> P s a -> Bool
policy_eq pr p1 p2 = all (\s -> (action pr p1 s) == (action pr p2 s)) (dp_states pr)
uniformPolicy :: (Ord a, DP_Problem pr s a num) => pr -> P s a
uniformPolicy pr =
HashMap.fromList $ flip map (Set.toList (dp_states pr)) $ \s ->
let
as = dp_actions pr s
in
(s, Set.map (\a -> (a, 1%(toInteger $ length as))) as)
data Opts num s a = Opts {
eo_gamma :: num
, eo_etha :: num
, eo_max_iter :: Int
} deriving(Show)
defaultOpts :: (Fractional num) => Opts num s a
defaultOpts = Opts {
eo_gamma = 0.9
, eo_etha = 0.1
, eo_max_iter = 10^3
}
data EvalState num s = EvalState {
_es_delta :: num
, _es_v :: V s num
, _es_v' :: V s num
, _es_iter :: Int
} deriving(Show)
makeLenses ''EvalState
initEvalState :: (Fractional num) => V s num -> EvalState num s
initEvalState v = EvalState 0 v v 0
policy_eval :: (Monad m, DP_Problem pr s a num)
=> Opts num s a -> P s a -> V s num -> (DP pr m s a num) -> m (V s num)
policy_eval Opts{..} p v (DP pr _) = do
let sum l f = List.sum <$> forM (Set.toList l) f
view es_v <$> do
flip execStateT (initEvalState v) $ loop $ do
i <- use es_iter
when (i > eo_max_iter1) $ do
break ()
es_delta %= const 0
forM_ (dp_states pr) $ \s -> do
v_s <- (HashMap.!s) <$> use es_v
v's <- do
sum (action pr p s) $ \(a, fromRational -> pa) -> do
(pa*) <$> do
sum (dp_transitions pr s a) $ \(s', fromRational -> p) -> do
v_s' <- (HashMap.!s') <$> use es_v
pure $ p * ((dp_reward pr s a s') + eo_gamma * (v_s'))
es_v' %= (HashMap.insert s v's)
es_delta %= (`max`(abs (v's v_s)))
d <- use es_delta
when (d < eo_etha) $ do
break ()
v' <- use es_v'
es_v %= const v'
es_iter %= (+1)
policy_action_value :: (DP_Problem pr s a num) => Opts num s a -> s -> a -> V s num -> pr -> num
policy_action_value Opts{..} s a v pr =
List.sum $ flip map (Set.toList $ dp_transitions pr s a) $ \(s', fromRational -> p) ->
p * ((dp_reward pr s a s') + eo_gamma * (v HashMap.! s'))
policy_improve :: (Monad m, DP_Problem pr s a num)
=> Opts num s a -> V s num -> DP pr m s a num -> m (P s a)
policy_improve o v (DP pr _) = do
let sum l f = List.sum <$> forM (Set.toList l) f
flip execStateT mempty $ do
forM_ (dp_states pr) $ \s -> do
(maxv, maxa) <- do
foldlM (\(val,maxa) a -> do
pi_s <- pure $ policy_action_value o s a v pr
return $
if Set.null maxa then
(pi_s, Set.singleton a)
else
if pi_s > val then
(pi_s, Set.singleton a)
else
if pi_s < val then
(val,maxa)
else
(val, Set.insert a maxa)
) (0, Set.empty) (dp_actions pr s)
let nmax = toInteger (Set.size maxa)
modify $ HashMap.insert s (Set.map (\a -> (a,1%nmax)) maxa)
data DP pr m s a num = DP {
dp_pr :: pr
, dp_trace :: V s num -> P s a -> m ()
}
policy_iteration :: (Monad m, DP_Problem pr s a num, Ord a)
=> Opts num s a -> P s a -> V s num -> (DP pr m s a num) -> m (V s num, P s a)
policy_iteration o p v dpr@(DP pr trace) = do
let up = lift . lift
(v', p') <-
flip execStateT (v, p) $ do
loop $ do
(v,p) <- get
v' <- up $ policy_eval o p v dpr
p' <- up $ policy_improve o v' dpr
up $ trace v' p'
put (v', p')
when (policy_eq pr p p') $ do
break ()
return (v',p')