{-# LANGUAGE Arrows #-}

module Control.Arrow.Utils (
    SameInputArrow(..)
  , traverseArr
  , traverseArr_
  , sequenceArr_
  , sequenceArr
  , zipSequenceArrVec
  , zipSequenceArrList
  , whenArr
  , unlessArr
  , constantly
) where

import Control.Arrow
    ( returnA, (>>>), Arrow((***), arr), ArrowChoice )
import Data.Foldable (traverse_)
import Data.Maybe ( fromJust )
import Data.Vector.Sized ( fromList, toList, Vector )
import qualified Data.Vector.Sized as Vec
import GHC.TypeLits ( KnownNat )

-- | Wrap the Arrow in a newtype in order to create new class instances.
--   This is a generalisation of 'ArrowMonad',
--   which is isomorphic to @'SameInputArrow' a () c@.
newtype SameInputArrow a b c = SameInputArrow { SameInputArrow a b c -> a b c
unSameInputArrow :: a b c }

-- | @'fmap' f@ postcomposes with @f@
instance (Arrow a) => Functor (SameInputArrow a b) where
  fmap :: (a -> b) -> SameInputArrow a b a -> SameInputArrow a b b
fmap a -> b
f SameInputArrow a b a
a = a b b -> SameInputArrow a b b
forall (a :: * -> * -> *) b c. a b c -> SameInputArrow a b c
SameInputArrow (SameInputArrow a b a -> a b a
forall (a :: * -> * -> *) b c. SameInputArrow a b c -> a b c
unSameInputArrow SameInputArrow a b a
a a b a -> a a b -> a b b
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (a -> b) -> a a b
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr a -> b
f)

-- | @'<*>'@ runs the arrows in parallel
instance (Arrow a) => Applicative (SameInputArrow a b) where
  pure :: a -> SameInputArrow a b a
pure a
c = a b a -> SameInputArrow a b a
forall (a :: * -> * -> *) b c. a b c -> SameInputArrow a b c
SameInputArrow (a b a -> SameInputArrow a b a) -> a b a -> SameInputArrow a b a
forall a b. (a -> b) -> a -> b
$ a -> a b a
forall (a :: * -> * -> *) b any. Arrow a => b -> a any b
constantly a
c
  SameInputArrow a b (a -> b)
f <*> :: SameInputArrow a b (a -> b)
-> SameInputArrow a b a -> SameInputArrow a b b
<*> SameInputArrow a b a
a = a b b -> SameInputArrow a b b
forall (a :: * -> * -> *) b c. a b c -> SameInputArrow a b c
SameInputArrow (a b b -> SameInputArrow a b b) -> a b b -> SameInputArrow a b b
forall a b. (a -> b) -> a -> b
$ proc b
input -> do
    a -> b
fres <- SameInputArrow a b (a -> b) -> a b (a -> b)
forall (a :: * -> * -> *) b c. SameInputArrow a b c -> a b c
unSameInputArrow SameInputArrow a b (a -> b)
f -< b
input
    a
ares <- SameInputArrow a b a -> a b a
forall (a :: * -> * -> *) b c. SameInputArrow a b c -> a b c
unSameInputArrow SameInputArrow a b a
a -< b
input
    a b b
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< a -> b
fres a
ares

-- | Creates arrows using f, then runs all arrows in the given 'Foldable',
-- discarding the results.
traverseArr_ :: (Foldable t, Arrow a) => (x -> a b c) -> t x -> a b ()
traverseArr_ :: (x -> a b c) -> t x -> a b ()
traverseArr_ x -> a b c
f t x
xs = SameInputArrow a b () -> a b ()
forall (a :: * -> * -> *) b c. SameInputArrow a b c -> a b c
unSameInputArrow (SameInputArrow a b () -> a b ())
-> SameInputArrow a b () -> a b ()
forall a b. (a -> b) -> a -> b
$ (x -> SameInputArrow a b c) -> t x -> SameInputArrow a b ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (a b c -> SameInputArrow a b c
forall (a :: * -> * -> *) b c. a b c -> SameInputArrow a b c
SameInputArrow (a b c -> SameInputArrow a b c)
-> (x -> a b c) -> x -> SameInputArrow a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> a b c
f) t x
xs

-- | Creates arrows using f, then runs all arrows in the given 'Traversable',
-- collecting the results.
--
--   @traverseArr (+) [1,10] 1 == [2,11]@
traverseArr :: (Traversable t, Arrow a) => (x -> a b c) -> t x -> a b (t c)
traverseArr :: (x -> a b c) -> t x -> a b (t c)
traverseArr x -> a b c
f t x
xs = SameInputArrow a b (t c) -> a b (t c)
forall (a :: * -> * -> *) b c. SameInputArrow a b c -> a b c
unSameInputArrow (SameInputArrow a b (t c) -> a b (t c))
-> SameInputArrow a b (t c) -> a b (t c)
forall a b. (a -> b) -> a -> b
$ (x -> SameInputArrow a b c) -> t x -> SameInputArrow a b (t c)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a b c -> SameInputArrow a b c
forall (a :: * -> * -> *) b c. a b c -> SameInputArrow a b c
SameInputArrow (a b c -> SameInputArrow a b c)
-> (x -> a b c) -> x -> SameInputArrow a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> a b c
f) t x
xs

