{-# LANGUAGE GADTs #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} -- |Improved arrows, with a whole host of minor optimisations and instances. module Control.Arrow.Improve(ImproveArrow, lowerImprove, getFunction) where import Prelude hiding (id, (.)) import Control.Category import Control.Arrow import Control.Applicative import Control.Monad import Control.Monad.Zip import Control.Arrow.Transformer import Control.Arrow.Operations import Data.Profunctor import Data.Semigroupoid import Data.Functor.Plus import Data.Functor.Bind import Data.Pointed import Data.Monoid import Data.String -- |Basic improved arrow type. data ImproveArrow a b c where IArr :: (b -> c) -> ImproveArrow a b c IArrow :: (i -> b) -> a b c -> (c -> o) -> ImproveArrow a i o -- |Lower an improved arrow to the original arrow type. -- -- prop> lowerImprove . lift = id -- prop> lift . lowerImprove = id lowerImprove :: (Arrow a) => ImproveArrow a b c -> a b c lowerImprove (IArrow f a g) = f ^>> a >>^ g lowerImprove (IArr f) = arr f -- |Get a function representing the arrow, if it is possible to do so. -- Guarantees are only made when arrows are constructed only from the -- combinators of the Category, Arrow, ArrowChoice and ArrowLoop classes. getFunction :: ImproveArrow a b c -> Maybe (b -> c) getFunction (IArr f) = Just f getFunction (IArrow _ _ _) = Nothing instance (Arrow a) => Category (ImproveArrow a) where id = arr id {-# INLINE id #-} IArr f . IArr g = IArr (f . g) IArr h . IArrow f a g = IArrow f a (h . g) IArrow f a g . IArr h = IArrow (f . h) a g IArrow f a g . IArrow h b i = IArrow h (b >>> arr (f . i) >>> a) g {-# INLINABLE (.) #-} instance (Arrow a) => Arrow (ImproveArrow a) where arr = IArr {-# INLINE arr #-} first (IArr f) = IArr (first f) first (IArrow f a g) = IArrow (first f) (first a) (first g) {-# INLINABLE first #-} second (IArr f) = IArr (second f) second (IArrow f a g) = IArrow (second f) (second a) (second g) {-# INLINABLE second #-} IArr f *** IArr g = IArr (f *** g) IArr h *** IArrow f a g = IArrow (second f) (second a) (h *** g) IArrow f a g *** IArr h = IArrow (first f) (first a) (g *** h) IArrow f a g *** IArrow h b i = IArrow (f *** h) (a *** b) (g *** i) {-# INLINABLE (***) #-} IArr f &&& IArr g = IArr (f &&& g) IArrow f a g &&& IArr h = IArrow (f &&& h) (first a) (first g) IArr h &&& IArrow f a g = IArrow (h &&& f) (second a) (second g) -- TODO: use a rule to use a &&& b instead when f == h IArrow f a g &&& IArrow h b i = IArrow (f &&& h) (a *** b) (g *** i) {-# INLINABLE (&&&) #-} instance (ArrowZero a) => ArrowZero (ImproveArrow a) where zeroArrow = lift zeroArrow {-# INLINE zeroArrow #-} instance (ArrowPlus a) => ArrowPlus (ImproveArrow a) where f <+> g = lift (lowerImprove f <+> lowerImprove g) {-# INLINE (<+>) #-} instance (ArrowChoice a) => ArrowChoice (ImproveArrow a) where left (IArr f) = IArr (left f) left (IArrow f a g) = IArrow (left f) (left a) (left g) {-# INLINE left #-} right (IArr f) = IArr (right f) right (IArrow f a g) = IArrow (right f) (right a) (right g) {-# INLINE right #-} IArr f +++ IArr g = IArr (f +++ g) IArrow f a g +++ IArr h = IArrow (left f) (left a) (g +++ h) IArr h +++ IArrow f a g = IArrow (right f) (right a) (h +++ g) IArrow f a g +++ IArrow h b i = IArrow (f +++ h) (a +++ b) (g +++ i) {-# INLINABLE (+++) #-} IArr f ||| IArr g = IArr (f ||| g) IArrow f a g ||| IArr h = IArrow (left f) (left a) (g ||| h) IArr h ||| IArrow f a g = IArrow (right f) (right a) (h ||| g) -- TODO: use rules to turn the +++ into a ||| on the arrow when g == i IArrow f a g ||| IArrow h b i = IArrow (f +++ h) (a +++ b) (g ||| i) {-# INLINABLE (|||) #-} instance (ArrowApply a) => ArrowApply (ImproveArrow a) where app = lift $ first lowerImprove ^>> app {-# INLINE app #-} instance (ArrowLoop a) => ArrowLoop (ImproveArrow a) where loop (IArr f) = IArr f' where f' x = let (y, k) = f (x, k) in y loop (IArrow f a g) = lift (loop (f ^>> a >>^ g)) {-# INLINE loop #-} instance (ArrowCircuit a) => ArrowCircuit (ImproveArrow a) where delay = lift . delay {-# INLINE delay #-} instance (ArrowState s a) => ArrowState s (ImproveArrow a) where fetch = lift fetch {-# INLINE fetch #-} store = lift store {-# INLINE store #-} instance (ArrowReader r a) => ArrowReader r (ImproveArrow a) where readState = lift readState {-# INLINE readState #-} newReader (IArr f) = lift $ newReader $ arr f newReader (IArrow f a g) = IArrow id (newReader (f ^>> a)) g {-# INLINE newReader #-} instance (Monoid w, ArrowWriter w a) => ArrowWriter w (ImproveArrow a) where write = lift write {-# INLINE write #-} newWriter (IArr f) = IArr ((\x -> (x, mempty)) . f) newWriter (IArrow f a g) = IArrow f (newWriter (a >>^ g)) id {-# INLINABLE newWriter #-} instance (ArrowError ex a) => ArrowError ex (ImproveArrow a) where raise = lift raise {-# INLINE raise #-} handle (IArr f) _ = IArr f handle a@(IArrow _ _ _) e = lift (handle (lowerImprove a) (lowerImprove e)) {-# INLINABLE handle #-} tryInUnless (IArr g) f _ = IArr (\x -> (x, g x)) >>> f tryInUnless a@(IArrow _ _ _) f e = lift (tryInUnless (lowerImprove a) (lowerImprove f) (lowerImprove e)) {-# INLINABLE tryInUnless #-} newError (IArr f) = IArr (Right . f) newError a@(IArrow _ _ _) = lift (newError (lowerImprove a)) {-# INLINABLE newError #-} instance (Arrow a) => Functor (ImproveArrow a b) where fmap f = (>>^ f) {-# INLINE fmap #-} instance (Arrow a) => Applicative (ImproveArrow a b) where pure k = IArr (\_ -> k) {-# INLINE pure #-} f <*> x = (f &&& x) >>^ uncurry id {-# INLINE (<*>) #-} instance (ArrowPlus a) => Alternative (ImproveArrow a b) where empty = zeroArrow {-# INLINE empty #-} (<|>) = (<+>) {-# INLINE (<|>) #-} instance (ArrowApply a) => Monad (ImproveArrow a b) where return = pure {-# INLINE return #-} x >>= f = ((x >>^ f) &&& id) >>> app {-# INLINE (>>=) #-} instance (ArrowPlus a, ArrowApply a) => MonadPlus (ImproveArrow a b) where mzero = zeroArrow {-# INLINE mzero #-} mplus = (<+>) {-# INLINE mplus #-} instance (ArrowApply a) => MonadZip (ImproveArrow a b) where mzip = (&&&) {-# INLINE mzip #-} instance (Arrow a) => Profunctor (ImproveArrow a) where dimap f g x = f ^>> x >>^ g {-# INLINE dimap #-} lmap f x = f ^>> x {-# INLINE lmap #-} rmap g x = x >>^ g {-# INLINE rmap #-} instance (Arrow a) => Strong (ImproveArrow a) where first' = first {-# INLINE first' #-} second' = second {-# INLINE second' #-} instance (ArrowChoice a) => Choice (ImproveArrow a) where left' = left {-# INLINE left' #-} right' = right {-# INLINE right' #-} instance (Arrow a) => Pointed (ImproveArrow a b) where point = pure {-# INLINE point #-} instance (Arrow a) => Semigroupoid (ImproveArrow a) where o = (.) {-# INLINE o #-} instance (ArrowPlus a) => Alt (ImproveArrow a b) where () = (<+>) {-# INLINE () #-} instance (Arrow a) => Apply (ImproveArrow a b) where (<.>) = (<*>) {-# INLINE (<.>) #-} instance (ArrowApply a) => Bind (ImproveArrow a b) where (>>-) = (>>=) {-# INLINE (>>-) #-} instance (ArrowPlus a) => Plus (ImproveArrow a b) where zero = zeroArrow {-# INLINE zero #-} instance (ArrowPlus a) => Monoid (ImproveArrow a b c) where mempty = zeroArrow {-# INLINE mempty #-} mappend = (<+>) {-# INLINE mappend #-} instance (Arrow a, Num c) => Num (ImproveArrow a b c) where (+) = liftA2 (+) {-# INLINE (+) #-} (-) = liftA2 (-) {-# INLINE (-) #-} (*) = liftA2 (*) {-# INLINE (*) #-} negate = fmap negate {-# INLINE negate #-} abs = fmap abs {-# INLINE abs #-} signum = fmap signum {-# INLINE signum #-} fromInteger = pure . fromInteger {-# INLINE fromInteger #-} instance (Arrow a, Fractional c) => Fractional (ImproveArrow a b c) where (/) = liftA2 (/) {-# INLINE (/) #-} recip = fmap recip {-# INLINE recip #-} fromRational = pure . fromRational {-# INLINE fromRational #-} instance (Arrow a, Floating c) => Floating (ImproveArrow a b c) where pi = pure pi {-# INLINE pi #-} exp = fmap exp {-# INLINE exp #-} log = fmap log {-# INLINE log #-} sqrt = fmap sqrt {-# INLINE sqrt #-} (**) = liftA2 (**) {-# INLINE (**) #-} logBase = liftA2 logBase {-# INLINE logBase #-} sin = fmap sin {-# INLINE sin #-} cos = fmap cos {-# INLINE cos #-} tan = fmap tan {-# INLINE tan #-} asin = fmap asin {-# INLINE asin #-} acos = fmap acos {-# INLINE acos #-} atan = fmap atan {-# INLINE atan #-} sinh = fmap sinh {-# INLINE sinh #-} cosh = fmap cosh {-# INLINE cosh #-} tanh = fmap tanh {-# INLINE tanh #-} asinh = fmap asinh {-# INLINE asinh #-} acosh = fmap acosh {-# INLINE acosh #-} atanh = fmap atanh {-# INLINE atanh #-} instance (Arrow a, IsString c) => IsString (ImproveArrow a b c) where fromString = pure . fromString {-# INLINE fromString #-} instance (Arrow a) => ArrowTransformer ImproveArrow a where lift x = IArrow id x id {-# INLINE lift #-}