module Control.Arrow.Mix (module Control.Arrow.Mix.Category, (:+), liftImpure, liftPure, unPlus, unPlus', plusTwist, plusAssoc, plusCommute) where
import Prelude hiding (id, (.))
import Control.Arrow
import Control.Category
import Control.Arrow.Mix.Category
import Control.Arrow.Mix.Utilities
data Mix a b input output where Mix :: (i -> o) -> (Along a i o :~> a) -> Along b input output i o -> Mix a b input output
newtype (a :+ b) input output = APlus ((Mix a :$~ b) input output) deriving (AlMonad, AlFunctor, Category, Arrow, ArrowLoop, ArrowChoice)
infixl 6 :+
liftImpure :: (ArrowChoice a, ArrowLoop a, Arrow b) => a :~> a :+ b
liftImpure a = APlus $ Apply $ Mix Right (\al -> loop $ al >>> second (a ||| id)) (arr (first Left . swap))
liftPure :: (Arrow a, Arrow b) => b :~> a :+ b
liftPure = alRet
unPlus :: Arrow a => a :+ (->) :~> a
unPlus = unPlus' . alMap arr
unPlus' :: Arrow a => a :+ a :~> a
unPlus' (APlus (Apply (Mix _ a1 a2))) = a1 $ arrSwap a2
instance Arrow a => AlMonad (Mix a) where
alRet = Mix id arrCancelUnit . second
alLift h (Mix r1 a1 b) =
case h r1 b of
Mix r2 a2 c -> Mix (r2 *** r1) (a2 . a1 . arrAssocRtoL) (arrAssocRtoL c)
plusTwist :: (Arrow a, AlFunctor f, Arrow c) => a :+ f c :~> f (a :+ c)
plusTwist (APlus (Apply (Mix r a c))) = alMap (APlus . Apply . Mix r a . arrTwist) c
plusAssoc :: (Arrow a, Arrow b, Arrow c) => a :+ (b :+ c) :~> (a :+ b) :+ c
plusAssoc (APlus (Apply (Mix r a (APlus (Apply (Mix s b c)))))) = APlus (Apply (Mix (s *** r) (ab r a b) (arrAssocRtoL c)))
where f :: ((a, (b, c)), d) -> (b, (c, (d, a)))
f ~((a1, (b1, c1)), d1) = (b1, (c1, (d1, a1)))
g :: (b, (c, (d, a))) -> ((a, (b, c)), d)
g ~(b1, (c1, (d1, a1))) = ((a1, (b1, c1)), d1)
ab :: (Arrow a, Arrow b) => (i -> o) -> (Along a i o :~> a) -> (Along b i' o' :~> b) -> (Along (a :+ b) (i', i) (o', o) :~> a :+ b)
ab r1 a1 b1 (APlus (Apply (Mix r' a' b'))) = APlus (Apply (Mix (r1 *** r') (a1 . a' . arrAssocRtoL) (arrAssocRtoL $ b1 $ f ^>> b' >>^ g)))
plusCommute :: (Arrow a, ArrowChoice b, ArrowLoop b) => a :+ b :~> b :+ a
plusCommute = alMap unPlus . plusTwist . alMap liftImpure