{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
module Data.Repa.Scalar.Product
        ( -- * Product type
          (:*:)         (..)
        , IsProdList    (..)
        , IsKeyValues   (..)

          -- * Selecting
        , Select        (..)

          -- * Discarding
        , Discard       (..)

          -- * Masking
        , Mask          (..)
        , Keep          (..)
        , Drop          (..))
where
import Data.Repa.Scalar.Singleton.Nat
import qualified Data.Vector.Unboxed            as U
import qualified Data.Vector.Generic            as G
import qualified Data.Vector.Generic.Mutable    as M


---------------------------------------------------------------------------------------------------
-- | A strict product type, written infix.
data a :*: b
        = !a :*: !b
        deriving (Eq, Show)

infixr :*:

instance Functor ((:*:) a) where
 fmap f ((:*:) x y) = (:*:) x (f y)


-- | Sequences of products that form a valid list, 
--   using () for the nil value.
class IsProdList p where
 --
 -- @
 -- isProdList (1 :*: 4 :*: 5)  ... no instance
 --
 -- isProdList (1 :*: 4 :*: ()) = True
 -- @
 --
 isProdList :: p -> Bool

instance IsProdList () where
 isProdList _ = True
 {-# INLINE isProdList #-}

instance IsProdList fs => IsProdList (f :*: fs) where
 isProdList (_ :*: xs) = isProdList xs
 {-# INLINE isProdList #-}


-- Key-value pairs---------------------------------------------------------------------------------
-- | Sequences of products and tuples that form hetrogeneous key-value pairs.
class IsKeyValues p where
 type Keys   p
 type Values p

 -- | Get a cons-list of all the keys.
 keys   :: p -> [Keys p]

 -- | Get a heterogeneous product-list of all the values.
 values :: p -> Values p


instance IsKeyValues (k, v) where
 type Keys   (k, v)     = k
 type Values (k, v)     = v

 keys   (k, _)          = [k]
 {-# INLINE keys #-}

 values (_, v)          = v
 {-# INLINE values #-}


instance (IsKeyValues p, IsKeyValues ps, Keys p ~ Keys ps)
       => IsKeyValues (p :*: ps) where

 type Keys   (p :*: ps) = Keys p
 type Values (p :*: ps) = Values p :*: Values ps

 keys   (p :*: ps)      = keys p   ++  keys ps
 {-# INLINE keys #-}

 values (p :*: ps)      = values p :*: values ps
 {-# INLINE values #-}


---------------------------------------------------------------------------------------------------
class    IsProdList t
      => Select  (n :: N) t where
 type    Select'    n        t
 -- | Return the element with this index from a product list.
 select ::          Nat n -> t -> Select' n t


instance IsProdList ts
      => Select  Z    (t1 :*: ts) where
 type Select'    Z    (t1 :*: ts) = t1
 select       Zero    (t1 :*: _)  = t1
 {-# INLINE select #-}


instance Select n ts
      => Select (S n) (t1 :*: ts) where
 type Select'   (S n) (t1 :*: ts) = Select' n ts
 select      (Succ n) (_  :*: xs) = select  n xs
 {-# INLINE select #-}


---------------------------------------------------------------------------------------------------
class    IsProdList t
      => Discard (n :: N) t where
 type    Discard'   n        t
 -- | Discard the element with this index from a product list.
 discard ::         Nat n -> t -> Discard' n t


instance IsProdList ts
      => Discard Z     (t1 :*: ts) where
 type Discard'   Z     (t1 :*: ts) = ts
 discard      Zero     (_  :*: xs) = xs
 {-# INLINE discard #-}


instance Discard n ts
      => Discard (S n) (t1 :*: ts) where
 type Discard'   (S n) (t1 :*: ts) = t1 :*: Discard' n ts
 discard      (Succ n) (x1 :*: xs) = x1 :*: discard  n xs
 {-# INLINE discard #-}


---------------------------------------------------------------------------------------------------
-- | Singleton to indicate a field that should be dropped.
data Drop = Drop

-- | Singleton to indicate a field that should be kept.
data Keep = Keep


-- | Class of data types that can have parts masked out.
class (IsProdList m, IsProdList t) => Mask  m t where
 type Mask' m t
 -- | Mask out some component of a type.
 --
 -- @  
 -- mask (Keep :*: Drop  :*: Keep :*: ()) 
 --      (1    :*: \"foo\" :*: \'a\'  :*: ())   =   (1 :*: \'a\' :*: ())
 --
 -- mask (Drop :*: Drop  :*: Drop :*: ()) 
 --      (1    :*: \"foo\" :*: \'a\'  :*: ())   =   ()
 -- @
 --
 mask :: m -> t -> Mask' m t


instance Mask () () where
 type Mask' ()   () = ()
 mask       ()   () = ()
 {-# INLINE mask #-}


instance Mask ms ts
      => Mask (Keep :*: ms) (t1 :*: ts) where
 type Mask'   (Keep :*: ms) (t1 :*: ts) = t1 :*: Mask' ms ts
 mask         (_    :*: ms) (x1 :*: xs) = x1 :*: mask  ms xs
 {-# INLINE mask #-}


instance Mask ms ts
      => Mask (Drop :*: ms) (t1 :*: ts) where
 type Mask'   (Drop :*: ms) (t1 :*: ts) = Mask' ms ts
 mask         (_    :*: ms) (_  :*: xs) = mask  ms xs
 {-# INLINE mask #-}


-- Unboxed ----------------------------------------------------------------------------------------
-- Unboxed instance adapted from:
-- http://code.haskell.org/vector/internal/unbox-tuple-instances
--
data instance U.Vector (a :*: b)
  = V_Prod
        {-# UNPACK #-} !Int
        !(U.Vector a)
        !(U.Vector b)


instance (U.Unbox a, U.Unbox b)
       => U.Unbox (a :*: b)


data instance U.MVector s (a :*: b)
  = MV_Prod {-# UNPACK #-} !Int
        !(U.MVector s a)
        !(U.MVector s b)


instance (U.Unbox a, U.Unbox b)
      => M.MVector U.MVector (a :*: b) where

  basicLength (MV_Prod n_ _as _bs) = n_
  {-# INLINE basicLength  #-}

  basicUnsafeSlice i_ m_ (MV_Prod _n_ as bs)
   = MV_Prod m_ (M.basicUnsafeSlice i_ m_ as)
                (M.basicUnsafeSlice i_ m_ bs)
  {-# INLINE basicUnsafeSlice  #-}

  basicOverlaps (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
   =  M.basicOverlaps as1 as2
   || M.basicOverlaps bs1 bs2
  {-# INLINE basicOverlaps  #-}

  basicUnsafeNew n_
   = do as <- M.basicUnsafeNew n_
        bs <- M.basicUnsafeNew n_
        return $ MV_Prod n_ as bs
  {-# INLINE basicUnsafeNew  #-}

  basicUnsafeReplicate n_ (a :*: b)
   = do as <- M.basicUnsafeReplicate n_ a
        bs <- M.basicUnsafeReplicate n_ b
        return $ MV_Prod n_ as bs
  {-# INLINE basicUnsafeReplicate  #-}

  basicUnsafeRead (MV_Prod _n_ as bs) i_
   = do a <- M.basicUnsafeRead as i_
        b <- M.basicUnsafeRead bs i_
        return (a :*: b)
  {-# INLINE basicUnsafeRead  #-}

  basicUnsafeWrite (MV_Prod _n_ as bs) i_ (a :*: b)
   = do M.basicUnsafeWrite as i_ a
        M.basicUnsafeWrite bs i_ b
  {-# INLINE basicUnsafeWrite  #-}

  basicClear (MV_Prod _n_ as bs)
   = do M.basicClear as
        M.basicClear bs
  {-# INLINE basicClear  #-}

  basicSet (MV_Prod _n_ as bs) (a :*: b)
   = do M.basicSet as a
        M.basicSet bs b
  {-# INLINE basicSet  #-}

  basicUnsafeCopy (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
   = do M.basicUnsafeCopy as1 as2
        M.basicUnsafeCopy bs1 bs2
  {-# INLINE basicUnsafeCopy  #-}

  basicUnsafeMove (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
   = do M.basicUnsafeMove as1 as2
        M.basicUnsafeMove bs1 bs2
  {-# INLINE basicUnsafeMove  #-}

  basicUnsafeGrow (MV_Prod n_ as bs) m_
   = do as' <- M.basicUnsafeGrow as m_
        bs' <- M.basicUnsafeGrow bs m_
        return $ MV_Prod (m_ + n_) as' bs'
  {-# INLINE basicUnsafeGrow  #-}


instance  (U.Unbox a, U.Unbox b)
        => G.Vector U.Vector (a :*: b) where
  basicUnsafeFreeze (MV_Prod n_ as bs)
   = do as' <- G.basicUnsafeFreeze as
        bs' <- G.basicUnsafeFreeze bs
        return $ V_Prod n_ as' bs'
  {-# INLINE basicUnsafeFreeze  #-}

  basicUnsafeThaw (V_Prod n_ as bs)
   = do as' <- G.basicUnsafeThaw as
        bs' <- G.basicUnsafeThaw bs
        return $ MV_Prod n_ as' bs'
  {-# INLINE basicUnsafeThaw  #-}

  basicLength (V_Prod n_ _as _bs)
   = n_
  {-# INLINE basicLength  #-}

  basicUnsafeSlice i_ m_ (V_Prod _n_ as bs)
   = V_Prod m_ (G.basicUnsafeSlice i_ m_ as)
               (G.basicUnsafeSlice i_ m_ bs)
  {-# INLINE basicUnsafeSlice  #-}

  basicUnsafeIndexM (V_Prod _n_ as bs) i_
   = do a <- G.basicUnsafeIndexM as i_
        b <- G.basicUnsafeIndexM bs i_
        return (a :*: b)
  {-# INLINE basicUnsafeIndexM  #-}

  basicUnsafeCopy (MV_Prod _n_1 as1 bs1) (V_Prod _n_2 as2 bs2)
   = do G.basicUnsafeCopy as1 as2
        G.basicUnsafeCopy bs1 bs2
  {-# INLINE basicUnsafeCopy  #-}

  elemseq _ (a :*: b)
      = G.elemseq (undefined :: U.Vector a) a
      . G.elemseq (undefined :: U.Vector b) b
  {-# INLINE elemseq  #-}