{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
{-# LANGUAGE CPP #-}
#include "fusion-phases.h"

-- | Operations on Distributed Segment Descriptors
module Data.Array.Parallel.Unlifted.Distributed.USegd 
        ( splitSegdOnSegsD
        , splitSegdOnElemsD
        , splitSD
        , joinSegdD
        , glueSegdD)
where
import Data.Array.Parallel.Unlifted.Distributed.Arrays
import Data.Array.Parallel.Unlifted.Distributed.Combinators
import Data.Array.Parallel.Unlifted.Distributed.Types
import Data.Array.Parallel.Unlifted.Distributed.Gang
import Data.Array.Parallel.Unlifted.Sequential.USegd                    (USegd)
import Data.Array.Parallel.Unlifted.Sequential.Vector                   (Vector, Unbox)
import Data.Array.Parallel.Base
import Data.Bits                                                        (shiftR)
import Control.Monad                                                    (when)
import qualified Data.Array.Parallel.Unlifted.Distributed.Types.USegd   as DUSegd
import qualified Data.Array.Parallel.Unlifted.Sequential.USegd          as USegd
import qualified Data.Array.Parallel.Unlifted.Sequential.Vector         as Seq

here :: String -> String
here s = "Data.Array.Parallel.Unlifted.Distributed.USegd." ++ s

-------------------------------------------------------------------------------
-- | Split a segment descriptor across the gang, segment wise.
--   Whole segments are placed on each thread, and we try to balance out
--   the segments so each thread has the same number of array elements.
--
--   We don't split segments across threads, as this would limit our ability
--   to perform intra-thread fusion of lifted operations. The down side
--   of this is that if we have few segments with an un-even size distribution
--   then large segments can cause the gang to become unbalanced.
--
--   In the following example the segment with size 100 dominates and
--   unbalances the gang. There is no reason to put any segments on the
--   the last thread because we need to wait for the first to finish anyway.
--
--   @ > pprp $ splitSegdOnSegsD theGang
--            $ lengthsToUSegd $ fromList [100, 10, 20, 40, 50  :: Int]
-- 
--     DUSegd lengths:   DVector lengths:  [ 1,    3,         1,  0]
--                                chunks:  [[100],[10,20,40],[50],[]]
-- 
--            indices:   DVector lengths:  [1,3,1,0]
--                                chunks:  [[0],  [0,10,30], [0], []]
--
--            elements:  DInt [100,70,50,0]
--   @
--
--  NOTE: This splitSegdOnSegsD function isn't currently used.
--
splitSegdOnSegsD :: Gang -> USegd -> Dist USegd
splitSegdOnSegsD g !segd 
  = mapD g USegd.fromLengths
  $ splitAsD g d lens
  where
    !d   = snd
         . mapAccumLD g chunks 0
         . splitLenD g
         $ USegd.takeElements segd

    n    = USegd.length segd
    lens = USegd.takeLengths segd

    chunks !i !k 
      = let !j = go i k
        in  (j,j-i)

    go !i !k | i >= n    = i
             | m == 0    = go (i+1) k
             | k <= 0    = i
             | otherwise = go (i+1) (k-m)
      where
        m = Seq.index (here "splitSegdOnSegsD") lens i
{-# NOINLINE splitSegdOnSegsD #-}


-------------------------------------------------------------------------------
-- | Split a segment descriptor across the gang, element wise.
--   We try to put the same number of elements on each thread, which means
--   that segments are sometimes split across threads.
--
--   Each thread gets a slice of segment descriptor, the segid of the first 
--   slice, and the offset of the first slice in its segment.
--   
--   Example:
--    In this picture each X represents 5 elements, and we have 5 segements in total.
--
-- @  segs:    ----------------------- --- ------- --------------- -------------------
--    elems:  |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
--            |     thread1     |     thread2     |     thread3     |     thread4     |
--    segid:  0                 0                 3                 4
--    offset: 0                 45                0                 5
--
--    pprp $ splitSegdOnElemsD theGang4
--          $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50 :: Int]
--
--     segd:    DUSegd lengths:  DVector lengths: [1,3,2,1]
--                                        chunks:  [[45],[15,10,20],[40,5],[45]]
--                     indices:  DVector lengths: [1,3,2,1]
--                                        chunks:  [[0], [0,15,25], [0,40],[0]]
--                    elements:  DInt [45,45,45,45]
--
--     segids: DInt [0,0,3,4]     (segment id of first slice on thread)
--    offsets: DInt [0,45,0,5]    (offset of that slice in its segment)
-- @
--
splitSegdOnElemsD :: Gang -> USegd -> Dist ((USegd,Int),Int)
splitSegdOnElemsD g !segd 
  = {-# SCC "splitSegdOnElemsD" #-} 
    imapD g mk (splitLenIdxD g (USegd.takeElements segd))
  where 
        -- Number of threads in gang.
        !nThreads = gangSize g

        -- Determine what elements go on a thread
        mk :: Int                  -- Thread index.
           -> (Int, Int)           -- Number of elements on this thread,
                                   --   and starting offset into the flat array.
           -> ((USegd, Int), Int)  -- Segd for this thread, segid of first slice,
                                   --   and offset of first slice.

        mk i (nElems, ixStart) 
         = case getChunk segd ixStart nElems (i == nThreads - 1) of
            (# lens, l, o #) -> ((USegd.fromLengths lens, l), o)

{-# NOINLINE splitSegdOnElemsD #-}
--  NOINLINE because this function has a large body of code and we don't want
--  to blow up the client modules by inlining it everywhere.


-------------------------------------------------------------------------------
-- | Determine what elements go on a thread.
--   The 'chunk' refers to the a chunk of the flat array, and is defined
--   by a set of segment slices. 
--
--   Example:
--    In this picture each X represents 5 elements, and we have 5 segements in total.
--
-- @
--    segs:    ----------------------- --- ------- --------------- -------------------
--    elems:  |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
--            |     thread1     |     thread2     |     thread3     |     thread4     |
--    segid:  0                 0                 3                 4
--    offset: 0                 45                0                 5
--    k:               0                 1                 3                 5
--    k':              1                 3                 5                 5
--    left:            0                 15                0                 45
--    right:           45                20                5                 0
--    left_len:        0                 1                 0                 1
--    left_off:        0                 45                0                 5
--    n':              1                 3                 2                 1
-- @
getChunk
        :: USegd          -- ^ Segment descriptor of entire array.
        -> Int            -- ^ Starting offset into the flat array for the first
                          --   slice on this thread.
        -> Int            -- ^ Number of elements in this thread.
        -> Bool           -- ^ Whether this is the last thread in the gang.
        -> (# Vector Int  --   Lengths of segment slices, 
            , Int         --     segid of first slice,
            , Int #)      --     offset of first slice.

getChunk !segd !nStart !nElems is_last
  = (# lens'', k-left_len, left_off #)
  where
    -- Lengths of all segments.
    -- eg: [60, 10, 20, 40, 50]
    !lens = USegd.takeLengths segd

    -- Indices indices of all segments.
    -- eg: [0, 60, 70, 90, 130]
    !idxs = USegd.takeIndices segd
    
    -- Total number of segments defined by segment descriptor.
    -- eg: 5
    !n    = Seq.length lens

    -- Segid of the first seg that starts after the left of this chunk.
    !k    = search nStart idxs

    -- Segid of the first seg that starts after the right of this chunk.
    !k'       | is_last     = n
              | otherwise   = search (nStart + nElems) idxs

    -- The length of the left-most slice of this chunk.
    !left     | k == n      = nElems
              | otherwise   = min ((Seq.index (here "getChunk") idxs k) - nStart) nElems

    -- The length of the right-most slice of this chunk.
    !right    | k' == k     = 0
              | otherwise   = nStart + nElems - (Seq.index (here "getChunk") idxs (k'-1))

    -- Whether the first element in this chunk is an internal element of
    -- of a segment. Alternatively, indicates that the first element of 
    -- the chunk is not the first element of a segment.            
    !left_len | left == 0   = 0
              | otherwise   = 1

    -- If the first element of the chunk starts within a segment, 
    -- then gives the index within that segment, otherwise 0.
    !left_off | left == 0   = 0
              | otherwise   = nStart - (Seq.index (here "getChunk") idxs (k-1))

    -- How many segments this chunk straddles.
    !n' = left_len + (k'-k)

    -- Create the lengths for this chunk by first copying out the lengths
    -- from the original segment descriptor. If the slices on the left
    -- and right cover partial segments, then we update the corresponding
    -- lengths.
    !lens' 
     = runST (do
            -- Create a new array big enough to hold all the lengths for this chunk.
            !mlens' <- Seq.newM n'

            -- If the first element is inside a segment, 
            --   then update the length to be the length of the slice.
            when (left /= 0) 
             $ Seq.write mlens' 0 left

            -- Copy out array lengths for this chunk.
            Seq.copy (Seq.mdrop left_len mlens')
                     (Seq.slice "getChunk" lens k (k'-k))

            -- If the last element is inside a segment, 
            --   then update the length to be the length of the slice.
            when (right /= 0)
             $ Seq.write mlens' (n' - 1) right

            Seq.unsafeFreeze mlens')

    !lens'' = lens'
{-      = trace 
        (render $ vcat
                [ text "CHUNK"
                , pprp segd
                , text "nStart:  " <+> int nStart
                , text "nElems:  " <+> int nElems
                , text "k:       " <+> int k
                , text "k':      " <+> int k'
                , text "left:    " <+> int left
                , text "right:   " <+> int right
                , text "left_len:" <+> int left_len
                , text "left_off:" <+> int left_off
                , text "n':      " <+> int n'
                , text ""]) lens'
-}

{-# INLINE getChunk #-}
--  INLINE even though it should be inlined into splitSSegdOnElemsD anyway
--  because that function contains the only use.


-------------------------------------------------------------------------------
-- O(log n). Given a monotonically increasing vector of `Int`s,
-- find the first element that is larger than the given value.
-- 
-- eg  search 75 [0, 60, 70, 90, 130] = 90
--     search 43 [0, 60, 70, 90, 130] = 60
--
search :: Int -> Vector Int -> Int
search !x ys = go 0 (Seq.length ys)
  where
    go i n | n <= 0        = i

           | Seq.index (here "search") ys mid < x  
           = go (mid + 1) (n - half - 1)

           | otherwise     = go i half
      where
        half = n `shiftR` 1
        mid  = i + half


-------------------------------------------------------------------------------
-- | time O(segs)
--   Join a distributed segment descriptor into a global one.
--   This simply joins the distributed lengths and indices fields, but does
--   not reconstruct the original segment descriptor as it was before splitting.
-- 
-- @ > pprp $ joinSegdD theGang4 
--         $ fstD $ fstD $ splitSegdOnElemsD theGang
--         $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50]
-- 
--   USegd lengths:  [45,15,10,20,40,5,45]
--         indices:  [0,45,60,70,90,130,135]
--         elements: 180
-- @
-- 
-- TODO: sequential runtime is O(segs) due to application of lengthsToUSegd
-- 
joinSegdD :: Gang -> Dist USegd -> USegd
joinSegdD gang
        = USegd.fromLengths
        . joinD gang unbalanced
        . mapD  gang USegd.takeLengths
{-# INLINE_DIST joinSegdD #-}


-------------------------------------------------------------------------------
-- | Glue a distributed segment descriptor back into the original global one.
--   Prop:  glueSegdD gang $ splitSegdOnElems gang usegd = usegd
--
--   NOTE: This is runs sequentially and should only be used for testing purposes.
--
glueSegdD :: Gang -> Dist ((USegd, Int), Int)  -> Dist USegd
glueSegdD gang bundle
 = let  !usegd           = fstD $ fstD $ bundle
        !lengths         = DUSegd.takeLengthsD usegd
                
        !firstSegOffsets = sndD bundle

        -- | Whether the last segment in this chunk extends into the next chunk.
        segSplits :: Dist Bool
        !segSplits
         = generateD_cheap gang $ \ix 
         -> if ix >= sizeD lengths - 1
             then False
             else indexD (here "glueSegdD") firstSegOffsets (ix + 1) /= 0

        !lengths'       = fst $ carryD gang (+)                  0 segSplits lengths
        !dusegd'        = mapD gang USegd.fromLengths lengths'

  in    dusegd'
{-# INLINE_DIST glueSegdD #-}


-------------------------------------------------------------------------------
splitSD :: Unbox a => Gang -> Dist USegd -> Vector a -> Dist (Vector a)
splitSD g dsegd xs
        = splitAsD g (DUSegd.takeElementsD dsegd) xs
{-# INLINE_DIST splitSD #-}

{-# RULES

"splitSD/splitJoinD" forall g d f xs.
  splitSD g d (splitJoinD g f xs) = f (splitSD g d xs)

"splitSD/Seq.zip" forall g d xs ys.
  splitSD g d (Seq.zip xs ys) = zipWithD g Seq.zip (splitSD g d xs)
                                             (splitSD g d ys)

  #-}