{-# LANGUAGE FlexibleInstances #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Arrow.Transformer.Reader
-- Copyright   :  (c) Ross Paterson 2003
-- License     :  BSD-style (see the LICENSE file in the distribution)
--
-- Maintainer  :  ross@soi.city.ac.uk
-- Stability   :  experimental
-- Portability :  non-portable (multi-parameter type classes)
--
-- Arrow transformer that adds a read-only state (i.e. an environment).

module Control.Arrow.Transformer.Reader(
		ReaderArrow(ReaderArrow),
		runReader,
		ArrowAddReader(..),
	) where

import Control.Arrow.Internals
import Control.Arrow.Operations
import Control.Arrow.Transformer

import Control.Applicative
import Control.Arrow
import Control.Category
import Data.Monoid

import Prelude hiding (id,(.))

-- | An arrow type that augments an existing arrow with a read-only state
-- (or environment).  The 'ArrowReader' class contains the operations
-- on this state.

newtype ReaderArrow r a b c = ReaderArrow (a (b, r) c)

-- | Encapsulation of a state-reading computation, taking a value for the
-- state.
--
-- Typical usage in arrow notation:
--
-- >	proc p -> ...
-- >		(|runReader cmd|) env

runReader :: Arrow a => ReaderArrow r a e b -> a (e,r) b
runReader (ReaderArrow f) = f

-- arrow transformer

instance Arrow a => ArrowTransformer (ReaderArrow r) a where
	lift f = ReaderArrow (arr fst >>> f)

-- liftings of standard classes

instance Arrow a => Category (ReaderArrow r a) where
	id = ReaderArrow (arr fst)
	ReaderArrow f . ReaderArrow g =
		ReaderArrow (f . first g . arr dupenv)
		where	dupenv (a, r) = ((a, r), r)

instance Arrow a => Arrow (ReaderArrow r a) where
	arr f = ReaderArrow (arr (f . fst))
	first (ReaderArrow f) = ReaderArrow (arr swapsnd >>> first f)

swapsnd :: ((a, r), b) -> ((a, b), r)
swapsnd ~(~(a, r), b) = ((a, b), r)

instance ArrowChoice a => ArrowChoice (ReaderArrow r a) where
	left (ReaderArrow f) = ReaderArrow (arr dist' >>> left f)
		where	dist' :: (Either b c, r) -> Either (b, r) c
			dist' (Left b, r) = Left (b, r)
			dist' (Right c, _) = Right c

instance ArrowApply a => ArrowApply (ReaderArrow r a) where
	app = ReaderArrow
		(arr (\((ReaderArrow f, a), r) -> (f, (a, r))) >>> app)

instance ArrowZero a => ArrowZero (ReaderArrow r a) where
	zeroArrow = lift zeroArrow

instance ArrowPlus a => ArrowPlus (ReaderArrow r a) where
	ReaderArrow f <+> ReaderArrow g = ReaderArrow (f <+> g)

instance ArrowLoop a => ArrowLoop (ReaderArrow r a) where
	loop (ReaderArrow f) = ReaderArrow (loop (arr swapsnd >>> f))

-- new instances

instance Arrow a => ArrowReader r (ReaderArrow r a) where
	readState = ReaderArrow (arr snd)
	newReader (ReaderArrow f) = ReaderArrow (arr fst >>> f)

instance Arrow a => ArrowAddReader r (ReaderArrow r a) a where
	liftReader = lift
	elimReader = runReader

-- liftings of other classes

instance ArrowCircuit a => ArrowCircuit (ReaderArrow r a) where
	delay x = lift (delay x)

instance ArrowError ex a => ArrowError ex (ReaderArrow r a) where
	raise = lift raise
	handle (ReaderArrow f) (ReaderArrow h) = ReaderArrow (handle f (arr swapsnd >>> h))
	tryInUnless (ReaderArrow f) (ReaderArrow s) (ReaderArrow h) =
		ReaderArrow (tryInUnless f (arr swapsnd >>> s) (arr swapsnd >>> h))
	newError (ReaderArrow f) = ReaderArrow (newError f)

instance ArrowState s a => ArrowState s (ReaderArrow r a) where
	fetch = lift fetch
	store = lift store

instance ArrowWriter s a => ArrowWriter s (ReaderArrow r a) where
	write = lift write
	newWriter (ReaderArrow f) = ReaderArrow (newWriter f)

-- Promotions of encapsulation operators.

instance ArrowAddError ex a a' =>
		ArrowAddError ex (ReaderArrow r a) (ReaderArrow r a') where
	liftError (ReaderArrow f) = ReaderArrow (liftError f)
	elimError (ReaderArrow f) (ReaderArrow h) =
		ReaderArrow (elimError f (arr swapsnd >>> h))

instance ArrowAddState s a a' =>
		ArrowAddState s (ReaderArrow r a) (ReaderArrow r a') where
	liftState (ReaderArrow f) = ReaderArrow (liftState f)
	elimState (ReaderArrow f) = ReaderArrow (arr swapsnd >>> elimState f)

-- instance ArrowAddReader r a a' =>
-- 		ArrowAddReader r (ReaderArrow r a) (ReaderArrow r a') where
-- 	elimReader (ReaderArrow f) = ReaderArrow (arr swapsnd >>> elimReader f)

instance ArrowAddWriter s a a' =>
		ArrowAddWriter s (ReaderArrow r a) (ReaderArrow r a') where
	liftWriter (ReaderArrow f) = ReaderArrow (liftWriter f)
	elimWriter (ReaderArrow f) = ReaderArrow (elimWriter f)

-- Other instances

instance Arrow a => Functor (ReaderArrow r a b) where
	fmap f g = g >>> arr f

instance Arrow a => Applicative (ReaderArrow r a b) where
	pure x = arr (const x)
	f <*> g = f &&& g >>> arr (uncurry id)

instance ArrowPlus a => Alternative (ReaderArrow r a b) where
	empty = zeroArrow
	f <|> g = f <+> g

instance ArrowPlus a => Monoid (ReaderArrow r a b c) where
	mempty = zeroArrow
	mappend f g = f <+> g