-- | Like 'sequenceArr', but discard the results.
sequenceArr_ :: (Foldable t, Arrow a) => t (a b any) -> a b ()
sequenceArr_ :: t (a b any) -> a b ()
sequenceArr_ = (a b any -> a b any) -> t (a b any) -> a b ()
forall (t :: * -> *) (a :: * -> * -> *) x b c.
(Foldable t, Arrow a) =>
(x -> a b c) -> t x -> a b ()
traverseArr_ a b any -> a b any
forall a. a -> a
id

-- | Run all arrows in the given 'Traversable', collecting the results.
--
--   @sequenceArr [(+1), (+10)] 1 == [2,11]@
sequenceArr :: (Traversable t, Arrow a) => t (a b c) -> a b (t c)
sequenceArr :: t (a b c) -> a b (t c)
sequenceArr = (a b c -> a b c) -> t (a b c) -> a b (t c)
forall (t :: * -> *) (a :: * -> * -> *) x b c.
(Traversable t, Arrow a) =>
(x -> a b c) -> t x -> a b (t c)
traverseArr a b c -> a b c
forall a. a -> a
id

-- | Fans each input from @Vector n b@ to a separate arrow from the given vector.
--
--   @sequenceArrVec (Vec.generate ((+).fromIntegral) :: Vector 5 (Int -> Int)) (Vec.replicate 1 :: Vector 5 Int) == Vector [1,2,3,4,5]@
zipSequenceArrVec :: (Arrow a, KnownNat n) => Vector n (a b c) -> a (Vector n b) (Vector n c)
zipSequenceArrVec :: Vector n (a b c) -> a (Vector n b) (Vector n c)
zipSequenceArrVec Vector n (a b c)
cells = (Vector n b -> [b]) -> a (Vector n b) [b]
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr Vector n b -> [b]
forall (n :: Nat) a. Vector n a -> [a]
toList a (Vector n b) [b]
-> a [b] (Vector n c) -> a (Vector n b) (Vector n c)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> [a b c] -> a [b] [c]
forall (a :: * -> * -> *) b c. Arrow a => [a b c] -> a [b] [c]
zipSequenceArrListUnsafe (Vector n (a b c) -> [a b c]
forall (n :: Nat) a. Vector n a -> [a]
toList Vector n (a b c)
cells) a [b] [c] -> a [c] (Vector n c) -> a [b] (Vector n c)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ([c] -> Vector n c) -> a [c] (Vector n c)
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr (Maybe (Vector n c) -> Vector n c
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Vector n c) -> Vector n c)
-> ([c] -> Maybe (Vector n c)) -> [c] -> Vector n c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [c] -> Maybe (Vector n c)
forall (n :: Nat) a. KnownNat n => [a] -> Maybe (Vector n a)
fromList)

