-- Copyright (c) 2009, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-- | A high-level interface to the operations in the core language
-- ("Feldspar.Core"). Many of the functions defined here are imitations of
-- Haskell's list operations, and to a first approximation they behave
-- accordingly.

module Feldspar.Vector where



import qualified Prelude
import Control.Arrow ((***),(&&&))
import Data.List (unfoldr)

import Feldspar.Prelude
import Feldspar.Core.Types
import Feldspar.Core.Expr hiding (index)
import Feldspar.Core



-- * Types

-- | Dynamic size of a vector
type Size = Int

-- | Vector index
type Ix = Int

-- | Empty type denoting a parallel (random) access pattern for elements in a
-- vector. The argument denotes the static size of the vector.
data Par n

-- | Empty type denoting a sequential access pattern for elements in a vector.
-- The argument denotes the static size of the vector.
data Seq n

-- | Symbolic vector. For example,
--
-- > Seq D10 :>> Par D5 :>> Data Int
--
-- is a sequential (symbolic) vector of parallel vectors of integers. The type
-- numbers @D10@ and @D5@ denote the /static size/ of the vector, i.e. the
-- allocated size of the array used if and when the vector gets written to
-- memory (e.g. by 'toPar').
--
-- If it is known that the vector will never be written to memory, it is
-- not needed to specify a static size. In that case, it is possible to use @()@
-- as the static size type. This way, attempting to write to memory will
-- result in a type error.
--
-- The 'Size' argument to the 'Indexed' and 'Unfold' constructors is called the
-- /dynamic/ size, since it can vary freely during execution.
data n :>> a
  where
    Indexed  -- Constructor for parallel vectors
      :: Data Size
      -> (Data Ix -> a)  -- A mapping from indexes to elements
      -> (Par n :>> a)

    Unfold  -- Constructor for sequential vectors
      :: Computable s
      => Data Size
      -> (s -> (a,s))  -- "Step function"
      -> s             -- Initial state
      -> (Seq n :>> a)

infixr 5 :>>

-- | Non-nested parallel vector
type VectorP n a = Par n :>> Data a

-- | Non-nested sequential vector
type VectorS n a = Seq n :>> Data a



-- | Addition for static vector size
type family (:+) a b

type instance (:+) (Dec a) (Dec b) = Dec a :+: Dec b
type instance (:+) ()      ()      = ()



-- | Multiplication for static vector size
type family (:*) a b

type instance (:*) (Dec a) (Dec b) = Dec a :*: Dec b
type instance (:*) ()      ()      = ()



-- * Construction/conversion

-- | A class for generalizing over parallel and sequential vectors.
class AccessPattern t
  where
    genericVector :: (Par n :>> a) -> (Seq n :>> a) -> (t n :>> a)

instance AccessPattern Par
  where
    genericVector vecP _ = vecP

instance AccessPattern Seq
  where
    genericVector _ vecS = vecS

-- | Constructs a parallel vector from an index function. The function is
-- assumed to be defined for the domain @[0 .. n-1]@, were @n@ is the dynamic
-- size.
indexed :: Data Size -> (Data Ix -> a) -> (Par n :>> a)
indexed = Indexed

-- | Constructs a sequential vector from a \"step\" function and an initial
-- state.
unfold :: Computable s => Data Size -> (s -> (a,s)) -> s -> (Seq n :>> a)
unfold = Unfold



-- | Converts a non-nested vector to a core vector.
freezeVector :: forall t n a . (NaturalT n, Storable a) =>
    (t n :>> Data a) -> Data (n :> a)

freezeVector (Indexed sz ixf) = parallel sz ixf

