-- Copyright (c) 2009-2010, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | A high-level interface to the operations in the core language
-- ("Feldspar.Core"). Many of the functions defined here are imitations of
-- Haskell's list operations, and to a first approximation they behave
-- accordingly.
--
-- A symbolic vector ('Vector') can be thought of as a representation of a
-- 'parallel' core array. This view is made precise by the function
-- 'freezeVector', which converts a symbolic vector to a core vector using
-- 'parallel'.
--
-- 'Vector' is instantiated under the 'Computable' class, which means that
-- symbolic vectors can be used quite seamlessly with the interface in
-- "Feldspar.Core".
--
-- Note that the only operations in this module that introduce storage (through
-- core arrays) are
--
--   * 'freezeVector'
--
--   * 'memorize'
--
--   * 'vector'
--
--   * 'unfoldVec'
--
--   * 'unfold'
--
--   * 'scan'
--
--   * 'mapAccum'
--
-- This means that vector operations not involving these operations will be
-- completely \"fused\" without using any intermediate storage.
--
-- Note also that most operations only introduce a small constant overhead on
-- the vector. The exceptions are
--
--   * 'dropWhile'
--
--   * 'fold'
--
--   * 'fold1'
--
--   * Functions that introduce storage (see above)
--
--   * \"Folding\" functions: 'sum', 'maximum', etc.
--
-- These functions introduce overhead that is linear in the length of the
-- vector.
--
-- Finally, note that 'freezeVector' can be introduced implicitly by functions
-- overloaded by the 'Computable' class. This means that, for example,
-- @`printCore` f@, where @f :: Vector (Data Int) -> Vector (Data Int)@, will
-- introduce storage for the input and output of @f@.

module Feldspar.Vector where



import qualified Prelude
import Control.Arrow ((&&&))
import qualified Data.List  -- Only for documentation of 'unfold'

import Feldspar.Prelude
import Feldspar.Range
import Feldspar.Core.Expr
import Feldspar.Core



-- * Types

-- | Vector index
type Ix = Int

-- | Symbolic vector
data Vector a = Indexed
  { length :: Data Length
  , index  :: Data Ix -> a
  }

-- | Short-hand for non-nested parallel vector
type DVector a = Vector (Data a)



-- * Construction/conversion

-- | Converts a non-nested vector to a core vector.
freezeVector :: Storable a => Vector (Data a) -> Data [a]
freezeVector (Indexed l ixf) = parallel l ixf

-- | Converts a non-nested core vector to a parallel vector.
unfreezeVector :: Storable a => Data Length -> Data [a] -> Vector (Data a)
unfreezeVector l arr = Indexed l (getIx arr)

-- | Optimizes vector lookup by computing all elements and storing them in a
-- core array.
memorize :: Storable a => Vector (Data a) -> Vector (Data a)
memorize vec = unfreezeVector (length vec) $ freezeVector vec
  -- XXX Should be generalized to arbitrary dimensions.

indexed :: Data Length -> (Data Ix -> a) -> Vector a
indexed = Indexed

-- | Constructs a non-nested vector. The elements are stored in a core vector.
vector :: Storable a => [a] -> Vector (Data a)
vector as = unfreezeVector l (value as)
  where
    l = value $ Prelude.length as
  -- XXX Should be generalized to arbitrary dimensions.

modifyLength :: (Data Length -> Data Length) -> Vector a -> Vector a
modifyLength f vec = vec {length = f (length vec)}

setLength :: Data Length -> Vector a -> Vector a
setLength = modifyLength . const

boundVector :: Int -> Vector a -> Vector a
boundVector maxLen = modifyLength (cap r)
  where
    r = negativeRange + singletonRange (fromIntegral maxLen) + 1
      -- XXX fromIntegral might not be needed in future.



instance Storable a => Computable (Vector (Data a))
  where
    type Internal (Vector (Data a)) = (Length, [Internal (Data a)])

    internalize vec =
      internalize (length vec, freezeVector $ map internalize vec)

    externalize l_a = map externalize $ unfreezeVector l a
      where
        l = externalize $ exprToData $ Get21 l_a
        a = externalize $ exprToData $ Get22 l_a

instance Storable a => Computable (Vector (Vector (Data a)))
  where
    type Internal (Vector (Vector (Data a))) =
           (Length, [Length], [[Internal (Data a)]])

    internalize vec = internalize
      ( length vec
      , freezeVector $ map length vec
      , freezeVector $ map (freezeVector . map internalize) vec
      )

    externalize inp
        = map (map externalize . uncurry unfreezeVector)
        $ zip l2sV (unfreezeVector l1 a)
      where
        l1   = externalize $ exprToData $ Get31 inp
        l2s  = externalize $ exprToData $ Get32 inp
        a    = externalize $ exprToData $ Get33 inp
        l2sV = unfreezeVector l1 l2s



-- * Operations

instance RandomAccess (Vector a)
  where
    type Element (Vector a) = a
    (!) = index



-- | Introduces an 'ifThenElse' for each element; use with care!
(++) :: Computable a => Vector a -> Vector a -> Vector a
Indexed l1 ixf1 ++ Indexed l2 ixf2 = Indexed (l1+l2) ixf
  where
    ixf i = ifThenElse (i < l1) ixf1 (ixf2 . subtract l1) i

