{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Data.Internal where

import Control.Concurrent.Async.Lifted (concurrently)
import qualified Control.Concurrent.STM as STM
import Control.Exception.Safe (bracket, finally)
import Control.Monad (when)
import Control.Monad.Base (MonadBase (..))
import Control.Monad.Cont (ContT (ContT))
import Control.Monad.Trans.Control
import Pipes
import Pipes.Concurrent hiding (atomically)
import qualified Pipes.Prelude as P

runWithBuffer ::
  forall a m b.
  (MonadBaseControl IO m) =>
  Int ->
  (Output a -> m ()) ->
  -- ContT b m (ListT m (a, Int))
  ContT b m (ListT m a)
runWithBuffer :: forall a (m :: * -> *) b.
MonadBaseControl IO m =>
Int -> (Output a -> m ()) -> ContT b m (ListT m a)
runWithBuffer Int
bufferSize Output a -> m ()
batchHandler = ((ListT m a -> m b) -> m b) -> ContT b m (ListT m a)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((ListT m a -> m b) -> m b) -> ContT b m (ListT m a))
-> ((ListT m a -> m b) -> m b) -> ContT b m (ListT m a)
forall a b. (a -> b) -> a -> b
$ \ListT m a -> m b
f ->
  ((), b) -> b
forall a b. (a, b) -> b
snd
    (((), b) -> b) -> m ((), b) -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Buffer a -> (Output a -> m ()) -> (Input a -> m b) -> m ((), b)
forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted
      (Int -> Buffer a
forall a. Int -> Buffer a
bounded Int
bufferSize)
      (\Output a
batchOutput -> Output a -> m ()
batchHandler Output a
batchOutput)
      -- (\input -> f . Select $ P.zip (fromInput' input) iters)
      (\Input a
input -> ListT m a -> m b
f (ListT m a -> m b)
-> (Producer a m () -> ListT m a) -> Producer a m () -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Producer a m () -> ListT m a
forall (m :: * -> *) a. Producer a m () -> ListT m a
Select (Producer a m () -> m b) -> Producer a m () -> m b
forall a b. (a -> b) -> a -> b
$ Input a -> Producer' a m ()
forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input a
input)

liftedBracket :: MonadBaseControl IO m => m a -> (a -> m b) -> (a -> m c) -> m c
liftedBracket :: forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
liftedBracket m a
acquire a -> m b
release a -> m c
action = (RunInBase m IO -> IO (StM m c)) -> m c
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(RunInBase m b -> b (StM m a)) -> m a
control ((RunInBase m IO -> IO (StM m c)) -> m c)
-> (RunInBase m IO -> IO (StM m c)) -> m c
forall a b. (a -> b) -> a -> b
$ \RunInBase m IO
runInIO ->
  IO (StM m a)
-> (StM m a -> IO (StM m b))
-> (StM m a -> IO (StM m c))
-> IO (StM m c)
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
    (m a -> IO (StM m a)
RunInBase m IO
runInIO m a
acquire)
    (\StM m a
saved -> m b -> IO (StM m b)
RunInBase m IO
runInIO (StM m a -> m a
forall a. StM m a -> m a
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
StM m a -> m a
restoreM StM m a
saved m a -> (a -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> m b
release))
    (\StM m a
saved -> m c -> IO (StM m c)
RunInBase m IO
runInIO (StM m a -> m a
forall a. StM m a -> m a
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
StM m a -> m a
restoreM StM m a
saved m a -> (a -> m c) -> m c
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> m c
action))

withBufferLifted ::
  (MonadBaseControl IO m) =>
  Buffer a ->
  (Output a -> m l) ->
  (Input a -> m r) ->
  m (l, r)
withBufferLifted :: forall (m :: * -> *) a l r.
MonadBaseControl IO m =>
Buffer a -> (Output a -> m l) -> (Input a -> m r) -> m (l, r)
withBufferLifted Buffer a
buffer Output a -> m l
fOutput Input a -> m r
fInput =
  m (Output a, Input a, STM ())
-> ((Output a, Input a, STM ()) -> m ())
-> ((Output a, Input a, STM ()) -> m (l, r))
-> m (l, r)
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
liftedBracket
    (IO (Output a, Input a, STM ()) -> m (Output a, Input a, STM ())
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO (Output a, Input a, STM ()) -> m (Output a, Input a, STM ()))
-> IO (Output a, Input a, STM ()) -> m (Output a, Input a, STM ())
forall a b. (a -> b) -> a -> b
$ Buffer a -> IO (Output a, Input a, STM ())
forall a. Buffer a -> IO (Output a, Input a, STM ())
spawn' Buffer a
buffer)
    (\(Output a
_, Input a
_, STM ()
seal) -> IO () -> m ()
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal)
    ( \(Output a
output, Input a
input, STM ()
seal) ->
        m l -> m r -> m (l, r)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m (a, b)
concurrently
          (Output a -> m l
fOutput Output a
output m l -> m () -> m l
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` (IO () -> m ()
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal))
          (Input a -> m r
fInput Input a
input m r -> m () -> m r
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`liftedFinally` (IO () -> m ()
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically STM ()
seal))
    )

fromInput' :: (MonadBase IO m) => Input a -> Producer' a m ()
fromInput' :: forall (m :: * -> *) a.
MonadBase IO m =>
Input a -> Producer' a m ()
fromInput' Input a
input = Proxy x' x () a m ()
forall {x'} {x}. Proxy x' x () a m ()
loop
  where
    loop :: Proxy x' x () a m ()
loop = do
      Maybe a
ma <- IO (Maybe a) -> Proxy x' x () a m (Maybe a)
forall α. IO α -> Proxy x' x () a m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO (Maybe a) -> Proxy x' x () a m (Maybe a))
-> IO (Maybe a) -> Proxy x' x () a m (Maybe a)
forall a b. (a -> b) -> a -> b
$ STM (Maybe a) -> IO (Maybe a)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (Maybe a) -> IO (Maybe a)) -> STM (Maybe a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ Input a -> STM (Maybe a)
forall a. Input a -> STM (Maybe a)
recv Input a
input
      case Maybe a
ma of
        Maybe a
Nothing -> () -> Proxy x' x () a m ()
forall a. a -> Proxy x' x () a m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just a
a -> do
          a -> Proxy x' x () a m ()
forall (m :: * -> *) a x' x. Functor m => a -> Proxy x' x () a m ()
yield a
a
          Proxy x' x () a m ()
loop

toOutput' :: (MonadBase IO m) => Output a -> Consumer' a m ()
toOutput' :: forall (m :: * -> *) a.
MonadBase IO m =>
Output a -> Consumer' a m ()
toOutput' Output a
output = Proxy () a y' y m ()
forall {y'} {y}. Proxy () a y' y m ()
loop
  where
    loop :: Proxy () a y' y m ()
loop = do
      a
a <- Proxy () a y' y m a
Consumer' a m a
forall (m :: * -> *) a. Functor m => Consumer' a m a
await
      Bool
alive <- IO Bool -> Proxy () a y' y m Bool
forall α. IO α -> Proxy () a y' y m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (IO Bool -> Proxy () a y' y m Bool)
-> IO Bool -> Proxy () a y' y m Bool
forall a b. (a -> b) -> a -> b
$ STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Output a -> a -> STM Bool
forall a. Output a -> a -> STM Bool
send Output a
output a
a
      Bool -> Proxy () a y' y m () -> Proxy () a y' y m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive Proxy () a y' y m ()
loop

liftedFinally :: MonadBaseControl IO m => m a -> m b -> m a
liftedFinally :: forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
liftedFinally m a
a m b
sequel = (RunInBase m IO -> IO (StM m a)) -> m a
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(RunInBase m b -> b (StM m a)) -> m a
control ((RunInBase m IO -> IO (StM m a)) -> m a)
-> (RunInBase m IO -> IO (StM m a)) -> m a
forall a b. (a -> b) -> a -> b
$ \RunInBase m IO
runInIO ->
  IO (StM m a) -> IO (StM m b) -> IO (StM m a)
forall (m :: * -> *) a b.
(HasCallStack, MonadMask m) =>
m a -> m b -> m a
finally
    (m a -> IO (StM m a)
RunInBase m IO
runInIO m a
a)
    (m b -> IO (StM m b)
RunInBase m IO
runInIO m b
sequel)

atomically :: MonadIO m => STM a -> m a
atomically :: forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (STM a -> IO a) -> STM a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM a -> IO a
forall a. STM a -> IO a
STM.atomically

instance (MonadBase IO m) => MonadBase IO (Proxy a' a b' b m) where
  liftBase :: forall α. IO α -> Proxy a' a b' b m α
liftBase = m α -> Proxy a' a b' b m α
forall (m :: * -> *) a. Monad m => m a -> Proxy a' a b' b m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m α -> Proxy a' a b' b m α)
-> (IO α -> m α) -> IO α -> Proxy a' a b' b m α
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO α -> m α
forall α. IO α -> m α
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase

---- make a runData function which just does runContT but zips that
---- the listT with the iteration! This is much better