freezeVector (Unfold sz step s) = snd $ for 0 end (s,arr) body
  where
    end = value $ fromIntegerT (undefined :: n) - 1
    arr = array [] :: Data (n :> a)

    body i (s, arr :: Data (n :> a)) = (s', setIx arr i a)
      where
        (a,s') = step s



-- | Converts a non-nested core vector to a parallel vector.
unfreezeVector :: (NaturalT n, Storable a, AccessPattern t) =>
    Data Size -> Data (n :> a) -> (t n :>> Data a)

unfreezeVector sz arr = genericVector vec (toSeq vec)
  where
    vec = Indexed sz (getIx arr)



-- | Constructs a non-nested vector.
vector :: (NaturalT n, Storable a, AccessPattern t, ListBased a ~ a) =>
    [a] -> (t n :>> Data a)
  -- (ListBased a ~ a) means no nesting.

vector as = unfreezeVector sz $ array as
  where
    sz = value $ Prelude.length as



-- instance (NaturalT n, Storable (Internal a), Computable a) =>
--     Computable (Par n :>> a)
--   where
--     type Internal (Par n :>> a) = (Int, n :> Internal a)

--     internalize vec =
--      internalize (length vec, freezeVector $ map internalize vec)

--     externalize sz_a = map externalize $ unfreezeVector sz a
--       where
--         sz = externalize $ ref $ GetTuple (T::T D0) sz_a
--         a  = externalize $ ref $ GetTuple (T::T D1) sz_a
  -- XXX This would require first class tuples.

instance (NaturalT n, Storable a, AccessPattern t)
      => Computable (t n :>> Data a)
  where
    type Internal (t n :>> Data a) = (Int, n :> Internal (Data a))

    internalize vec =
      internalize (length vec, freezeVector $ map internalize vec)

    externalize sz_a = map externalize $ unfreezeVector sz a
      where
        sz = externalize $ ref $ GetTuple (T::T D0) sz_a
        a  = externalize $ ref $ GetTuple (T::T D1) sz_a

instance
    ( NaturalT n1
    , NaturalT n2
    , Storable a
    , AccessPattern t1
    , AccessPattern t2
    ) =>
      Computable (t1 n1 :>> t2 n2 :>> Data a)
  where
    type Internal (t1 n1 :>> t2 n2 :>> Data a) =
           (Int, n1 :> Int, n1 :> n2 :> Internal (Data a))

    internalize vec = internalize
      ( length vec
      , freezeVector $ map length vec
      , freezeVector $ map (freezeVector . map internalize) vec
      )

    externalize inp
        = map (map externalize . uncurry unfreezeVector)
        $ zip sz2sV (unfreezeVector sz1 a)
      where
        sz1   = externalize $ ref $ GetTuple (T::T D0) inp
        sz2s  = externalize $ ref $ GetTuple (T::T D1) inp
        a     = externalize $ ref $ GetTuple (T::T D2) inp
        sz2sV = unfreezeVector sz1 sz2s :: t1 n1 :>> Data Int



-- | Convert any vector to a sequential one. This operation is always \"cheap\".
toSeq :: (t n :>> a) -> (Seq n :>> a)
toSeq (Indexed sz ixf)   = Unfold sz (\i -> (ixf i, i+1)) 0
toSeq (Unfold sz step s) = Unfold sz step s

-- | Changes the static size of a vector.
resize :: NaturalT n => (t m :>> a) -> (t n :>> a)
resize (Indexed sz ixf)   = Indexed sz ixf
resize (Unfold sz step s) = Unfold sz step s
  -- The NaturalT constraint is needed because otherwise it would be possible to
  -- make an existing NaturalT constraint disappear. That would ruin the
  -- property that vectors with fully polymorphic sizes do not represent their
  -- elements in memory.

-- | Convert any non-nested vector to a parallel one with cheap lookups.
-- Internally, this is done by writing the vector to memory.
toPar :: (NaturalT n, Storable a) => (t n :>> Data a) -> VectorP n a
toPar vec = unfreezeVector (length vec) $ freezeVector vec



-- * Operations

-- | Look up an index in a vector. This operation takes linear time for
-- sequential vectors.
index :: (t :>> a) -> Data Ix -> a
index (Indexed _ ixf)   i = ixf i
index (Unfold _ step s) i = fst $ step $ fst $ while cont body (s,0)
  where
    cont = (<i) . snd
    body = ((snd . step) *** (+1))

instance RandomAccess (Par n :>> a)
  where
    type Elem (Par n :>> a) = a
    (!) = index



-- | The dynamic size of a vector
length :: (t n :>> a) -> Data Size
length (Indexed sz _)   = sz
length (Unfold  sz _ _) = sz



(++) :: Computable a => (t m :>> a) -> (t n :>> a) -> (t (m :+ n) :>> a)

Indexed sz1 ixf1 ++ Indexed sz2 ixf2 = Indexed (sz1+sz2) ixf
  where
    ixf i = ifThenElse (i < sz1) ixf1 (ixf2 . subtract sz1) i

Unfold sz1 step1 s1 ++ Unfold sz2 step2 s2 = Unfold (sz1+sz2) step (0, (s1,s2))
  where
    step (n, (s1',s2')) = n<sz1 ?
      ( let (a,s1'') = step1 s1' in (a, (n+1, (s1'', s2')))
      , let (a,s2'') = step2 s2' in (a, (n+1, (s1', s2'')))
      )

