--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

-- | A high-level interface to the operations in the core language
-- ("Feldspar.Core"). Many of the functions defined here are imitations of
-- Haskell's list operations, and to a first approximation they behave
-- accordingly.
--
-- A symbolic vector ('Vector') can be thought of as a representation of a
-- 'parallel' core array. This view is made precise by the function
-- 'freezeVector', which converts a symbolic vector to a core vector using
-- 'parallel'.
--
-- 'Vector' is instantiated under the 'Computable' class, which means that
-- symbolic vectors can be used quite seamlessly with the interface in
-- "Feldspar.Core".
--
-- Unlike core arrays vectors don't use any physical memory. All
-- operations on vectors are \"fused\" which means that intermediate vectors
-- are removed. As an example, the following function uses only constant
-- space despite using two intermediate vectors of length @n@.
--
-- > sumSq n = sum (map (^2) (1...n))
--
-- Memory is only introduced when a vector is explicitly
-- written to memory using the function 'memorize' or converted to a core
-- array using 'freezeVector'. The function 'vector' for creating a
-- vector also allocates memory.
--
-- Note also that most operations only introduce a small constant overhead on
-- the vector. The exceptions are
--
--   * 'dropWhile'
--
--   * 'fold'
--
--   * 'fold1'
--
--   * Functions that introduce storage (see above)
--
--   * \"Folding\" functions: 'sum', 'maximum', etc.
--
-- These functions introduce overhead that is linear in the length of the
-- vector.
--
-- Finally, note that 'freezeVector' can be introduced implicitly by functions
-- overloaded by the 'Computable' class. This means that, for example,
-- @`printCore` f@, where @f :: Vector (Data Int) -> Vector (Data Int)@, will
-- introduce storage for the input and output of @f@.

module Feldspar.Vector where



import qualified Prelude
import Control.Arrow ((&&&))
import qualified Data.List  -- Only for documentation of 'unfold'

import Feldspar.Prelude
import Feldspar.Range
import Feldspar.Core.Expr
import Feldspar.Core



-- * Types

-- | Vector index
type Ix = Int

-- | Symbolic vector
data Vector a = Indexed
  { length :: Data Length
  , index  :: Data Ix -> a
  }

-- | Short-hand for non-nested parallel vector
type DVector a = Vector (Data a)



-- * Construction/conversion

-- | Converts a non-nested vector to a core vector.
freezeVector :: Storable a => Vector (Data a) -> Data [a]
freezeVector (Indexed l ixf) = parallel l ixf

-- | Converts a non-nested core vector to a parallel vector.
unfreezeVector :: Storable a => Data Length -> Data [a] -> Vector (Data a)
unfreezeVector l arr = Indexed l (getIx arr)

-- | Optimizes vector lookup by computing all elements and storing them in a
-- core array.
memorize :: Storable a => Vector (Data a) -> Vector (Data a)
memorize vec = unfreezeVector (length vec) $ freezeVector vec
  -- XXX Should be generalized to arbitrary dimensions.

indexed :: Data Length -> (Data Ix -> a) -> Vector a
indexed = Indexed

-- | Constructs a non-nested vector. The elements are stored in a core vector.
vector :: Storable a => [a] -> Vector (Data a)
vector as = unfreezeVector l (value as)
  where
    l = value $ Prelude.length as
  -- XXX Should be generalized to arbitrary dimensions.

modifyLength :: (Data Length -> Data Length) -> Vector a -> Vector a
modifyLength f vec = vec {length = f (length vec)}

setLength :: Data Length -> Vector a -> Vector a
setLength = modifyLength . const

boundVector :: Int -> Vector a -> Vector a
boundVector maxLen = modifyLength (cap r)
  where
    r = negativeRange + singletonRange (fromIntegral maxLen) + 1
      -- XXX fromIntegral might not be needed in future.



instance Storable a => Computable (Vector (Data a))
  where
    type Internal (Vector (Data a)) = (Length, [Internal (Data a)])

    internalize vec =
      internalize (length vec, freezeVector $ map internalize vec)

    externalize l_a = map externalize $ unfreezeVector l a
      where
        l = externalize $ get21 l_a
        a = externalize $ get22 l_a

