{-# LANGUAGE BangPatterns
           , FlexibleContexts #-}

module Vision.Image.Parallel (computeP) where

import Control.Concurrent (
    forkIO, getNumCapabilities, newEmptyMVar, putMVar, takeMVar)
import Control.Monad.ST (ST, stToIO)
import Data.Vector (enumFromN, forM, forM_)
import Foreign.Storable (Storable)
import System.IO.Unsafe (unsafePerformIO)

import Vision.Image.Class (MaskedImage (..), Image (..), (!))
import Vision.Image.Type (Manifest (..))
import Vision.Image.Mutable (MutableManifest, linearWrite, new, unsafeFreeze)
import Vision.Primitive (Z (..), (:.) (..), ix2)


-- | Parallel version of 'compute'.
--
-- Computes the value of an image into a manifest representation in parallel.
--
-- The monad ensures that the image is fully evaluated before continuing.
computeP :: (Monad m, Image i, Storable (ImagePixel i))
        => i -> m (Manifest (ImagePixel i))
computeP :: forall (m :: * -> *) i.
(Monad m, Image i, Storable (ImagePixel i)) =>
i -> m (Manifest (ImagePixel i))
computeP !i
src =
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
        MutableManifest (ImagePixel i) RealWorld
dst <- forall a. ST RealWorld a -> IO a
stToIO forall p s. Storable p => ST s (MutableManifest p s)
newManifest

        -- Forks 'nCapabilities' threads.
        Vector (MVar ())
childs <- forall (m :: * -> *) a b.
Monad m =>
Vector a -> (a -> m b) -> m (Vector b)
forM (forall a. Num a => a -> Int -> Vector a
enumFromN Int
0 Int
nCapabilities) forall a b. (a -> b) -> a -> b
$ \Int
c -> do
            MVar ()
child <- forall a. IO (MVar a)
newEmptyMVar

            ThreadId
_ <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do
                let nLines :: Int
nLines | Int
c forall a. Eq a => a -> a -> Bool
== Int
0    = Int
nLinesPerThread forall a. Num a => a -> a -> a
+ Int
remain
                           | Bool
otherwise = Int
nLinesPerThread

                forall a. ST RealWorld a -> IO a
stToIO forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *}.
PrimMonad m =>
MutableManifest (ImagePixel i) (PrimState m) -> Int -> Int -> m ()
fillFromN MutableManifest (ImagePixel i) RealWorld
dst (Int
c forall a. Num a => a -> a -> a
* Int
nLinesPerThread) Int
nLines

                -- Sends a signal to the main thread.
                forall a. MVar a -> a -> IO ()
putMVar MVar ()
child ()

            forall (m :: * -> *) a. Monad m => a -> m a
return MVar ()
child

        -- Waits for all threads to finish.
        forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
forM_ Vector (MVar ())
childs forall a. MVar a -> IO a
takeMVar

        forall a. ST RealWorld a -> IO a
stToIO forall a b. (a -> b) -> a -> b
$ forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> m (Freezed i)
unsafeFreeze MutableManifest (ImagePixel i) RealWorld
dst
  where
    !size :: Size
size@(DIM0
Z :. Int
h :. Int
w) = forall i. MaskedImage i => i -> Size
shape i
src

    !nCapabilities :: Int
nCapabilities = forall a. IO a -> a
unsafePerformIO IO Int
getNumCapabilities

    !(Int
nLinesPerThread, Int
remain) = Int
h forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
nCapabilities

    -- Computes 'n' lines starting at 'from' of the image.
    fillFromN :: MutableManifest (ImagePixel i) (PrimState m) -> Int -> Int -> m ()
fillFromN !MutableManifest (ImagePixel i) (PrimState m)
dst !Int
from !Int
n =
        forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
forM_ (forall a. Num a => a -> Int -> Vector a
enumFromN Int
from Int
n) forall a b. (a -> b) -> a -> b
$ \Int
y -> do
            let !lineOffset :: Int
lineOffset = Int
y forall a. Num a => a -> a -> a
* Int
w
            forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
forM_ (forall a. Num a => a -> Int -> Vector a
enumFromN Int
0 Int
w) forall a b. (a -> b) -> a -> b
$ \Int
x -> do
                let !offset :: Int
offset = Int
lineOffset forall a. Num a => a -> a -> a
+ Int
x
                    !val :: ImagePixel i
val    = i
src forall i. Image i => i -> Size -> ImagePixel i
! (Int -> Int -> Size
ix2 Int
y Int
x)
                forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
i (PrimState m) -> Int -> ImagePixel (Freezed i) -> m ()
linearWrite MutableManifest (ImagePixel i) (PrimState m)
dst Int
offset ImagePixel i
val

    newManifest :: Storable p => ST s (MutableManifest p s)
    newManifest :: forall p s. Storable p => ST s (MutableManifest p s)
newManifest = forall (i :: * -> *) (m :: * -> *).
(MutableImage i, PrimMonad m) =>
Size -> m (i (PrimState m))
new Size
size
{-# INLINE computeP #-}