infixr 5 ++



take :: Data Int -> (t n :>> a) -> (t n :>> a)

take n (Indexed sz ixf) = Indexed sz' ixf
  where
    sz' = min sz n

take n (Unfold sz step s) = Unfold sz' step s
  where
    sz' = min sz n



drop :: Data Int -> (t n :>> a) -> (t n :>> a)

drop n (Indexed sz ixf) = Indexed sz' (\x -> ixf (x+n))
  where
    sz' = max 0 (sz-n)

drop n (Unfold sz step s) = Unfold sz' step s'
  where
    sz' = max 0 (sz-n)
    s'  = for 0 (n-1) s (\_ -> snd . step)



dropWhile :: (a -> Data Bool) -> (t n :>> a) -> (t n :>> a)

dropWhile cont vec@(Indexed _ _) = drop i vec
  where
    i = while ((< length vec) &&* (cont . (vec !))) (+1) 0

dropWhile cont vec@(Unfold sz step s) = Unfold (sz-i) step s'
  where
    (s',i) = while condition (\(s,i) -> (snd $ step s, i+1)) (s,0)
      where
        condition = ((\(s,i) -> i <= length vec) &&* (cont.fst.step.fst))



splitAt :: Data Int -> (t n :>> a) -> (t n :>> a, t n :>> a)
splitAt n vec = (take n vec, drop n vec)

head :: (t n :>> a) -> a
head = flip index 0

last :: (t n :>> a) -> a
last vec = index vec (length vec - 1)

tail :: (t n :>> a) -> (t n :>> a)
tail = drop 1

init :: (t n :>> a) -> (t n :>> a)
init vec = take (length vec - 1) vec

-- | Like Haskell's 'tails', but does not include the empty vector. This is
-- actually just to make the types simpler (the result is square).
tails :: AccessPattern u => (t n :>> a) -> (u n :>> t n :>> a)
tails vec = genericVector vecP vecS
  where
    sz   = length vec
    vecP = Indexed sz (\n -> drop n vec)
    vecS = Unfold sz (\n -> (drop n vec, n+1)) 0

-- | Like Haskell's 'inits', but does not include the empty vector. This is
-- actually just to make the types simpler (the result is square).
inits :: AccessPattern u => (t n :>> a) -> (u n :>> t n :>> a)
inits vec = genericVector vecP vecS
  where
    sz   = length vec
    vecP = Indexed sz (\n -> take n vec)
    vecS = Unfold sz (\n -> (take n vec, n+1)) 0

permute :: (Data Size -> Data Ix -> Data Ix) -> ((Par n :>> a) -> (Par n :>> a))
permute perm (Indexed sz ixf) = Indexed sz (ixf . perm sz)

reverse :: (Par n :>> a) -> (Par n :>> a)
reverse = permute $ \sz i -> sz-1-i

replicate :: AccessPattern t => Data Int -> a -> (t n :>> a)
replicate n a = genericVector vecP vecS
  where
    vecP = Indexed n (const a)
    vecS = Unfold n (const (a, unit)) unit

enumFromTo :: AccessPattern t => Data Int -> Data Int -> (t n :>> Data Int)
enumFromTo m n = genericVector vecP vecS
  where
    sz   = n-m+1
    vecP = indexed sz (+m)
    vecS = unfold sz (\x -> (x,x+1)) m



zip :: (t n :>> a) -> (t n :>> b) -> (t n :>> (a,b))

zip (Indexed sz1 ixf1) (Indexed sz2 ixf2) =
    Indexed (min sz1 sz2) (ixf1 &&& ixf2)

zip (Unfold sz1 step1 s1) (Unfold sz2 step2 s2) = Unfold sz step (s1, s2)
  where
    sz = min sz1 sz2
    step (s1,s2) = ((a,b), (s1',s2'))
      where
        (a,s1') = step1 s1
        (b,s2') = step2 s2



unzip :: (t n :>> (a,b)) -> (t n :>> a, t n :>> b)

unzip (Indexed sz ixf) = (Indexed sz (fst.ixf), Indexed sz (snd.ixf))

unzip (Unfold sz step s) =
    (Unfold sz ((fst***id).step) s, Unfold sz ((snd***id).step) s)



map :: (a -> b) -> ((t n :>> a) -> (t n :>> b))
map f (Indexed sz ixf)   = Indexed sz (f . ixf)
map f (Unfold sz step s) = Unfold sz ((f *** id) . step) s

zipWith :: (a -> b -> c) -> (t n :>> a) -> (t n :>> b) -> (t n :>> c)
zipWith f aVec bVec = map (uncurry f) $ zip aVec bVec



-- | Corresponds to Haskell's @foldl@.
fold :: Computable a => (a -> b -> a) -> a -> (t n :>> b) -> a

fold f x (Unfold sz step s) = fst $ for 0 (sz-1) (x,s) body
  where
    body i (m,n) = (f m m', n')
      where
        (m',n') = step n

fold f x (Indexed sz ixf) = for 0 (sz-1) x (\i s -> f s (ixf i))



-- | Corresponds to Haskell's @foldl1@.
fold1 :: Computable a => (a -> a -> a) -> (t n :>> a) -> a
fold1 f a = fold f (head a) a



-- | Corresponds to Haskell's @scanl@.
scan :: Computable a => (a -> b -> a) -> a -> (t n :>> b) -> (Seq n :>> a)

scan f a (Indexed sz ixf) = Unfold sz step (0,a)
  where
    step (i,a) = let a' = f a (ixf i) in (a', (i+1, a'))

scan f a (Unfold sz step s) = Unfold sz step' (s,a)
  where
    step' (s,a) = (a', (s',a'))
      where
        (b,s') = step s
        a'     = f a b



-- | Corresponds to Haskell's @scanl1@.
scan1 :: Computable a => (a -> a -> a) -> (t n :>> a) -> (Seq n :>> a)
scan1 f vec = scan f (head vec) (tail vec)

sum :: (Num a, Computable a) => (t n :>> a) -> a
sum = fold (+) 0

maximum :: Storable a => (t n :>> Data a) -> Data a
maximum = fold1 max

minimum :: Storable a => (t n :>> Data a) -> Data a
minimum = fold1 min



-- | Scalar product of two vectors
scalarProd :: (Primitive a, Num a) =>
    (t n :>> Data a) -> (t n :>> Data a) -> Data a

scalarProd a b = sum (zipWith (*) a b)