-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

module Morley.Util.Batching
  ( BatchingM
  , runBatching
  , unsafeRunBatching
  , submitThenParse

  , BatchingError (..)
  ) where

import Control.Monad.Except (Except, runExcept, throwError)
import Fmt (Buildable(..))

-- | Errors that can occur during batching, usually because the
-- underlying function that performs batch operation returns output
-- that does not match the provided input.
data BatchingError e
  -- | The function that executes the batch returned less elements in
  -- output than were provided at input.
  = InsufficientOutput
  -- | The function that executes the batch returned more elements in
  -- output than were provided at input.
  | ExtraOutput
  -- | User-provided parsing method failed.
  -- Usually this means that output does not correspond to provided input.
  | UnexpectedElement e

instance Buildable e => Buildable (BatchingError e) where
  build :: BatchingError e -> Builder
build = \case
    BatchingError e
InsufficientOutput ->
      Builder
"Too few elements in output of batch operation"
    BatchingError e
ExtraOutput ->
      Builder
"Too many elements in output of batch operation"
    UnexpectedElement e
e ->
      Builder
"Unexpected element: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> e -> Builder
forall p. Buildable p => p -> Builder
build e
e

-- | Records operations to be executed in batch.
--
-- Chronologically, this works in 3 steps:
--
-- * Form the list of input items @i@;
-- * Perform the batch operation;
-- * Parse output items @o@ into result @a@, maybe producing error @e@.
--
-- However in code we usually want steps 1 and 3 to be grouped
-- and step 2 to be delayed - 'BatchingM' facilitates this separation.
--
-- Note that 'BatchingM' is fundamentally not a monad, rather just an applicative,
-- because within a batch you cannot use result of one operation in another
-- operation.
data BatchingM i o e a = BatchingM
  { BatchingM i o e a -> Endo [i]
bInput :: Endo [i]
    -- ^ All the provided input, in some sort of DList
  , BatchingM i o e a -> StateT [o] (Except (BatchingError e)) a
bParseOutput :: StateT [o] (Except (BatchingError e)) a
    -- ^ Parser for output when it is available
  } deriving stock a -> BatchingM i o e b -> BatchingM i o e a
(a -> b) -> BatchingM i o e a -> BatchingM i o e b
(forall a b. (a -> b) -> BatchingM i o e a -> BatchingM i o e b)
-> (forall a b. a -> BatchingM i o e b -> BatchingM i o e a)
-> Functor (BatchingM i o e)
forall a b. a -> BatchingM i o e b -> BatchingM i o e a
forall a b. (a -> b) -> BatchingM i o e a -> BatchingM i o e b
forall i o e a b. a -> BatchingM i o e b -> BatchingM i o e a
forall i o e a b.
(a -> b) -> BatchingM i o e a -> BatchingM i o e b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> BatchingM i o e b -> BatchingM i o e a
$c<$ :: forall i o e a b. a -> BatchingM i o e b -> BatchingM i o e a
fmap :: (a -> b) -> BatchingM i o e a -> BatchingM i o e b
$cfmap :: forall i o e a b.
(a -> b) -> BatchingM i o e a -> BatchingM i o e b
Functor

instance Applicative (BatchingM i o e) where
  pure :: a -> BatchingM i o e a