infixr 5 ++

take :: Data Int -> Vector a -> Vector a
take n (Indexed l ixf) = Indexed (minX n l) ixf

drop :: Data Int -> Vector a -> Vector a
drop n (Indexed l ixf) = Indexed (maxX 0 (l-n)) (\x -> ixf (x+n))

dropWhile :: (a -> Data Bool) -> Vector a -> Vector a
dropWhile cont vec = drop i vec
  where
    i = while ((< length vec) &&* (cont . (vec !))) (+1) 0

splitAt :: Data Int -> Vector a -> (Vector a, Vector a)
splitAt n vec = (take n vec, drop n vec)

head :: Vector a -> a
head = (!0)

last :: Vector a -> a
last vec = vec ! (length vec - 1)

tail :: Vector a -> Vector a
tail = drop 1

init :: Vector a -> Vector a
init vec = take (length vec - 1) vec

tails :: Vector a -> Vector (Vector a)
tails vec = Indexed (length vec + 1) (\n -> drop n vec)

inits :: Vector a -> Vector (Vector a)
inits vec = Indexed (length vec + 1) (\n -> take n vec)

inits1 :: Vector a -> Vector (Vector a)
inits1 = tail . inits

permute :: (Data Length -> Data Ix -> Data Ix) -> (Vector a -> Vector a)
permute perm (Indexed l ixf) = Indexed l (ixf . perm l)

reverse :: Vector a -> Vector a
reverse = permute $ \l i -> l-1-i

replicate :: Data Int -> a -> Vector a
replicate n a = Indexed n (const a)

enumFromTo :: Data Int -> Data Int -> Vector (Data Int)
enumFromTo m n = Indexed (n-m+1) (+m)
  -- XXX Type should be generalized.

(...) :: Data Int -> Data Int -> Vector (Data Int)
(...) = enumFromTo

zip :: Vector a -> Vector b -> Vector (a,b)
zip (Indexed l1 ixf1) (Indexed l2 ixf2) = Indexed (min l1 l2) (ixf1 &&& ixf2)

unzip :: Vector (a,b) -> (Vector a, Vector b)
unzip (Indexed l ixf) = (Indexed l (fst.ixf), Indexed l (snd.ixf))

map :: (a -> b) -> Vector a -> Vector b
map f (Indexed l ixf) = Indexed l (f . ixf)

zipWith :: (a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith f aVec bVec = map (uncurry f) $ zip aVec bVec

-- | Corresponds to 'foldl'.
fold :: Computable a => (a -> b -> a) -> a -> Vector b -> a
fold f x (Indexed l ixf) = for 0 (l-1) x (\i s -> f s (ixf i))

-- | Corresponds to 'foldl1'.
fold1 :: Computable a => (a -> a -> a) -> Vector a -> a
fold1 f a = fold f (head a) a



-- | Like 'unfoldCore', but for symbolic vectors. The output elements are stored
-- in a core vector.
unfoldVec
    :: (Computable state, Storable a)
    => Data Length
    -> state
    -> (Data Int -> state -> (Data a, state))
    -> (Vector (Data a), state)

unfoldVec l init step = (unfreezeVector l arr, final)
  where
    (arr,final) = unfoldCore l init step



-- | Somewhat similar to Haskell's 'Data.List.unfoldr'. The output elements are
-- stored in a core vector.
--
-- @`unfold` l init step@:
--
--   * @l@ is the length of the resulting vector.
--
--   * @init@ is the initial state.
--
--   * @step@ is a function computing a new element and the next state from the
--     current state.
unfold :: (Computable state, Storable a) =>
    Data Length -> state -> (state -> (Data a, state)) -> Vector (Data a)

unfold l init = fst . unfoldVec l init . const



-- | Corresponds to 'scanl'. The output elements are stored in a core vector.
scan :: (Storable a, Computable b) =>
    (Data a -> b -> Data a) -> Data a -> Vector b -> Vector (Data a)

scan f a vec = fst $ unfoldVec (length vec + 1) a $ \i a -> (a, f a (vec!i))



-- | Corresponds to 'Data.List.mapAccumL'. The output elements are stored in a
-- core vector.
mapAccum :: (Storable a, Computable acc, Storable b)
    => (acc -> Data a -> (acc, Data b))
    -> acc -> Vector (Data a) -> (acc, Vector (Data b))

mapAccum f init vecA = (final,vecB)
  where
    (vecB,final) = unfoldVec (length vecA) init $ \i acc ->
      let (acc',b) = f acc (vecA!i) in (b,acc')



sum :: (Num a, Computable a) => Vector a -> a
sum = fold (+) 0

maximum :: Storable a => Vector (Data a) -> Data a
maximum = fold1 max

minimum :: Storable a => Vector (Data a) -> Data a
minimum = fold1 min

-- | Scalar product of two vectors
scalarProd :: Numeric a => Vector (Data a) -> Vector (Data a) -> Data a
scalarProd a b = sum (zipWith (*) a b)