module RL.DP where import qualified Data.List as List import qualified Data.Map.Strict as Map import qualified Data.HashMap.Strict as HashMap import qualified Data.Set as Set import Prelude hiding(break) import RL.Imports -- | Probability [0..1] type Probability = Rational -- | Policy type P s a = HashMap s (Set (a,Probability)) type V s num = HashMap s num -- FIXME: handle missing states case diffV :: (Eq s, Hashable s, Num num) => V s num -> V s num -> num diffV tgt src = sum (HashMap.intersectionWith (\a b -> abs (a - b)) tgt src) -- FIXME: Convert to fold-like style eventially -- | Dynamic Programming Problem. Parameters have the following meaning: @num@ - -- Type of Numbers; @pr@ - the problem; @s@ - State; @a@ - Action class (Ord s, Ord a, Fractional num, Ord num, Hashable s) => DP_Problem pr s a num | pr -> s, pr -> a, pr -> num where dp_states :: pr -> Set s dp_actions :: pr -> s -> Set a dp_transitions :: pr -> s -> a -> Set (s, Probability) dp_reward :: pr -> s -> a -> s -> num -- FIXME: think about splitting terminal and non-terminal states dp_terminal_states :: pr -> Set s action :: (DP_Problem pr s a num) => pr -> P s a -> s -> Set (a,Probability) action pr p s = p HashMap.! s initV :: (DP_Problem pr s a num) => pr -> num -> V s num initV pr num = HashMap.fromList $ map (\s -> (s,num)) (Set.toList $ dp_states pr) -- | For given state, probabilities for all possible action should sum up to 1 invariant_probable_actions :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool invariant_probable_actions pr = flip all (dp_states pr) $ \s -> flip all (dp_actions pr s) $ \a -> case sum (map snd (Set.toList (dp_transitions pr s a))) of 1 -> True x -> error $ "Total probability of state " ++ show s ++ " action " ++ show a ++ " sum up to " ++ show x -- | No action leads to unlisted state invariant_closed_transition :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool invariant_closed_transition pr = flip all (dp_states pr) $ \s -> flip all (dp_actions pr s) $ \a -> flip all (dp_transitions pr s a) $ \(s',p) -> case (Set.member s' (dp_states pr)) of True -> True False -> error $ "State " ++ show s ++ ", action " ++ show a ++ " lead to invalid state " ++ show s' -- | Terminal states are dead ends and non-terminal states are not invariant_no_dead_states :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool invariant_no_dead_states pr = flip all (dp_states pr) $ \s -> case (member s (dp_terminal_states pr), Set.null (dp_actions pr s)) of (True,True) -> True (True,False) -> error $ "Terminal state " ++ show s ++ " is not dead end" (False,False) -> True (False,True) -> error $ "State " ++ show s ++ " is dead end" -- Terminals are valid states invariant_terminal :: (DP_Problem pr s a num, Show s, Show a) => pr -> Bool invariant_terminal pr = flip all (dp_terminal_states pr) $ \st -> case Set.member st (dp_states pr) of True -> True False -> error $ "State " ++ show st ++ " is not a valid state" -- Policy returns valid actions invariant_policy_actions :: (DP_Problem pr s a num, Ord a, Show s, Show a) => P s a -> pr -> Bool invariant_policy_actions p pr = flip all (dp_states pr) $ \s -> flip all (action pr p s) $ \(a, prob) -> case Set.member a (dp_actions pr s) of True -> True False -> error $ "Policy from state " ++ show s ++ " leads to invalid action " ++ show a -- Policy return valid probabilities invariant_policy_prob :: (DP_Problem pr s a num, Ord a, Show s, Show a) => P s a -> pr -> Bool invariant_policy_prob p pr = flip all (dp_states pr) $ \s -> let as = Set.toList (action pr p s) in case sum $ map snd as of 1 -> True 0 | null as -> True x -> error $ "Policy state " ++ show s ++ " probabilities sum up to " ++ show x invariant :: (DP_Problem pr s a num, Show s, Show a, Ord a) => pr -> Bool invariant pr = all ($ pr) [ invariant_probable_actions , invariant_closed_transition , invariant_terminal , invariant_policy_actions (uniformPolicy pr) , invariant_policy_prob (uniformPolicy pr) , invariant_no_dead_states ] policy_eq :: (Eq a, DP_Problem pr s a num) => pr -> P s a -> P s a -> Bool policy_eq pr p1 p2 = all (\s -> (action pr p1 s) == (action pr p2 s)) (dp_states pr) uniformPolicy :: (Ord a, DP_Problem pr s a num) => pr -> P s a uniformPolicy pr = HashMap.fromList $ flip map (Set.toList (dp_states pr)) $ \s -> let as = dp_actions pr s in (s, Set.map (\a -> (a, 1%(toInteger $ length as))) as) data Opts num s a = Opts { eo_gamma :: num -- ^ Forgetness , eo_etha :: num -- ^ policy evaluation precision , eo_max_iter :: Int -- ^ policy evaluation iteration limit, [1..maxBound] } deriving(Show) defaultOpts :: (Fractional num) => Opts num s a defaultOpts = Opts { eo_gamma = 0.9 , eo_etha = 0.1 , eo_max_iter = 10^3 } data EvalState num s = EvalState { _es_delta :: num , _es_v :: V s num , _es_v' :: V s num , _es_iter :: Int } deriving(Show) makeLenses ''EvalState initEvalState :: (Fractional num) => V s num -> EvalState num s initEvalState v = EvalState 0 v v 0 -- | Iterative policy evaluation algorithm -- Figure 4.1, pg.86. policy_eval :: (Monad m, DP_Problem pr s a num) => Opts num s a -> P s a -> V s num -> (DP pr m s a num) -> m (V s num) policy_eval Opts{..} p v (DP pr _) = do let sum l f = List.sum <$> forM (Set.toList l) f view es_v <$> do flip execStateT (initEvalState v) $ loop $ do i <- use es_iter when (i > eo_max_iter-1) $ do break () es_delta %= const 0 forM_ (dp_states pr) $ \s -> do v_s <- (HashMap.!s) <$> use es_v v's <- do sum (action pr p s) $ \(a, fromRational -> pa) -> do (pa*) <$> do sum (dp_transitions pr s a) $ \(s', fromRational -> p) -> do v_s' <- (HashMap.!s') <$> use es_v pure $ p * ((dp_reward pr s a s') + eo_gamma * (v_s')) es_v' %= (HashMap.insert s v's) es_delta %= (`max`(abs (v's - v_s))) d <- use es_delta when (d < eo_etha) $ do break () v' <- use es_v' es_v %= const v' es_iter %= (+1) policy_action_value :: (DP_Problem pr s a num) => Opts num s a -> s -> a -> V s num -> pr -> num policy_action_value Opts{..} s a v pr = List.sum $ flip map (Set.toList $ dp_transitions pr s a) $ \(s', fromRational -> p) -> p * ((dp_reward pr s a s') + eo_gamma * (v HashMap.! s')) policy_improve :: (Monad m, DP_Problem pr s a num) => Opts num s a -> V s num -> DP pr m s a num -> m (P s a) policy_improve o v (DP pr _) = do let sum l f = List.sum <$> forM (Set.toList l) f flip execStateT mempty $ do forM_ (dp_states pr) $ \s -> do (maxv, maxa) <- do foldlM (\(val,maxa) a -> do pi_s <- pure $ policy_action_value o s a v pr return $ if Set.null maxa then (pi_s, Set.singleton a) else if pi_s > val then -- GT (pi_s, Set.singleton a) else if pi_s < val then -- LT (val,maxa) else -- EQ (val, Set.insert a maxa) ) (0, Set.empty) (dp_actions pr s) let nmax = toInteger (Set.size maxa) modify $ HashMap.insert s (Set.map (\a -> (a,1%nmax)) maxa) data DP pr m s a num = DP { dp_pr :: pr , dp_trace :: V s num -> P s a -> m () } policy_iteration :: (Monad m, DP_Problem pr s a num, Ord a) => Opts num s a -> P s a -> V s num -> (DP pr m s a num) -> m (V s num, P s a) policy_iteration o p v dpr@(DP pr trace) = do let up = lift . lift (v', p') <- flip execStateT (v, p) $ do loop $ do (v,p) <- get v' <- up $ policy_eval o p v dpr p' <- up $ policy_improve o v' dpr up $ trace v' p' put (v', p') when (policy_eq pr p p') $ do break () return (v',p')