instance Storable a => Computable (Vector (Vector (Data a)))
  where
    type Internal (Vector (Vector (Data a))) =
           (Length, [Length], [[Internal (Data a)]])

    internalize vec = internalize
      ( length vec
      , freezeVector $ map length vec
      , freezeVector $ map (freezeVector . map internalize) vec
      )

    externalize inp
        = map (map externalize . uncurry unfreezeVector)
        $ zip l2sV (unfreezeVector l1 a)
      where
        l1   = externalize $ get31 inp
        l2s  = externalize $ get32 inp
        a    = externalize $ get33 inp
        l2sV = unfreezeVector l1 l2s



-- * Operations

instance RandomAccess (Vector a)
  where
    type Element (Vector a) = a
    (!) = index



-- | Introduces an 'ifThenElse' for each element; use with care!
(++) :: Computable a => Vector a -> Vector a -> Vector a
Indexed l1 ixf1 ++ Indexed l2 ixf2 = Indexed (l1+l2) ixf
  where
    ixf i = ifThenElse (i < l1) ixf1 (ixf2 . subtract l1) i

infixr 5 ++

take :: Data Int -> Vector a -> Vector a
take n (Indexed l ixf) = Indexed (min n l) ixf

drop :: Data Int -> Vector a -> Vector a
drop n (Indexed l ixf) = Indexed (max 0 (l-n)) (\x -> ixf (x+n))

dropWhile :: (a -> Data Bool) -> Vector a -> Vector a
dropWhile cont vec = drop i vec
  where
    i = while ((< length vec) &&* (cont . (vec !))) (+1) 0

splitAt :: Data Int -> Vector a -> (Vector a, Vector a)
splitAt n vec = (take n vec, drop n vec)

head :: Vector a -> a
head = (!0)

last :: Vector a -> a
last vec = vec ! (length vec - 1)

tail :: Vector a -> Vector a
tail = drop 1

init :: Vector a -> Vector a
init vec = take (length vec - 1) vec

tails :: Vector a -> Vector (Vector a)
tails vec = Indexed (length vec + 1) (\n -> drop n vec)

inits :: Vector a -> Vector (Vector a)
inits vec = Indexed (length vec + 1) (\n -> take n vec)

inits1 :: Vector a -> Vector (Vector a)
inits1 = tail . inits

permute :: (Data Length -> Data Ix -> Data Ix) -> (Vector a -> Vector a)
permute perm (Indexed l ixf) = Indexed l (ixf . perm l)

reverse :: Vector a -> Vector a
reverse = permute $ \l i -> l-1-i

replicate :: Data Int -> a -> Vector a
replicate n a = Indexed n (const a)

enumFromTo :: Data Int -> Data Int -> Vector (Data Int)
enumFromTo m n = Indexed (n-m+1) (+m)
  -- XXX Type should be generalized.

(...) :: Data Int -> Data Int -> Vector (Data Int)
(...) = enumFromTo

zip :: Vector a -> Vector b -> Vector (a,b)
zip (Indexed l1 ixf1) (Indexed l2 ixf2) = Indexed (min l1 l2) (ixf1 &&& ixf2)

unzip :: Vector (a,b) -> (Vector a, Vector b)
unzip (Indexed l ixf) = (Indexed l (fst.ixf), Indexed l (snd.ixf))

map :: (a -> b) -> Vector a -> Vector b
map f (Indexed l ixf) = Indexed l (f . ixf)

zipWith :: (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith f aVec bVec = map (uncurry f) $ zip aVec bVec

-- | Corresponds to 'foldl'.
fold :: Computable a => (a -> b -> a) -> a -> Vector b -> a
fold f x (Indexed l ixf) = for 0 (l-1) x (\i s -> f s (ixf i))

-- | Corresponds to 'foldl1'.
fold1 :: Computable a => (a -> a -> a) -> Vector a -> a
fold1 f a = fold f (head a) a

sum :: Numeric a => Vector (Data a) -> Data a
sum = fold (+) 0

maximum :: Ord a => Vector (Data a) -> Data a
maximum = fold1 max

minimum :: Ord a => Vector (Data a) -> Data a
minimum = fold1 min

-- | Scalar product of two vectors
scalarProd :: Numeric a => Vector (Data a) -> Vector (Data a) -> Data a
scalarProd a b = sum (zipWith (*) a b)