{-# language TypeFamilies, FlexibleInstances, MultiParamTypeClasses, CPP #-}
module Data.Sparse.Internal.SVector where

import Control.Arrow
import Control.Monad (unless)
import Control.Monad.IO.Class
import qualified Data.Foldable as F -- (foldl')
-- import Data.List (group, groupBy)

import qualified Data.Vector as V 
-- import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Mutable as VM
-- import qualified Data.Vector.Algorithms.Merge as VA (sortBy)
-- import qualified Data.Vector.Generic as VG (convert)

import Data.Complex
import Foreign.C.Types (CSChar, CInt, CShort, CLong, CLLong, CIntMax, CFloat, CDouble)
import Control.Monad.Primitive

-- import Data.Sparse.Utils
-- import Data.Sparse.Types
import Data.Sparse.Internal.SVector.Mutable hiding (fromList)
import qualified Data.Sparse.Internal.SVector.Mutable as SMV (fromList)


import Numeric.LinearAlgebra.Class





data SVector a = SV { svDim :: {-# UNPACK #-} !Int,
                      svIx :: V.Vector Int,
                      svVal :: V.Vector a } deriving Eq

instance Show a => Show (SVector a) where
  show (SV n ix v) = unwords ["SV (",show n,"),",show nz,"NZ:",show (V.zip ix v)]
    where nz = V.length ix

instance Functor SVector where
  fmap f (SV n ix v) = SV n ix (fmap f v)

instance Foldable SVector where
  foldr f z (SV _ _ v) = foldr f z v

instance Traversable SVector where
  traverse f (SV n ix v) = SV n ix <$> traverse f v

instance HasData (SVector a) where
  nnz = length . svIx
  -- dat (SV _ _ x) = x



-- ** Construction 

fromDense :: V.Vector a -> SVector a
fromDense xs = SV n (V.enumFromTo 0 (n-1)) xs where
  n = V.length xs

fromVector :: Int -> V.Vector (Int, a) -> SVector a
fromVector n ixs = SV n ix xs where
  (ix, xs) = V.unzip ixs

fromList :: Int -> [(Int, a)] -> SVector a
fromList n = fromVector n . V.fromList

-- * Query

-- | O(N) : Lookup an index in a SVector (based on `find` from Data.Foldable)
index :: SVector a -> Int -> Maybe a
index cv i =
  case F.find (== i) (svIx cv) of
    Just i' -> Just $ svVal cv V.! i'
    Nothing -> Nothing
      


-- * Lifted binary functions

-- | O(N) : Applies a function to the index _intersection_ of two CsrVector s. Useful e.g. to compute the inner product of two sparse vectors.
intersectWith :: (a -> b -> c) -> SVector a -> SVector b -> SVector c
intersectWith g v1 v2 = SV nfin ixf vf where
   nfin = V.length vf
   (ixf, vf) = V.unzip $ V.create (intersectWithM g v1 v2)

intersectWithM :: PrimMonad m =>
     (a -> b -> c)
     -> SVector a
     -> SVector b
     -> m (VM.MVector (PrimState m) (Int, c))
intersectWithM g (SV n1 ixu u) (SV n2 ixv v) = do
  vm <- VM.new n
  (vm', nfin) <- go ixu u ixv v 0 vm
  return $ VM.take nfin vm'
  where
    n = min n1 n2
    go ixu_ u_ ixv_ v_ i vm
      | V.null u_ || V.null v_ || i == n = return (vm, i)
      | otherwise =  do
           let (u, us) = headTail u_ 
               (v, vs) = headTail v_ 
               (ix1, ix1s) = headTail ixu_
               (ix2, ix2s) = headTail ixv_
           if ix1 == ix2 then do VM.write vm i (ix1, g u v)
                                 go ix1s us ix2s vs (i + 1) vm
                     else if ix1 < ix2 then go ix1s us ixv  v_ i vm
                                       else go ixu  u_ ix2s vs i vm

      
-- | O(N) : Applies a function to the index _union_ of two CsrVector s. Useful e.g. to compute the vector sum of two sparse vectors.
unionWith :: (t -> t -> a) -> t -> SVector t -> SVector t -> SVector a
unionWith g z v1 v2 = SV n ixf vf where
   n = max (svDim v1) (svDim v2)
   (ixf, vf) = V.unzip $ V.create (unionWithM g z v1 v2)

unionWithM :: PrimMonad m =>
     (a -> a -> b)
     -> a
     -> SVector a
     -> SVector a
     -> m (VM.MVector (PrimState m) (Int, b))
unionWithM g z (SV n1 ixu u) (SV n2 ixv v) = do
  vm <- VM.new n
  (vm', nfin) <- go ixu u ixv v 0 vm
  let vm'' = VM.take nfin vm'
  return vm''
  where
    n = min n1 n2
    go iu u_ iv v_ i vm
          | (V.null u_ && V.null v_) || i == n = return (vm , i)
          | V.null u_ = do
                 VM.write vm i (V.head iv, g z (V.head v_))
                 go iu u_ (V.tail iv) (V.tail v_) (i + 1) vm
          | V.null v_ = do
                 VM.write vm i (V.head iu, g (V.head u_) z)
                 go (V.tail iu) (V.tail u_) iv v_ (i + 1) vm
          | otherwise =  do
           let (u, us) = headTail u_
               (v, vs) = headTail v_
               (iu1, ius) = headTail iu
               (iv1, ivs) = headTail iv
           if iu1 == iv1 then do VM.write vm i (iu1, g u v)
                                 go ius us ivs vs (i + 1) vm
                         else if iu1 < iv1 then do VM.write vm i (iu1, g u z)
                                                   go ius us iv v_ (i + 1) vm
                                           else do VM.write vm i (iv1, g z v)
                                                   go iu u_ ivs vs (i + 1) vm
  


#define SVType(t) \
  instance AdditiveGroup (SVector t) where {zeroV = SV 0 V.empty V.empty; (^+^) = unionWith (+) 0 ; negateV = fmap negate};\
  instance VectorSpace (SVector t) where {type Scalar (SVector t) = (t); n .* x = fmap (* n) x };\
  instance InnerSpace (SVector t) where {a <.> b = sum $ intersectWith (*) a b}
  -- instance Normed (CsrVector t) where {norm p v = norm' p v}
  -- instance Hilbert (CsrVector t) where {x `dot` y = sum $ intersectWithCV (*) x y };\
  

#define SVTypeC(t) \
  instance AdditiveGroup (SVector (Complex t)) where {zeroV = SV 0 V.empty V.empty; (^+^) = unionWith (+) (0 :+ 0) ; negateV = fmap negate};\
  instance VectorSpace (SVector (Complex t)) where {type Scalar (SVector (Complex t)) = (Complex t); n .* x = fmap (* n) x };\
  instance InnerSpace (SVector (Complex t)) where {x <.> y = sum $ intersectWith (*) (conjugate <$> x) y};\

-- #define NormedType(t) \
--   instance Normed (CsrVector t) where { norm p v | p==1 = norm1 v | otherwise = norm2 v ; normalize p v = v ./ norm p v};\
--   instance Normed (CsrVector (Complex t)) where { norm p v | p==1 = norm1 v | otherwise = norm2 v ; normalize p v = v ./ norm p v}


-- -- CVType(Int)
-- -- CVType(Integer)
-- SVType(Float)
-- SVType(Double)
-- -- CVType(CSChar)
-- -- CVType(CInt)
-- -- CVType(CShort)
-- -- CVType(CLong)
-- -- CVType(CLLong)
-- -- CVType(CIntMax)
-- SVType(CFloat)
-- SVType(CDouble)

-- SVTypeC(Float)
-- SVTypeC(Double)  
-- SVTypeC(CFloat)
-- SVTypeC(CDouble)

-- -- NormedType(Float)
-- -- NormedType(Double)
-- -- NormedType(CFloat)
-- -- NormedType(CDouble)    








{-| Modify the mutable vector operand by applying a binary function over the index union of the two.

e.g.

g = (+)
z = 0
u = [(1, a), (2, b)]
v = [(0, d), (2, e)]

unionWithSMV g z u v = [(0, d), (1, a), (2, b + e)]

invariants :

* uu, vv nonzero values
* ixu, ixv nonzero indices
* ixu, ixv sorted in ascending order
* n1 == n2
* length ixu == length uu
* length ixv == length vv

-}
unionWithSMV :: PrimMonad m =>
     (a -> a -> a)
     -> a -> SVector a -> SMVector m a -> m (SMVector m a)
unionWithSMV g z (SV n ixu uu) (SMV n2 ixm_ vm_) = do
  ixmnew <- VM.new nnzero  -- create new mutable vectors
  vmnew <- VM.new nnzero
  unless (n == n2) (error "unionWithSMV : operand vectors must have the same length")
  (ixm, vm, nfin) <- go 0 ixmnew vmnew
  let ixm' = VM.take nfin ixm
      vm' = VM.take nfin vm
  return $ SMV n ixm' vm'
  where
    nnzero = lu + lv
    lu = V.length ixu
    lv = VM.length ixm_
    go i ixm vm
      | i == lu && i == lv || i == n = return (ixm, vm, i)
      | i == lu = do
          v0 <- VM.read vm i
          VM.write ixm i i
          VM.write vm i (g z v0)
          go (i + 1) ixm vm
      | i == lv = do
          let u0 = uu V.! i
          VM.write ixm_ i i
          VM.write vm_ i (g u0 z)
          go (i + 1) ixm vm
      | otherwise = do
          let u = uu V.! i    -- read head elements and indices
              iu = ixu V.! i
          v <- VM.read vm i
          iv <- VM.read ixm i
          if iu == iv then
            do
              VM.write ixm i iu         -- write `iu` at position `i` 
              VM.write vm i (g u v)
              go (i + 1) ixm vm
          else if iu < iv then
            do
              VM.write ixm i iu         -- write `iu` at position `i` 
              VM.write vm i (g u z)
              go (i + 1) ixm vm
              else
                do
                  VM.write ixm i iv     -- write `iv` at position `i` 
                  VM.write vm i (g z v)
                  go (i + 1) ixm vm
                               
                             

-- -- test data
-- testUnionWithSMV :: IO (SVector Double)
-- testUnionWithSMV = do 
--   let v = fromList 4 [(1, 1), (2, 1)]
--   vm <- SMV.fromList 4 [(0, pi), (2, pi)]
--   (vmres, nfin) <- unionWithSMV (+) 0 v vm
--   liftIO $ print nfin
--   freeze vmres




            
          

-- * To/from SMVector

thaw :: PrimMonad m => SVector a -> m (SMVector m a)
thaw (SV n ix v) = do
  vm <- V.thaw v
  ixm <- V.thaw ix
  return $ SMV n ixm vm

freeze :: PrimMonad m => SMVector m a -> m (SVector a)
freeze (SMV n ixm vm) = do
  v <- V.freeze vm
  ix <- V.freeze ixm
  return $ SV n ix v





-- * helpers

both :: Arrow arr => arr b c -> arr (b, b) (c, c)
both f = f *** f

headTail :: V.Vector a -> (a, V.Vector a)
headTail = V.head &&& V.tail