{-# language FlexibleContexts, TypeFamilies #-}
module Data.Sparse.Internal.SVector.Mutable where

import Data.Foldable
import qualified Data.Vector as V 
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Generic as VG

import Control.Arrow ((&&&))
import Control.Monad.Primitive

data SMVector m a = SMV { smvDim :: {-# UNPACK #-} !Int,
                          smvIx :: VM.MVector (PrimState m) Int,
                          smvVal :: VM.MVector (PrimState m) a }


fromList :: (PrimMonad m) => Int -> [(Int, a)] -> m (SMVector m a)
fromList n ixv = do
  let ns = length ixv
  vm <- VM.new ns
  ixm <- VM.new ns
  fromListOverwrite ixm vm n ixv
  

fromListOverwrite :: PrimMonad m =>
     VM.MVector (PrimState m) Int
     -> VM.MVector (PrimState m) a
     -> Int
     -> [(Int, a)]
     -> m (SMVector m a)
fromListOverwrite ixm vm n ixv = go ixm vm ixv 0
    where
      go ixm_ vm_ [] _ = return $ SMV n ixm_ vm_
      go ixm_ vm_ ((i, x) : xs) iwrite = do
        VM.write vm_ iwrite x
        VM.write ixm_ iwrite i
        go ixm_ vm_ xs (iwrite + 1)
      


-- instance Foldable (SMVector m) where
--   foldr f z (SMV _ _ v) = foldr f z v

-- instance Traversable (SMVector m) where
--   traverse f (SMV n ix v) = SMV n ix <$> traverse f v



-- | traverse the sparse mutable vector and operate on the union set between the indices of the immutable vector `v` and those of the mutable one `vm`, _overwriting_ the values in vm.
-- Invariant: the index set of `v` is a strict subset of that of `vm` (i.e. we assume that `vm` is preallocated properly, or, we assume there won't be any out-of-bound writes attempted)
unionWithM_prealloc g z v vm@(SMV nvm vmix vmv) = undefined








-- -- | unionWithSMV takes the union of two sparse mutable vectors given a binary function and a neutral element.

-- -- unionWithSMV :: PrimMonad m =>
-- --      (a -> a -> b)
-- --      -> a
-- --      -> SMVector m a
-- --      -> SMVector m a
-- --      -> m (VM.MVector (PrimState m) (Int, b))

-- unionWithSMV g z (SMV n1 ixu uu) (SMV n2 ixv vv) = do
--   vm0 <- VM.new n
--   (vm, nfin) <- go ixu uu ixv vv 0 vm0  -- populate
--   -- let (ixm, vm') = VG.unzip (VM.take nfin vm')
--   let vm_trim = VM.take nfin vm
--       (ixm, vm') = VG.unzip $ VG.convert vm_trim
--   vOut <- V.freeze ixm
--   return (vOut, vm')
--   where
--     headTail = V.head &&& V.tail
--     n = min n1 n2
--     go iu u_ iv v_ i vm
--           | (VM.null u_ && VM.null v_) || i == n = return (vm , i)
--           | VM.null u_ = do
--               v0 <- VM.read v_ 0
--               VM.write vm i (V.head iv, g z v0)
--               go iu u_ (V.tail iv) (VM.tail v_) (i + 1) vm
--           | VM.null v_ = do
--               u0 <- VM.read u_ 0
--               VM.write vm i (V.head iu, g u0 z)
--               go (V.tail iu) (VM.tail u_) iv v_ (i + 1) vm
--           | otherwise =  do
--              u <- VM.read u_ 0
--              v <- VM.read v_ 0
--              let us = VM.tail u_
--                  vs = VM.tail v_
--              let (iu1, ius) = headTail iu
--                  (iv1, ivs) = headTail iv
--              if iu1 == iv1 then do VM.write vm i (iu1, g u v)
--                                    go ius us ivs vs (i + 1) vm
--                            else if iu1 < iv1 then do VM.write vm i (iu1, g u z)
--                                                      go ius us iv v_ (i + 1) vm
--                                              else do VM.write vm i (iv1, g z v)
--                                                      go iu u_ ivs vs (i + 1) vm










-- -- * helpers

-- -- both :: Arrow arr => arr b c -> arr (b, b) (c, c)
-- -- both f = f *** f