{-# LANGUAGE DeriveFunctor #-}
module RL.TD where

import qualified Prelude
import qualified Data.HashMap.Strict as HashMap

import RL.Imports
import RL.Types
import RL.Utils (eps_greedy_action)

data Q_Opts = Q_Opts {
    o_alpha :: TD_Number
  , o_gamma :: TD_Number
  , o_eps :: TD_Number
} deriving (Show)

defaultOpts = Q_Opts {
    o_alpha = 0.1
  , o_gamma = 1.0
  , o_eps = 0.3
  }

type TD_Number = Double

type Q s a = M s a TD_Number

emptyQ :: TD_Number -> Q s a
emptyQ = initM

toV :: (Bounded a, Enum a, Eq a, Hashable a, Eq s, Hashable s) => Q s a -> HashMap s TD_Number
toV = foldMap_s (\(s,l) -> HashMap.singleton s (snd $ layer_s_max l))

class (Monad m, Eq s, Hashable s, Show s, Eq a, Hashable a, Enum a, Bounded a, Show a) =>
    TD_Problem pr m s a | pr -> m, pr -> s , pr -> a where
  td_is_terminal :: pr -> s -> Bool
  td_greedy :: pr -> Bool -> a -> a
  td_reward :: pr -> s -> a -> s -> TD_Number
  td_transition :: pr -> s -> a -> Q s a -> m s
  td_modify :: pr -> s -> a -> Q s a  -> m ()

queryQ s = HashMap.toList <$> get_s s <$> get
modifyQ pr s a f = modify (modify_s_a s a f) >> get >>= lift . td_modify pr s a
action pr s eps = queryQ s >>= eps_greedy_action eps (td_greedy pr)
transition pr s a = get >>= lift . td_transition pr s a

-- | Q-Learning algorithm
q_learn :: (MonadRnd g m, TD_Problem pr m s a) => Q_Opts -> Q s a -> s -> pr -> m (s, Q s a)
q_learn Q_Opts{..} q0 s0 pr = do
  flip runStateT q0 $ do
  loopM s0 (not . td_is_terminal pr) $ \s -> do
    (a,_) <- action pr s o_eps
    s' <- transition pr s a
    r <- pure $ td_reward pr s a s'
    max_qs' <- snd . maximumBy (compare`on`snd) <$> queryQ s'
    modifyQ pr s a $ \q -> q + o_alpha * (r + o_gamma * max_qs' - q)
    return s'

-- | Q-Executive algorithm. Actions are taken greedily, no learning is performed
q_exec :: (MonadRnd g m, TD_Problem pr m s a) => Q_Opts -> Q s a -> s -> pr -> m s
q_exec Q_Opts{..} q0 s0 pr = do
  flip evalStateT q0 $ do
  loopM s0 (not . td_is_terminal pr) $ \s -> do
    a <- fst . maximumBy (compare`on`snd) <$> queryQ s
    s' <- transition pr s a
    return s'