module Control.Arrow.Reader (module Control.Arrow.Reader.Class, ReaderT (..), inlineReaderT, outlineReaderT, withReaderTA, withReaderT) where

import Prelude hiding ((.), id);

import Control.Monad;
import Control.Category;
import Control.Arrow;
import Control.Arrow.Trans;
import Control.Arrow.Reader.Class;

newtype ReaderT r s a b = ReaderT { runReaderT :: r -> s a b };

inlineReaderT :: (ArrowApply s) => ReaderT r s a b -> s (r, a) b;
inlineReaderT = runReaderT >>> arr . (*** id) >>> (>>> app);

outlineReaderT :: (Arrow s) => s (r, a) b -> ReaderT r s a b;
outlineReaderT = ReaderT . flip ((>>>) . arr . (&&& id) . const);

instance ArrowTrans (ReaderT r) where {
  lift = ReaderT . const;
  tmap f = ReaderT . liftM f . runReaderT;
};

instance (Category s) => Category (ReaderT r s) where {
  id = ReaderT (const id);
  ReaderT f . ReaderT g = ReaderT (liftM2 (.) f g);
};

instance (Arrow s) => Arrow (ReaderT r s) where {
  arr = ReaderT . const . arr;
  first  = ReaderT . (.) first  . runReaderT;
  second = ReaderT . (.) second . runReaderT;
};

instance (ArrowApply s) => ArrowApply (ReaderT r s) where {
  app = ReaderT $ \ r -> arr (\ (ReaderT f, x) -> (f r, x)) >>> app;
};

instance (Arrow s) => ArrowReader r (ReaderT r s) where {
  ask   = ReaderT (arr . const);
  local = withReaderT;
};

instance (Arrow s, ArrowTrans xT, Arrow (xT (ReaderT r s))) => ArrowReader r (xT (ReaderT r s)) where {
  ask   = lift ask;
  local = undefined;
};

withReaderTA :: (ArrowApply s) => s q r -> ReaderT r s a b -> ReaderT q s a b;
withReaderTA a = outlineReaderT . (<<< a *** id) . inlineReaderT;

withReaderT :: (q -> r) -> ReaderT r s a b -> ReaderT q s a b;
withReaderT f = ReaderT . (. f) . runReaderT;