module Data.Repa.Scalar.Product
(
(:*:) (..)
, IsProdList (..)
, IsKeyValues (..)
, Select (..)
, Discard (..)
, Mask (..)
, Keep (..)
, Drop (..))
where
import Data.Repa.Scalar.Singleton.Nat
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
data a :*: b
= !a :*: !b
deriving (Eq, Show)
infixr :*:
instance Functor ((:*:) a) where
fmap f ((:*:) x y) = (:*:) x (f y)
class IsProdList p where
isProdList :: p -> Bool
instance IsProdList () where
isProdList _ = True
instance IsProdList fs => IsProdList (f :*: fs) where
isProdList (_ :*: xs) = isProdList xs
class IsKeyValues p where
type Keys p
type Values p
keys :: p -> [Keys p]
values :: p -> Values p
instance IsKeyValues (k, v) where
type Keys (k, v) = k
type Values (k, v) = v
keys (k, _) = [k]
values (_, v) = v
instance (IsKeyValues p, IsKeyValues ps, Keys p ~ Keys ps)
=> IsKeyValues (p :*: ps) where
type Keys (p :*: ps) = Keys p
type Values (p :*: ps) = Values p :*: Values ps
keys (p :*: ps) = keys p ++ keys ps
values (p :*: ps) = values p :*: values ps
class IsProdList t
=> Select (n :: N) t where
type Select' n t
select :: Nat n -> t -> Select' n t
instance IsProdList ts
=> Select Z (t1 :*: ts) where
type Select' Z (t1 :*: ts) = t1
select Zero (t1 :*: _) = t1
instance Select n ts
=> Select (S n) (t1 :*: ts) where
type Select' (S n) (t1 :*: ts) = Select' n ts
select (Succ n) (_ :*: xs) = select n xs
class IsProdList t
=> Discard (n :: N) t where
type Discard' n t
discard :: Nat n -> t -> Discard' n t
instance IsProdList ts
=> Discard Z (t1 :*: ts) where
type Discard' Z (t1 :*: ts) = ts
discard Zero (_ :*: xs) = xs
instance Discard n ts
=> Discard (S n) (t1 :*: ts) where
type Discard' (S n) (t1 :*: ts) = t1 :*: Discard' n ts
discard (Succ n) (x1 :*: xs) = x1 :*: discard n xs
data Drop = Drop
data Keep = Keep
class (IsProdList m, IsProdList t) => Mask m t where
type Mask' m t
mask :: m -> t -> Mask' m t
instance Mask () () where
type Mask' () () = ()
mask () () = ()
instance Mask ms ts
=> Mask (Keep :*: ms) (t1 :*: ts) where
type Mask' (Keep :*: ms) (t1 :*: ts) = t1 :*: Mask' ms ts
mask (_ :*: ms) (x1 :*: xs) = x1 :*: mask ms xs
instance Mask ms ts
=> Mask (Drop :*: ms) (t1 :*: ts) where
type Mask' (Drop :*: ms) (t1 :*: ts) = Mask' ms ts
mask (_ :*: ms) (_ :*: xs) = mask ms xs
data instance U.Vector (a :*: b)
= V_Prod
!Int
!(U.Vector a)
!(U.Vector b)
instance (U.Unbox a, U.Unbox b)
=> U.Unbox (a :*: b)
data instance U.MVector s (a :*: b)
= MV_Prod !Int
!(U.MVector s a)
!(U.MVector s b)
instance (U.Unbox a, U.Unbox b)
=> M.MVector U.MVector (a :*: b) where
basicLength (MV_Prod n_ _as _bs) = n_
basicUnsafeSlice i_ m_ (MV_Prod _n_ as bs)
= MV_Prod m_ (M.basicUnsafeSlice i_ m_ as)
(M.basicUnsafeSlice i_ m_ bs)
basicOverlaps (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
= M.basicOverlaps as1 as2
|| M.basicOverlaps bs1 bs2
basicUnsafeNew n_
= do as <- M.basicUnsafeNew n_
bs <- M.basicUnsafeNew n_
return $ MV_Prod n_ as bs
basicUnsafeReplicate n_ (a :*: b)
= do as <- M.basicUnsafeReplicate n_ a
bs <- M.basicUnsafeReplicate n_ b
return $ MV_Prod n_ as bs
basicUnsafeRead (MV_Prod _n_ as bs) i_
= do a <- M.basicUnsafeRead as i_
b <- M.basicUnsafeRead bs i_
return (a :*: b)
basicUnsafeWrite (MV_Prod _n_ as bs) i_ (a :*: b)
= do M.basicUnsafeWrite as i_ a
M.basicUnsafeWrite bs i_ b
basicClear (MV_Prod _n_ as bs)
= do M.basicClear as
M.basicClear bs
basicSet (MV_Prod _n_ as bs) (a :*: b)
= do M.basicSet as a
M.basicSet bs b
basicUnsafeCopy (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
= do M.basicUnsafeCopy as1 as2
M.basicUnsafeCopy bs1 bs2
basicUnsafeMove (MV_Prod _n_1 as1 bs1) (MV_Prod _n_2 as2 bs2)
= do M.basicUnsafeMove as1 as2
M.basicUnsafeMove bs1 bs2
basicUnsafeGrow (MV_Prod n_ as bs) m_
= do as' <- M.basicUnsafeGrow as m_
bs' <- M.basicUnsafeGrow bs m_
return $ MV_Prod (m_ + n_) as' bs'
instance (U.Unbox a, U.Unbox b)
=> G.Vector U.Vector (a :*: b) where
basicUnsafeFreeze (MV_Prod n_ as bs)
= do as' <- G.basicUnsafeFreeze as
bs' <- G.basicUnsafeFreeze bs
return $ V_Prod n_ as' bs'
basicUnsafeThaw (V_Prod n_ as bs)
= do as' <- G.basicUnsafeThaw as
bs' <- G.basicUnsafeThaw bs
return $ MV_Prod n_ as' bs'
basicLength (V_Prod n_ _as _bs)
= n_
basicUnsafeSlice i_ m_ (V_Prod _n_ as bs)
= V_Prod m_ (G.basicUnsafeSlice i_ m_ as)
(G.basicUnsafeSlice i_ m_ bs)
basicUnsafeIndexM (V_Prod _n_ as bs) i_
= do a <- G.basicUnsafeIndexM as i_
b <- G.basicUnsafeIndexM bs i_
return (a :*: b)
basicUnsafeCopy (MV_Prod _n_1 as1 bs1) (V_Prod _n_2 as2 bs2)
= do G.basicUnsafeCopy as1 as2
G.basicUnsafeCopy bs1 bs2
elemseq _ (a :*: b)
= G.elemseq (undefined :: U.Vector a) a
. G.elemseq (undefined :: U.Vector b) b