-- Not safe, doesn't check size of the lists.
-- When used in sequenceArrVec it is safe as the size of the
-- vectors are all the same. Not to export.
zipSequenceArrListUnsafe :: Arrow a => [a b c] -> a [b] [c]
zipSequenceArrListUnsafe :: [a b c] -> a [b] [c]
zipSequenceArrListUnsafe [] = [c] -> a [b] [c]
forall (a :: * -> * -> *) b any. Arrow a => b -> a any b
constantly []
zipSequenceArrListUnsafe (a b c
x:[a b c]
xs) = proc (b
y:[b]
ys) -> do
  c
xres <- a b c
x -< b
y
  [c]
xsres <- [a b c] -> a [b] [c]
forall (a :: * -> * -> *) b c. Arrow a => [a b c] -> a [b] [c]
zipSequenceArrListUnsafe [a b c]
xs -< [b]
ys
  a [c] [c]
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< (c
xresc -> [c] -> [c]
forall a. a -> [a] -> [a]
:[c]
xsres)

-- | Fans each input from @[b]@ to a separate arrow from the given list.
--   The output list has length of the minimum of the input list length and the arrow list length.
--
--  @
--  sequenceArrList [(+1), (+10)] [1,2] == [2,12]
--  sequenceArrList [(+1), (+10)] [1]   == [2]
--  sequenceArrList [(+1)] [1,2,3,4]    == [2]@
zipSequenceArrList :: (Arrow a, ArrowChoice a) => [a b c] -> a [b] [c]
zipSequenceArrList :: [a b c] -> a [b] [c]
zipSequenceArrList [] = [c] -> a [b] [c]
forall (a :: * -> * -> *) b any. Arrow a => b -> a any b
constantly []
zipSequenceArrList (a b c
a : [a b c]
as) = proc [b]
bs' -> case [b]
bs' of
  [] -> a [c] [c]
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< []
  b
b : [b]
bs -> do
    c
c <- a b c
a -< b
b
    [c]
cs <- [a b c] -> a [b] [c]
forall (a :: * -> * -> *) b c.
(Arrow a, ArrowChoice a) =>
[a b c] -> a [b] [c]
zipSequenceArrList [a b c]
as -< [b]
bs
    a [c] [c]
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< c
c c -> [c] -> [c]
forall a. a -> [a] -> [a]
: [c]
cs

-- | Similar to @'when'@ for @'Applicative'@. Relevant for
--   arrows which embeded a Monad.
whenArr :: ArrowChoice a => a b () -> a (Bool, b) ()
whenArr :: a b () -> a (Bool, b) ()
whenArr a b ()
cell = proc (Bool
b, b
input) -> do
  if Bool
b
    then a b ()
cell -< b
input
    else () -> a b ()
forall (a :: * -> * -> *) b any. Arrow a => b -> a any b
constantly () -< b
input

-- | Similar to @'unless'@ for @'Applicative'@. Relevant for
--   arrows which embeded a Monad.
unlessArr :: ArrowChoice a => a b () -> a (Bool, b) ()
unlessArr :: a b () -> a (Bool, b) ()
unlessArr a b ()
cell = (Bool -> Bool) -> a Bool Bool
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr Bool -> Bool
not a Bool Bool -> a b b -> a (Bool, b) (Bool, b)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (b -> b) -> a b b
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr b -> b
forall a. a -> a
id a (Bool, b) (Bool, b) -> a (Bool, b) () -> a (Bool, b) ()
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> a b () -> a (Bool, b) ()
forall (a :: * -> * -> *) b.
ArrowChoice a =>
a b () -> a (Bool, b) ()
whenArr a b ()
cell

-- | Always output the given value.
constantly :: Arrow a => b -> a any b
constantly :: b -> a any b
constantly = (any -> b) -> a any b
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr ((any -> b) -> a any b) -> (b -> any -> b) -> b -> a any b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> any -> b
forall a b. a -> b -> a
const