{-# LANGUAGE UndecidableInstances #-}

-- | A module for /virtual vectors/. Many of the functions defined here are
-- imitations of Haskell's list operations, and to a first approximation they
-- behave accordingly.
--
-- A virtual vector normally doesn't use any physical memory. Memory is only
-- introduced explicitly using the function 'force' or converted to a core array
-- using 'freezeVector'. The function 'vector' for creating a vector also
-- allocates memory.
--
-- Note also that most operations only introduce a small constant overhead on
-- the vector. The exceptions are
--
--   * 'fold'
--
--   * 'fold1'
--
--   * Functions that introduce storage (see above)
--
--   * \"Folding\" functions: 'sum', 'maximum', etc.
--
-- These functions introduce overhead that is linear in the length of the
-- vector.
--
-- Finally, note that 'freezeVector' can be introduced implicitly by functions
-- overloaded by the 'Syntactic' class.

module Feldspar.Vector where



import qualified Prelude
import Control.Arrow ((&&&))
import Data.List (genericLength)
import qualified Data.TypeLevel as TL

import Feldspar.DSL.Network hiding (In,Out)
import Feldspar.Prelude
import Feldspar.Core.Representation
import Feldspar.Core



-- * Types

-- | Symbolic vector
data Vector a
    = Empty
    | Indexed
        { segmentLength :: Data Length
        , segmentIndex  :: Data Index -> a
        , continuation  :: Vector a
        }

-- | Short-hand for non-nested parallel vector
type DVector a = Vector (Data a)



-- * Construction/conversion

indexed :: Data Length -> (Data Index -> a) -> Vector a
indexed l idxFun = Indexed l idxFun Empty

-- | Breaks up a segmented vector into a list of single-segment vectors.
segments :: Vector a -> [Vector a]
segments Empty                = []
segments (Indexed l ixf cont) = indexed l ixf : segments cont

length :: Vector a -> Data Length
length Empty = 0
length vec   = Prelude.foldr (+) 0 $ Prelude.map segmentLength $ segments vec

-- | Converts a segmented vector to a vector with a single segment.
mergeSegments :: Syntactic a => Vector a -> Vector a
mergeSegments vec = indexed (length vec) (ixFun (segments vec))
  where
    ixFun (Indexed l ixf _ : vs) = case vs of
      [] -> ixf
      _  -> \i -> condition (i<l) (ixf i) (ixFun vs (i-l))

-- | Converts a non-nested vector to a core vector.
freezeVector :: Type a => Vector (Data a) -> Data [a]
freezeVector vec = help True vec
  where
    help _   Empty                = value []
    help opt (Indexed l ixf cont) = parallel'' opt l ixf $ help False cont

-- | Converts a non-nested core vector to a parallel vector.
unfreezeVector :: Type a => Data [a] -> Vector (Data a)
unfreezeVector arr = indexed (getLength arr) (getIx arr)

-- | Variant of `unfreezeVector` with additional static size information.
unfreezeVector' :: (Type a) => Length -> Data [a] -> Vector (Data a)
unfreezeVector' len arr = unfreezeVector $ cap (r :> elemSize) arr
  where
    (_ :> elemSize) = dataSize arr
    r = Range len len

-- | Optimizes vector lookup by computing all elements and storing them in a
-- core array.
memorize :: Syntactic (Vector a) => Vector a -> Vector a
memorize = force
{-# DEPRECATED memorize "Please use `force` instead." #-}

-- | Constructs a non-nested vector. The elements are stored in a core vector.
vector :: Type a => [a] -> Vector (Data a)
vector as = unfreezeVector (value as)
  -- TODO Generalize to arbitrary dimensions.

instance
    ( Syntactic a
    , Role a ~ ()
    , Info a ~ EdgeSize () (Internal a)
    ) => EdgeInfo (Vector a)
  where
    type Info (Vector a) = EdgeSize () [Internal a]
    edgeInfo             = edgeInfo . toEdge

instance
    ( Syntactic a
    , Role a ~ ()
    , Info a ~ EdgeSize () (Internal a)
    ) =>
    MultiEdge (Vector a) Feldspar EdgeSize
  where
    type Role     (Vector a) = ()
    type Internal (Vector a) = [Internal a]

    toEdge           = toEdge . freezeVector . map edgeCast
    fromInEdge       = map edgeCast . unfreezeVector . fromInEdge
    fromOutEdge info = map edgeCast . unfreezeVector . fromOutEdge info

instance (Syntactic a, Role a ~ (), Info a ~ EdgeSize () (Internal a)) =>
    Syntactic (Vector a)



-- * Operations

instance Syntactic a => RandomAccess (Vector a)
  where
    type Element (Vector a) = a
    (!) = segmentIndex . mergeSegments

(++) :: Vector a -> Vector a -> Vector a
Empty              ++ v     = v
v                  ++ Empty = v
Indexed l ixf cont ++ v     = Indexed l ixf (cont ++ v)

infixr 5 ++

take :: Data Length -> Vector a -> Vector a
take _ Empty                = Empty
take n (Indexed l ixf cont) = indexed nHead ixf ++ take nCont cont
  where
    nHead = n<l ? (n,l)
    nCont = n<l ? (0,n-l)

drop :: Data Length -> Vector a -> Vector a
drop _ Empty = Empty
drop n (Indexed l ixf cont) = indexed nHead (ixf . (+n)) ++ drop nCont cont
  where
    nHead = n > l ? (0,l-n)
    nCont = l > n ? (0,n-l)

splitAt :: Data Index -> Vector a -> (Vector a, Vector a)
splitAt n vec = (take n vec, drop n vec)

head :: Syntactic a => Vector a -> a
head = (!0)

last :: Syntactic a => Vector a -> a
last vec = vec ! (length vec - 1)

tail :: Vector a -> Vector a
tail = drop 1

init :: Vector a -> Vector a
init vec = take (length vec - 1) vec

tails :: Vector a -> Vector (Vector a)
tails vec = indexed (length vec + 1) (\n -> drop n vec)

inits :: Vector a -> Vector (Vector a)
inits vec = indexed (length vec + 1) (\n -> take n vec)

inits1 :: Vector a -> Vector (Vector a)
inits1 = tail . inits

-- | Permute a single-segment vector
permute' :: (Data Length -> Data Index -> Data Index) -> (Vector a -> Vector a)
permute' _    Empty                 = Empty
permute' perm (Indexed l ixf Empty) = indexed l (ixf . perm l)

-- | Permute a vector
permute :: Syntactic a =>
    (Data Length -> Data Index -> Data Index) -> (Vector a -> Vector a)
permute perm = permute' perm . mergeSegments

reverse :: Syntactic a => Vector a -> Vector a
reverse = permute $ \l i -> l-1-i
  -- TODO Can be optimized (reversing each segment separately, and then
  --      reversing the segment order)

rotateVecL :: Syntactic a => Data Index -> Vector a -> Vector a
rotateVecL ix = permute $ \l i -> (i + ix) `rem` l

rotateVecR :: Syntactic a => Data Index -> Vector a -> Vector a
rotateVecR ix = reverse . rotateVecL ix . reverse

replicate :: Data Length -> a -> Vector a
replicate n a = Indexed n (const a) Empty

enumFromTo :: Data Index -> Data Index -> Vector (Data Index)
enumFromTo m n = indexed l (+m)
  where
    l = n<m ? (0, n-m+1)
  -- TODO Type should be generalized.

(...) :: Data Index -> Data Index -> Vector (Data Index)
(...) = enumFromTo

map :: (a -> b) -> Vector a -> Vector b
map _ Empty = Empty
map f (Indexed l ixf cont) = Indexed l (f . ixf) $ map f cont

-- | Zipping a single-segment vector
zip' :: Vector a -> Vector b -> Vector (a,b)
zip' Empty _ = Empty
zip' _ Empty = Empty
zip' (Indexed l1 ixf1 Empty) (Indexed l2 ixf2 Empty) =
    indexed (min l1 l2) (ixf1 &&& ixf2)

zip :: (Syntactic a, Syntactic b) => Vector a -> Vector b -> Vector (a,b)
zip vec1 vec2 = zip' (mergeSegments vec1) (mergeSegments vec2)

unzip :: Vector (a,b) -> (Vector a, Vector b)
unzip Empty = (Empty, Empty)
unzip (Indexed l ixf cont) =
    (Indexed l (fst.ixf) cont1, Indexed l (snd.ixf) cont2)
  where
    (cont1,cont2) = unzip cont

zipWith :: (Syntactic a, Syntactic b) =>
    (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith f aVec bVec = map (uncurry f) $ zip aVec bVec

-- | Corresponds to the standard 'foldl'.
fold :: Syntactic a => (a -> b -> a) -> a -> Vector b -> a
fold _ x Empty = x
fold f x (Indexed l ixf cont) =
    fold f (forLoop l x $ \ix s -> f s (ixf ix)) cont

-- | Corresponds to the standard 'foldl1'.
fold1 :: Type a => (Data a -> Data a -> Data a) -> Vector (Data a) -> Data a
fold1 f a = fold f (head a) (tail a)

sum :: Numeric a => Vector (Data a) -> Data a
sum = fold (+) 0

maximum :: Ord a => Vector (Data a) -> Data a
maximum = fold1 max

minimum :: Ord a => Vector (Data a) -> Data a
minimum = fold1 min

-- | Scalar product of two vectors
scalarProd :: Numeric a => Vector (Data a) -> Vector (Data a) -> Data a
scalarProd a b = sum (zipWith (*) a b)

-- * Wrapping for vectors

instance (Type a) => Wrap (Vector (Data a)) (Data [a]) where
    wrap v = freezeVector v

instance (Wrap t u, Type a, TL.Nat s) => Wrap (DVector a -> t) (Data' s [a] -> u) where
    wrap f = \(Data' d) -> wrap $ f $ unfreezeVector' s' d where
        s' = fromInteger $ toInteger $ TL.toInt (undefined :: s)