module Feldspar.Vector where
import qualified Prelude
import Control.Arrow ((***),(&&&))
import Data.List (unfoldr)
import Feldspar.Prelude
import Feldspar.Core.Types
import Feldspar.Core.Expr hiding (index)
import Feldspar.Core
type Size = Int
type Ix = Int
data Par n
data Seq n
data n :>> a
where
Indexed
:: Data Size
-> (Data Ix -> a)
-> (Par n :>> a)
Unfold
:: Computable s
=> Data Size
-> (s -> (a,s))
-> s
-> (Seq n :>> a)
infixr 5 :>>
type VectorP n a = Par n :>> Data a
type VectorS n a = Seq n :>> Data a
type family (:+) a b
type instance (:+) (Dec a) (Dec b) = Dec a :+: Dec b
type instance (:+) () () = ()
type family (:*) a b
type instance (:*) (Dec a) (Dec b) = Dec a :*: Dec b
type instance (:*) () () = ()
class AccessPattern t
where
genericVector :: (Par n :>> a) -> (Seq n :>> a) -> (t n :>> a)
instance AccessPattern Par
where
genericVector vecP _ = vecP
instance AccessPattern Seq
where
genericVector _ vecS = vecS
indexed :: Data Size -> (Data Ix -> a) -> (Par n :>> a)
indexed = Indexed
unfold :: Computable s => Data Size -> (s -> (a,s)) -> s -> (Seq n :>> a)
unfold = Unfold
freezeVector :: forall t n a . (NaturalT n, Storable a) =>
(t n :>> Data a) -> Data (n :> a)
freezeVector (Indexed sz ixf) = parallel sz ixf
freezeVector (Unfold sz step s) = snd $ for 0 end (s,arr) body
where
end = value $ fromIntegerT (undefined :: n) 1
arr = array [] :: Data (n :> a)
body i (s, arr :: Data (n :> a)) = (s', setIx arr i a)
where
(a,s') = step s
unfreezeVector :: (NaturalT n, Storable a, AccessPattern t) =>
Data Size -> Data (n :> a) -> (t n :>> Data a)
unfreezeVector sz arr = genericVector vec (toSeq vec)
where
vec = Indexed sz (getIx arr)
vector :: (NaturalT n, Storable a, AccessPattern t, ListBased a ~ a) =>
[a] -> (t n :>> Data a)
vector as = unfreezeVector sz $ array as
where
sz = value $ Prelude.length as
instance (NaturalT n, Storable a, AccessPattern t)
=> Computable (t n :>> Data a)
where
type Internal (t n :>> Data a) = (Int, n :> Internal (Data a))
internalize vec =
internalize (length vec, freezeVector $ map internalize vec)
externalize sz_a = map externalize $ unfreezeVector sz a
where
sz = externalize $ ref $ GetTuple (T::T D0) sz_a
a = externalize $ ref $ GetTuple (T::T D1) sz_a
instance
( NaturalT n1
, NaturalT n2
, Storable a
, AccessPattern t1
, AccessPattern t2
) =>
Computable (t1 n1 :>> t2 n2 :>> Data a)
where
type Internal (t1 n1 :>> t2 n2 :>> Data a) =
(Int, n1 :> Int, n1 :> n2 :> Internal (Data a))
internalize vec = internalize
( length vec
, freezeVector $ map length vec
, freezeVector $ map (freezeVector . map internalize) vec
)
externalize inp
= map (map externalize . uncurry unfreezeVector)
$ zip sz2sV (unfreezeVector sz1 a)
where
sz1 = externalize $ ref $ GetTuple (T::T D0) inp
sz2s = externalize $ ref $ GetTuple (T::T D1) inp
a = externalize $ ref $ GetTuple (T::T D2) inp
sz2sV = unfreezeVector sz1 sz2s :: t1 n1 :>> Data Int
toSeq :: (t n :>> a) -> (Seq n :>> a)
toSeq (Indexed sz ixf) = Unfold sz (\i -> (ixf i, i+1)) 0
toSeq (Unfold sz step s) = Unfold sz step s
resize :: NaturalT n => (t m :>> a) -> (t n :>> a)
resize (Indexed sz ixf) = Indexed sz ixf
resize (Unfold sz step s) = Unfold sz step s
toPar :: (NaturalT n, Storable a) => (t n :>> Data a) -> VectorP n a
toPar vec = unfreezeVector (length vec) $ freezeVector vec
index :: (t :>> a) -> Data Ix -> a
index (Indexed _ ixf) i = ixf i
index (Unfold _ step s) i = fst $ step $ fst $ while cont body (s,0)
where
cont = (<i) . snd
body = ((snd . step) *** (+1))
instance RandomAccess (Par n :>> a)
where
type Elem (Par n :>> a) = a
(!) = index
length :: (t n :>> a) -> Data Size
length (Indexed sz _) = sz
length (Unfold sz _ _) = sz
(++) :: Computable a => (t m :>> a) -> (t n :>> a) -> (t (m :+ n) :>> a)
Indexed sz1 ixf1 ++ Indexed sz2 ixf2 = Indexed (sz1+sz2) ixf
where
ixf i = ifThenElse (i < sz1) ixf1 (ixf2 . subtract sz1) i
Unfold sz1 step1 s1 ++ Unfold sz2 step2 s2 = Unfold (sz1+sz2) step (0, (s1,s2))
where
step (n, (s1',s2')) = n<sz1 ?
( let (a,s1'') = step1 s1' in (a, (n+1, (s1'', s2')))
, let (a,s2'') = step2 s2' in (a, (n+1, (s1', s2'')))
)
infixr 5 ++
take :: Data Int -> (t n :>> a) -> (t n :>> a)
take n (Indexed sz ixf) = Indexed sz' ixf
where
sz' = min sz n
take n (Unfold sz step s) = Unfold sz' step s
where
sz' = min sz n
drop :: Data Int -> (t n :>> a) -> (t n :>> a)
drop n (Indexed sz ixf) = Indexed sz' (\x -> ixf (x+n))
where
sz' = max 0 (szn)
drop n (Unfold sz step s) = Unfold sz' step s'
where
sz' = max 0 (szn)
s' = for 0 (n1) s (\_ -> snd . step)
dropWhile :: (a -> Data Bool) -> (t n :>> a) -> (t n :>> a)
dropWhile cont vec@(Indexed _ _) = drop i vec
where
i = while ((< length vec) &&* (cont . (vec !))) (+1) 0
dropWhile cont vec@(Unfold sz step s) = Unfold (szi) step s'
where
(s',i) = while condition (\(s,i) -> (snd $ step s, i+1)) (s,0)
where
condition = ((\(s,i) -> i <= length vec) &&* (cont.fst.step.fst))
splitAt :: Data Int -> (t n :>> a) -> (t n :>> a, t n :>> a)
splitAt n vec = (take n vec, drop n vec)
head :: (t n :>> a) -> a
head = flip index 0
last :: (t n :>> a) -> a
last vec = index vec (length vec 1)
tail :: (t n :>> a) -> (t n :>> a)
tail = drop 1
init :: (t n :>> a) -> (t n :>> a)
init vec = take (length vec 1) vec
tails :: AccessPattern u => (t n :>> a) -> (u n :>> t n :>> a)
tails vec = genericVector vecP vecS
where
sz = length vec
vecP = Indexed sz (\n -> drop n vec)
vecS = Unfold sz (\n -> (drop n vec, n+1)) 0
inits :: AccessPattern u => (t n :>> a) -> (u n :>> t n :>> a)
inits vec = genericVector vecP vecS
where
sz = length vec
vecP = Indexed sz (\n -> take n vec)
vecS = Unfold sz (\n -> (take n vec, n+1)) 0
permute :: (Data Size -> Data Ix -> Data Ix) -> ((Par n :>> a) -> (Par n :>> a))
permute perm (Indexed sz ixf) = Indexed sz (ixf . perm sz)
reverse :: (Par n :>> a) -> (Par n :>> a)
reverse = permute $ \sz i -> sz1i
replicate :: AccessPattern t => Data Int -> a -> (t n :>> a)
replicate n a = genericVector vecP vecS
where
vecP = Indexed n (const a)
vecS = Unfold n (const (a, unit)) unit
enumFromTo :: AccessPattern t => Data Int -> Data Int -> (t n :>> Data Int)
enumFromTo m n = genericVector vecP vecS
where
sz = nm+1
vecP = indexed sz (+m)
vecS = unfold sz (\x -> (x,x+1)) m
zip :: (t n :>> a) -> (t n :>> b) -> (t n :>> (a,b))
zip (Indexed sz1 ixf1) (Indexed sz2 ixf2) =
Indexed (min sz1 sz2) (ixf1 &&& ixf2)
zip (Unfold sz1 step1 s1) (Unfold sz2 step2 s2) = Unfold sz step (s1, s2)
where
sz = min sz1 sz2
step (s1,s2) = ((a,b), (s1',s2'))
where
(a,s1') = step1 s1
(b,s2') = step2 s2
unzip :: (t n :>> (a,b)) -> (t n :>> a, t n :>> b)
unzip (Indexed sz ixf) = (Indexed sz (fst.ixf), Indexed sz (snd.ixf))
unzip (Unfold sz step s) =
(Unfold sz ((fst***id).step) s, Unfold sz ((snd***id).step) s)
map :: (a -> b) -> ((t n :>> a) -> (t n :>> b))
map f (Indexed sz ixf) = Indexed sz (f . ixf)
map f (Unfold sz step s) = Unfold sz ((f *** id) . step) s
zipWith :: (a -> b -> c) -> (t n :>> a) -> (t n :>> b) -> (t n :>> c)
zipWith f aVec bVec = map (uncurry f) $ zip aVec bVec
fold :: Computable a => (a -> b -> a) -> a -> (t n :>> b) -> a
fold f x (Unfold sz step s) = fst $ for 0 (sz1) (x,s) body
where
body i (m,n) = (f m m', n')
where
(m',n') = step n
fold f x (Indexed sz ixf) = for 0 (sz1) x (\i s -> f s (ixf i))
fold1 :: Computable a => (a -> a -> a) -> (t n :>> a) -> a
fold1 f a = fold f (head a) a
scan :: Computable a => (a -> b -> a) -> a -> (t n :>> b) -> (Seq n :>> a)
scan f a (Indexed sz ixf) = Unfold sz step (0,a)
where
step (i,a) = let a' = f a (ixf i) in (a', (i+1, a'))
scan f a (Unfold sz step s) = Unfold sz step' (s,a)
where
step' (s,a) = (a', (s',a'))
where
(b,s') = step s
a' = f a b
scan1 :: Computable a => (a -> a -> a) -> (t n :>> a) -> (Seq n :>> a)
scan1 f vec = scan f (head vec) (tail vec)
sum :: (Num a, Computable a) => (t n :>> a) -> a
sum = fold (+) 0
maximum :: Storable a => (t n :>> Data a) -> Data a
maximum = fold1 max
minimum :: Storable a => (t n :>> Data a) -> Data a
minimum = fold1 min
scalarProd :: (Primitive a, Num a) =>
(t n :>> Data a) -> (t n :>> Data a) -> Data a
scalarProd a b = sum (zipWith (*) a b)