{-# LANGUAGE CPP #-}
{-# OPTIONS -Wall -fno-warn-orphans #-}
#include "fusion-phases.h"

-- | Scattered Segment Descriptors.
--
--   See "Data.Array.Parallel.Unlifted" for how this works.
--
module Data.Array.Parallel.Unlifted.Sequential.USSegd 
        ( -- * Types
          USSegd(..)
        , valid

        -- * Constructors
        , mkUSSegd
        , empty
        , singleton
        , fromUSegd
        
        -- * Predicates
        , isContiguous
        
        -- * Projections
        , length
        , takeUSegd, takeLengths, takeIndices, takeElements
        , takeSources, takeStarts
        , getSeg
        
        -- * Operators
        , appendWith
        , cullOnVSegids)
where
import Data.Array.Parallel.Unlifted.Sequential.USegd            (USegd)
import Data.Array.Parallel.Unlifted.Sequential.Vector           (Vector)
import Data.Array.Parallel.Pretty                               hiding (empty)
import Prelude                                                  hiding (length)
import qualified Data.Array.Parallel.Unlifted.Sequential.USegd  as USegd
import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as U

here :: String -> String 
here s = "Data.Array.Parallel.Unlifted.Sequential.USSegd." ++ s


-- USSegd ---------------------------------------------------------------------
-- | Scattered Segment Descriptor.
data USSegd
        = USSegd
        { ussegd_contiguous     :: !Bool
          -- ^ True when the starts are identical to the usegd indices field
          --        and the sources are all 0's. 
          --
          --   In this case all the data elements are in one contiguous flat
          --   array, and consumers can avoid looking at the real starts and
          --   sources fields.

        , ussegd_starts         :: Vector Int
          -- ^ Starting index of each segment in its flat array.
          -- 
          --   IMPORTANT: this field is lazy so we can avoid creating it when
          --              the flat array is contiguous.

        , ussegd_sources        :: Vector Int
          -- ^ Which flat array to take each segment from.
          -- 
          --   IMPORTANT: this field is lazy so we can avoid creating it when
          --              the flat array is contiguous.

        , ussegd_usegd          :: !USegd
          -- ^ Segment descriptor relative to a contiguous index space.
          --   This defines the length of each segment.
        }
        deriving (Show)


-- | Pretty print the physical representation of a `UVSegd`
instance PprPhysical USSegd where
 pprp (USSegd _ starts sources ssegd)
  = vcat
  [ text "USSegd" 
        $$ (nest 7 $ vcat
                [ text "starts:  " <+> (text $ show $ U.toList starts)
                , text "sources: " <+> (text $ show $ U.toList sources) ])
  , pprp ssegd ]


-- Constructors ---------------------------------------------------------------
-- | O(1). Construct a new scattered segment descriptor.
--   All the provided arrays must have the same lengths.
mkUSSegd
        :: Vector Int   -- ^ Starting index of each segment in its flat array.
        -> Vector Int   -- ^ Which array to take each segment from.
        -> USegd        -- ^ Contiguous segment descriptor.
        -> USSegd

mkUSSegd = USSegd False
{-# INLINE mkUSSegd #-}


-- | O(1). Check the internal consistency of a scattered segment descriptor.
valid :: USSegd -> Bool
valid (USSegd _ starts srcids usegd)
        =  (U.length starts == USegd.length usegd)
        && (U.length srcids == USegd.length usegd)

{-# NOINLINE valid #-}
--  NOINLINE because it's only enabled during debugging anyway.


-- | O(1). Construct an empty segment descriptor, with no elements or segments.
empty :: USSegd
empty   = USSegd True U.empty U.empty USegd.empty
{-# INLINE_U empty #-}


-- | O(1). Construct a singleton segment descriptor.
--   The single segment covers the given number of elements in a flat array
--   with sourceid 0.
singleton :: Int -> USSegd
singleton n 
        = USSegd True (U.singleton 0) (U.singleton 0) (USegd.singleton n)
{-# INLINE_U singleton #-}


-- | O(segs). Promote a plain `USegd` to a `USSegd`.
--   All segments are assumed to come from a flat array with sourceid 0.
fromUSegd :: USegd -> USSegd
fromUSegd usegd
        = USSegd True 
                 (USegd.takeIndices usegd)
                 (U.replicate (USegd.length usegd) 0)
                 usegd
{-# INLINE_U fromUSegd #-}


-- Predicates -----------------------------------------------------------------
-- INLINE trivial projections as they'll expand to a single record selector.
-- | O(1). True when the starts are identical to the usegd indices field and
--   the sources are all 0's. 
--
--   In this case all the data elements are in one contiguous flat
--   array, and consumers can avoid looking at the real starts and
--   sources fields.
--
isContiguous :: USSegd -> Bool
isContiguous    = ussegd_contiguous
{-# INLINE isContiguous #-}


-- Projections ----------------------------------------------------------------
-- INLINE trivial projections as they'll expand to a single record selector.

-- | O(1). Yield the overall number of segments.
length :: USSegd -> Int
length          = USegd.length . ussegd_usegd 
{-# INLINE length #-}


-- | O(1). Yield the `USegd` of a `USSegd`.
takeUSegd   :: USSegd -> USegd
takeUSegd       = ussegd_usegd
{-# INLINE takeUSegd #-}


-- | O(1). Yield the lengths of the segments of a `USSegd`.
takeLengths :: USSegd -> Vector Int
takeLengths     = USegd.takeLengths . ussegd_usegd
{-# INLINE takeLengths #-}


-- | O(1). Yield the segment indices of a `USSegd`.
takeIndices :: USSegd -> Vector Int
takeIndices     = USegd.takeIndices . ussegd_usegd
{-# INLINE takeIndices #-}


-- | O(1). Yield the total number of elements covered by a `USSegd`.
takeElements :: USSegd -> Int
takeElements    = USegd.takeElements . ussegd_usegd
{-# INLINE takeElements #-}


-- | O(1). Yield the starting indices of a `USSegd`.
takeStarts :: USSegd -> Vector Int
takeStarts      = ussegd_starts
{-# INLINE takeStarts #-}


-- | O(1). Yield the source ids of a `USSegd`.
takeSources :: USSegd -> Vector Int
takeSources     = ussegd_sources
{-# INLINE takeSources #-}


-- | O(1). Get the length, segment index, starting index, and source id of a segment.
getSeg :: USSegd -> Int -> (Int, Int, Int, Int)
getSeg (USSegd _ starts sources usegd) ix
 = let  (len, ixl) = USegd.getSeg usegd ix
   in   ( len
        , ixl
        , U.index (here "getSeg") starts  ix
        , U.index (here "getSeg") sources ix)
{-# INLINE_U getSeg #-}


-- Operators ==================================================================

-- | O(n). Produce a segment descriptor that describes the result of appending
--   two arrays.
appendWith
        :: USSegd               -- ^ Segment descriptor of first nested array.
        -> Int                  -- ^ Number of flat data arrays used to represent first nested array.
        -> USSegd               -- ^ Segment descriptor of second nested array. 
        -> Int                  -- ^ Number of flat data arrays used to represent second nested array.
        -> USSegd
appendWith
        (USSegd _ starts1 srcs1 usegd1) pdatas1
        (USSegd _ starts2 srcs2 usegd2) _
        = USSegd False
                 (starts1  U.++  starts2)
                 (srcs1    U.++  U.map (+ pdatas1) srcs2)
                 (USegd.append usegd1 usegd2)
{-# NOINLINE appendWith #-}
--  NOINLINE because we're worried about code explosion. Might be useful though.


-- | Cull the segments of a `USSegd` down to only those reachable from an array
--   of @vsegids@, and also update the @vsegids@ to point to the same segments
--   in the result.
--
cullOnVSegids :: Vector Int -> USSegd -> (Vector Int, USSegd)
cullOnVSegids vsegids (USSegd _ starts sources usegd)
 = {-# SCC "cullOnVSegids" #-}
 let    -- Determine which of the psegs are still reachable from the vsegs.
        -- This produces an array of flags, 
        --    with reachable   psegs corresponding to 1
        --    and  unreachable psegs corresponding to 0
        -- 
        --  eg  vsegids:        [0 1 1 3 5 5 6 6]
        --   => psegids_used:   [1 1 0 1 0 1 1]
        --  
        --  Note that psegids '2' and '4' are not in vsegids_packed.
        psegids_used
         = U.bpermuteDft (USegd.length usegd)
                         (const False)
                         (U.zip vsegids (U.replicate (U.length vsegids) True))

        -- Produce an array of used psegs.
        --  eg  psegids_used:   [1 1 0 1 0 1 1]
        --      psegids_packed: [0 1 3 5 6]
        psegids_packed
         = U.pack (U.enumFromTo 0 (U.length psegids_used)) psegids_used

        -- Produce an array that maps psegids in the source array onto
        -- psegids in the result array. If a particular pseg isn't present
        -- in the result this maps onto -1.

        --  Note that if psegids_used has 0 in some position, then psegids_map
        --  has -1 in the same position, corresponding to an unused pseg.
         
        --  eg  psegids_packed: [0 1 3 5 6]
        --                      [0 1 2 3 4]
        --      psegids_map:    [0 1 -1 2 -1 3 4]
        psegids_map
         = U.bpermuteDft (USegd.length usegd)
                         (const (-1))
                         (U.zip psegids_packed (U.enumFromTo 0 (U.length psegids_packed - 1)))

        -- Use the psegids_map to rewrite the packed vsegids to point to the 
        -- corresponding psegs in the result.
        -- 
        --  eg  vsegids:        [0 1 1 3 5 5 6 6]
        --      psegids_map:    [0 1 -1 2 -1 3 4]
        -- 
        --      vsegids':       [0 1 1 2 3 3 4 4]
        --
        vsegids'  = U.map (U.index (here "cullOnVSegids") psegids_map) vsegids

        -- Rebuild the usegd.
        starts'   = U.pack starts  psegids_used
        sources'  = U.pack sources psegids_used

        lengths'  = U.pack (USegd.takeLengths usegd) psegids_used
        usegd'    = USegd.fromLengths lengths'
        
        ussegd'   = USSegd False starts' sources' usegd'

     in  (vsegids', ussegd')

{-# NOINLINE cullOnVSegids #-}
--  NOINLINE because it's complicated and won't fuse with anything
--  This can also be expensive and we want to see the SCC in profiling builds.