{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
{-# LANGUAGE CPP #-}
#include "fusion-phases.h"

-- | Operations on Distributed Segment Descriptors
module Data.Array.Parallel.Unlifted.Distributed.Data.USSegd.Split 
        (splitSSegdOnElemsD)
where
import Data.Array.Parallel.Unlifted.Distributed.Arrays
import Data.Array.Parallel.Unlifted.Distributed.Combinators
import Data.Array.Parallel.Unlifted.Distributed.Primitive
import Data.Array.Parallel.Unlifted.Sequential.USSegd                   (USSegd)
import Data.Array.Parallel.Unlifted.Sequential.Vector                   (Vector)
import Data.Array.Parallel.Base
import Data.Bits                                                        (shiftR)
import Control.Monad                                                    (when)
import Data.Array.Parallel.Unlifted.Distributed.Data.USSegd.DT          ()
import qualified Data.Array.Parallel.Unlifted.Sequential.USegd          as USegd
import qualified Data.Array.Parallel.Unlifted.Sequential.USSegd         as USSegd
import qualified Data.Array.Parallel.Unlifted.Sequential.Vector         as Seq
import Debug.Trace

here :: String -> String
here s = "Data.Array.Parallel.Unlifted.Distributed.USSegd." ++ s

-------------------------------------------------------------------------------
-- | Split a segment descriptor across the gang, element wise.
--   We try to put the same number of elements on each thread, which means
--   that segments are sometimes split across threads.
--
--   Each thread gets a slice of segment descriptor, the segid of the first 
--   slice, and the offset of the first slice in its segment.
--   
--   Example:
--    In this picture each X represents 5 elements, and we have 5 segements in total.
--
-- @   segs:    ----------------------- --- ------- --------------- -------------------
--    elems:  |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
--            |     thread1     |     thread2     |     thread3     |     thread4     |
--    segid:  0                 0                 3                 4
--    offset: 0                 45                0                 5
--
--    pprp $ splitSegdOnElemsD theGang 
--          $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50 :: Int]
--
--     segd:    DUSegd lengths:  DVector lengths: [1,3,2,1]
--                                        chunks:  [[45],[15,10,20],[40,5],[45]]
--                     indices:  DVector lengths: [1,3,2,1]
--                                        chunks:  [[0], [0,15,25], [0,40],[0]]
--                    elements:  DInt [45,45,45,45]
--
--     segids: DInt [0,0,3,4]     (segment id of first slice on thread)
--    offsets: DInt [0,45,0,5]    (offset of that slice in its segment)
-- @
--
splitSSegdOnElemsD :: Gang -> USSegd -> Dist ((USSegd,Int),Int)
splitSSegdOnElemsD g !segd 
  = {-# SCC "splitSSegdOnElemsD" #-}
    traceEvent ("dph-prim-par: USSegd.splitSSegdOnElems")
  $ imapD (What "UPSSegd.splitSSegdOnElems/splitLenIx") g mk 
          (splitLenIdxD g (USegd.takeElements $ USSegd.takeUSegd segd))
  where 
        -- Number of threads in gang.
        !nThreads = gangSize g


        -- Build a USSegd from just the lengths, starts and sources fields.
        --   The indices and elems fields of the contained USegd are 
        --   generated from the lengths.
        buildUSSegd :: Vector Int -> Vector Int -> Vector Int -> USSegd
        buildUSSegd lengths starts sources
                = USSegd.mkUSSegd starts sources
                $ USegd.fromLengths lengths

        -- Determine what elements go on a thread
        mk :: Int                  -- Thread index.
           -> (Int, Int)           -- Number of elements on this thread,
                                   --   and starting offset into the flat array.
           -> ((USSegd, Int), Int) -- Segd for this thread, segid of first slice,
                                   --   and offset of first slice.

        mk i (nElems, ixStart) 
         = case chunk segd ixStart nElems (i == nThreads - 1) of
            (# lengths, starts, sources, l, o #) 
             -> ((buildUSSegd lengths starts sources, l), o)

{-# NOINLINE splitSSegdOnElemsD #-}
--  NOINLINE because it's complicated and won't fuse with anything.
--  This function has a large body of code and we don't want to blow up
--  the client modules by inlining it everywhere.


-------------------------------------------------------------------------------
-- | Determine what elements go on a thread.
--   The 'chunk' refers to the a chunk of the flat array, and is defined
--   by a set of segment slices. 
--
--   Example:
--    In this picture each X represents 5 elements, and we have 5 segements in total.
--
-- @  segs:    ----------------------- --- ------- --------------- -------------------
--    elems:  |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
--            |     thread1     |     thread2     |     thread3     |     thread4     |
--    segid:  0                 0                 3                 4
--    offset: 0                 45                0                 5
--    k:               0                 1                 3                 5
--    k':              1                 3                 5                 5
--    left:            0                 15                0                 45
--    right:           45                20                5                 0
--    left_len:        0                 1                 0                 1
--    left_off:        0                 45                0                 5
--    n':              1                 3                 2                 1
-- @
chunk   :: USSegd          -- ^ Segment descriptor of entire array.
        -> Int            -- ^ Starting offset into the flat array for the first
                          --    slice on this thread.
        -> Int            -- ^ Number of elements in this thread.
        -> Bool           -- ^ Whether this is the last thread in the gang.
        -> (# Vector Int  --    Lengths of segment slices, 
            , Vector Int  --    Starting index of data in its vector
            , Vector Int  --    Source id
            , Int         --    segid of first slice
            , Int #)      --    offset of first slice.

chunk !ussegd !nStart !nElems is_last
  = (# lengths', starts', sources', k-left_len, left_off #)
  where
    -- Lengths of all segments.
    -- eg: [60, 10, 20, 40, 50]
    lengths     = USSegd.takeLengths ussegd

    -- Indices indices of all segments.
    -- eg: [0, 60, 70, 90, 130]
    indices     = USSegd.takeIndices ussegd

    -- Starting indices for all segments.
    starts      = USSegd.takeStarts ussegd

    -- Source ids for all segments.
    sources     = USSegd.takeSources ussegd
    
    -- Total number of segments defined by segment descriptor.
    -- eg: 5
    n    = Seq.length lengths

    -- Segid of the first seg that starts after the left of this chunk.
    k    = search nStart indices

    -- Segid of the first seg that starts after the right of this chunk.
    k'       | is_last     = n
             | otherwise   = search (nStart + nElems) indices

    -- The length of the left-most slice of this chunk.
    left     | k == n      = nElems
             | otherwise   = min ((Seq.index (here "chunk") indices k) - nStart) nElems

    -- The length of the right-most slice of this chunk.
    length_right   
             | k' == k     = 0
             | otherwise   = nStart + nElems - (Seq.index (here "chunk") indices (k'-1))

    -- Whether the first element in this chunk is an internal element of
    -- of a segment. Alternatively, indicates that the first element of 
    -- the chunk is not the first element of a segment.            
    left_len | left == 0   = 0
             | otherwise   = 1

    -- If the first element of the chunk starts within a segment, 
    -- then gives the index within that segment, otherwise 0.
    left_off | left == 0   = 0
             | otherwise   = nStart - (Seq.index (here "chunk") indices (k-1))

    -- How many segments this chunk straddles.
    n' = left_len + (k'-k)

    -- Create the lengths for this chunk by first copying out the lengths
    -- from the original segment descriptor. If the slices on the left
    -- and right cover partial segments, then we update the corresponding
    -- lengths.
    (!lengths', !starts', !sources')
     = runST (do
            -- Create a new array big enough to hold all the lengths for this chunk.
            mlengths' <- Seq.newM n'
            msources' <- Seq.newM n'
            mstarts'  <- Seq.newM n'

            -- If the first element is inside a segment, 
            --   then update the length to be the length of the slice.
            when (left /= 0) 
             $ do Seq.write mlengths' 0 left
                  Seq.write mstarts'  0 (Seq.index (here "chunk") starts  (k - left_len) + left_off)
                  Seq.write msources' 0 (Seq.index (here "chunk") sources (k - left_len))

            -- Copy out array lengths for this chunk.
            Seq.copy (Seq.mdrop left_len mlengths') (Seq.slice (here "chunk") lengths k (k'-k))
            Seq.copy (Seq.mdrop left_len mstarts')  (Seq.slice (here "chunk")  starts k (k'-k))
            Seq.copy (Seq.mdrop left_len msources') (Seq.slice (here "chunk") sources k (k'-k))

            -- If the last element is inside a segment, 
            --   then update the length to be the length of the slice.
            when (length_right /= 0)
             $ do Seq.write mlengths' (n' - 1) length_right

            clengths' <- Seq.unsafeFreeze mlengths'
            cstarts'  <- Seq.unsafeFreeze mstarts'
            csources' <- Seq.unsafeFreeze msources'
            return (clengths', cstarts', csources'))

{-      = trace 
        (render $ vcat
                [ text "CHUNK"
                , pprp segd
                , text "nStart:  " <+> int nStart
                , text "nElems:  " <+> int nElems
                , text "k:       " <+> int k
                , text "k':      " <+> int k'
                , text "left:    " <+> int left
                , text "right:   " <+> int right
                , text "left_len:" <+> int left_len
                , text "left_off:" <+> int left_off
                , text "n':      " <+> int n'
                , text ""]) lens'
-}

{-# INLINE_DIST chunk #-}
--  INLINE_DIST even though it should be inlined into splitSSegdOnElemsD anyway
--  because that function contains the only use.


-------------------------------------------------------------------------------
-- O(log n).
-- Given a monotonically increasing vector of `Int`s,
-- find the first element that is larger than the given value.
-- 
-- eg  search 75 [0, 60, 70, 90, 130] = 90
--     search 43 [0, 60, 70, 90, 130] = 60
--
search :: Int -> Vector Int -> Int
search !x ys = go 0 (Seq.length ys)
  where
    go i n | n <= 0        = i
           | Seq.index (here "search") ys mid < x
           = go (mid + 1) (n - half - 1)
           | otherwise     = go i half
      where
        half = n `shiftR` 1
        mid  = i + half
{-# INLINE_DIST search #-}
--  INLINE_DIST because we want it inlined into both uses in 'chunk' above.