pure a
a = BatchingM :: forall i o e a.
Endo [i]
-> StateT [o] (Except (BatchingError e)) a -> BatchingM i o e a
BatchingM
    { bInput :: Endo [i]
bInput = Endo [i]
forall a. Monoid a => a
mempty, bParseOutput :: StateT [o] (Except (BatchingError e)) a
bParseOutput = a -> StateT [o] (Except (BatchingError e)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a }
  BatchingM i o e (a -> b)
b1 <*> :: BatchingM i o e (a -> b) -> BatchingM i o e a -> BatchingM i o e b
<*> BatchingM i o e a
b2 = BatchingM :: forall i o e a.
Endo [i]
-> StateT [o] (Except (BatchingError e)) a -> BatchingM i o e a
BatchingM
    { bInput :: Endo [i]
bInput = BatchingM i o e (a -> b) -> Endo [i]
forall i o e a. BatchingM i o e a -> Endo [i]
bInput BatchingM i o e (a -> b)
b1 Endo [i] -> Endo [i] -> Endo [i]
forall a. Semigroup a => a -> a -> a
<> BatchingM i o e a -> Endo [i]
forall i o e a. BatchingM i o e a -> Endo [i]
bInput BatchingM i o e a
b2
    , bParseOutput :: StateT [o] (Except (BatchingError e)) b
bParseOutput = BatchingM i o e (a -> b)
-> StateT [o] (Except (BatchingError e)) (a -> b)
forall i o e a.
BatchingM i o e a -> StateT [o] (Except (BatchingError e)) a
bParseOutput BatchingM i o e (a -> b)
b1 StateT [o] (Except (BatchingError e)) (a -> b)
-> StateT [o] (Except (BatchingError e)) a
-> StateT [o] (Except (BatchingError e)) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> BatchingM i o e a -> StateT [o] (Except (BatchingError e)) a
forall i o e a.
BatchingM i o e a -> StateT [o] (Except (BatchingError e)) a
bParseOutput BatchingM i o e a
b2
    }

-- | Run recorded operations sequence using the given batch executor.
runBatching
  :: (Functor m)
  => ([i] -> m (r, [o]))
  -> BatchingM i o e a
  -> m (r, Either (BatchingError e) a)
runBatching :: ([i] -> m (r, [o]))
-> BatchingM i o e a -> m (r, Either (BatchingError e) a)
runBatching [i] -> m (r, [o])
execBatch BatchingM{Endo [i]
StateT [o] (Except (BatchingError e)) a
bParseOutput :: StateT [o] (Except (BatchingError e)) a
bInput :: Endo [i]
bParseOutput :: forall i o e a.
BatchingM i o e a -> StateT [o] (Except (BatchingError e)) a
bInput :: forall i o e a. BatchingM i o e a -> Endo [i]
..} =
    ([o] -> Either (BatchingError e) a)
-> (r, [o]) -> (r, Either (BatchingError e) a)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [o] -> Either (BatchingError e) a
parseResult ((r, [o]) -> (r, Either (BatchingError e) a))
-> m (r, [o]) -> m (r, Either (BatchingError e) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [i] -> m (r, [o])
execBatch (Endo [i] -> [i] -> [i]
forall a. Endo a -> a -> a
appEndo Endo [i]
bInput [])
  where
    parseResult :: [o] -> Either (BatchingError e) a
parseResult [o]
output =
      Except (BatchingError e) (a, [o])
-> Either (BatchingError e) (a, [o])
forall e a. Except e a -> Either e a
runExcept (StateT [o] (Except (BatchingError e)) a
-> [o] -> Except (BatchingError e) (a, [o])
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT [o] (Except (BatchingError e)) a
bParseOutput [o]
output) Either (BatchingError e) (a, [o])
-> ((a, [o]) -> Either (BatchingError e) a)
-> Either (BatchingError e) a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        (a
a, []) -> a -> Either (BatchingError e) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
        (a, [o])
_ -> BatchingError e -> Either (BatchingError e) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BatchingError e
forall e. BatchingError e
ExtraOutput

-- | Similar to 'runBatching', for cases when the given batch executor
-- is guaranteed to return the output respective to the provided input.
unsafeRunBatching
  :: (Functor m, Buildable e)
  => ([i] -> m (r, [o]))
  -> BatchingM i o e a
  -> m (r, a)
unsafeRunBatching :: ([i] -> m (r, [o])) -> BatchingM i o e a -> m (r, a)
unsafeRunBatching =
  ((r, Either (BatchingError e) a) -> (r, a))
-> m (r, Either (BatchingError e) a) -> m (r, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Either (BatchingError e) a -> a)
-> (r, Either (BatchingError e) a) -> (r, a)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Either (BatchingError e) a -> a
forall a b. (HasCallStack, Buildable a) => Either a b -> b
unsafe) (m (r, Either (BatchingError e) a) -> m (r, a))
-> (([i] -> m (r, [o]))
    -> BatchingM i o e a -> m (r, Either (BatchingError e) a))
-> ([i] -> m (r, [o]))
-> BatchingM i o e a
-> m (r, a)
forall a b c. SuperComposition a b c => a -> b -> c
... ([i] -> m (r, [o]))
-> BatchingM i o e a -> m (r, Either (BatchingError e) a)
forall (m :: * -> *) i r o e a.
Functor m =>
([i] -> m (r, [o]))
-> BatchingM i o e a -> m (r, Either (BatchingError e) a)
runBatching

-- | This is the basic primitive for all actions in 'BatchingM'.
--
-- It records that given input item should be put to batch, and once operation
-- is actually performed, the result should be parsed with given method.
submitThenParse :: i -> (o -> Either e a) -> BatchingM i o e a
submitThenParse :: i -> (o -> Either e a) -> BatchingM i o e a
submitThenParse i
inp o -> Either e a
parse = BatchingM :: forall i o e a.
Endo [i]
-> StateT [o] (Except (BatchingError e)) a -> BatchingM i o e a
BatchingM
  { bInput :: Endo [i]
bInput = ([i] -> [i]) -> Endo [i]
forall a. (a -> a) -> Endo a
Endo (i
inp i -> [i] -> [i]
forall a. a -> [a] -> [a]
:)
  , bParseOutput :: StateT [o] (Except (BatchingError e)) a
bParseOutput = ([o] -> Except (BatchingError e) (a, [o]))
-> StateT [o] (Except (BatchingError e)) a
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (([o] -> Except (BatchingError e) (a, [o]))
 -> StateT [o] (Except (BatchingError e)) a)
-> ([o] -> Except (BatchingError e) (a, [o]))
-> StateT [o] (Except (BatchingError e)) a
forall a b. (a -> b) -> a -> b
$ \case
      [] -> BatchingError e -> Except (BatchingError e) (a, [o])
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BatchingError e
forall e. BatchingError e
InsufficientOutput
      (o
o : [o]
os) -> case o -> Either e a
parse o
o of
        Left e
e -> BatchingError e -> Except (BatchingError e) (a, [o])
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (BatchingError e -> Except (BatchingError e) (a, [o]))
-> BatchingError e -> Except (BatchingError e) (a, [o])
forall a b. (a -> b) -> a -> b
$ e -> BatchingError e
forall e. e -> BatchingError e
UnexpectedElement e
e
        Right a
x -> (a, [o]) -> Except (BatchingError e) (a, [o])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, [o]
os)
  }
infix 1 `submitThenParse`