module RL.Types where import qualified Data.HashMap.Strict as HashMap import qualified Data.HashSet as HashSet import RL.Imports type Layer a num = HashMap a num type Storage s a num = HashMap s (Layer a num) -- | Base container used in most of RL algorithms. @M x0 sto@ describes the -- 2-dimentional array (`Storage` of `Layers`) where each layer containes fixed -- number of elements. New layers are filled with the range of -- @[minBound..maxBound]@ default values @x0@ data M s a num = M { x0 :: num , sto :: Storage s a num } deriving(Show) -- | Initialises new container, set default layer value to @x@ initM :: num -> M s a num initM x = M x HashMap.empty mmod :: (Storage s a num -> Storage s a num) -> M s a num -> M s a num mmod f m = m { sto = f (sto m) } aq0 :: (Eq a, Enum a, Hashable a, Bounded a) => num -> HashMap a num aq0 q0 = HashMap.fromList [(a,q0) | a <- [minBound .. maxBound]] get_s :: (Eq a, Enum a, Hashable a, Bounded a, Eq s, Hashable s) => s -> M s a num -> Layer a num get_s s (M x0 sto) = maybe (aq0 x0) (`HashMap.union` (aq0 x0)) . HashMap.lookup s $ sto layer_s_max :: (Eq a, Enum a, Hashable a, Bounded a, Ord num) => Layer a num -> (a,num) layer_s_max = maximumBy (compare`on`snd) . HashMap.toList get_s_a :: (Eq a, Enum a, Hashable a, Bounded a, Eq s, Hashable s) => s -> a -> M s a num -> num get_s_a s a (M x0 sto) = maybe x0 (maybe x0 id . HashMap.lookup a) . HashMap.lookup s $ sto put_s :: (Eq s, Hashable s, Bounded a, Enum a, Eq a, Hashable a) => s -> HashMap a num -> M s a num -> M s a num put_s s x = mmod $ HashMap.unionWith HashMap.union (HashMap.singleton s x) put_s_a :: (Eq s, Hashable s, Bounded a, Enum a, Eq a, Hashable a) => s -> a -> num -> M s a num -> M s a num put_s_a s a x = put_s s (HashMap.singleton a x) modify_s_a :: (Eq s, Hashable s, Bounded a, Enum a, Eq a, Hashable a) => s -> a -> (num -> num) -> M s a num -> M s a num modify_s_a s a f q = put_s_a s a (f (get_s_a s a q)) q list :: M s a num -> [(s,a,num)] list q = flip concatMap (HashMap.toList (sto q)) $ \(s,aq) -> flip map (HashMap.toList aq) $ \(a,q) -> (s,a,q) foldMap_s :: (Eq a, Bounded a, Enum a, Hashable a, Monoid acc) => ((s,Layer a num) -> acc) -> M s a num -> acc foldMap_s f (M x0 sto) = foldMap (f . (id *** (`HashMap.union`(aq0 x0)))) (HashMap.toList sto) fold_s :: (Eq a, Bounded a, Enum a, Hashable a, Monoid acc) => (acc -> (s,Layer a num) -> acc) -> acc -> M s a num -> acc fold_s f acc0 (M x0 sto) = foldl' go acc0 (HashMap.toList sto) where go acc (s,l) = f acc (s,l`HashMap.union`(aq0 x0))