module TypeUnary.Vec
(
module TypeUnary.Nat
, Vec(..), headV, tailV, joinV, (<+>), indices
, Vec0,Vec1,Vec2,Vec3,Vec4,Vec5,Vec6,Vec7,Vec8,Vec9
, Vec10,Vec11,Vec12,Vec13,Vec14,Vec15,Vec16
, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8
, un1, un2, un3, un4
, get, get0, get1, get2, get3
, update
, set, set0, set1, set2, set3
, getI, setI
, flattenV, swizzle, split, deleteV, elemsV, unzipV
, ToVec(..)
) where
import Prelude hiding (foldr,sum)
import Data.Monoid (Monoid(..))
import Control.Applicative (Applicative(..),liftA2,(<$>))
import Data.Foldable (Foldable(..),toList,sum)
import Data.Traversable (Traversable(..))
import Foreign.Storable
import Foreign.Ptr (Ptr,plusPtr,castPtr)
import Data.VectorSpace
import TypeUnary.Nat
infixr 5 :<
data Vec :: * -> * -> * where
ZVec :: Vec Z a
(:<) :: a -> Vec n a -> Vec (S n) a
headV :: Vec (S n) a -> a
headV (a :< _) = a
tailV :: Vec (S n) a -> Vec n a
tailV (_ :< as') = as'
instance Eq a => Eq (Vec n a) where
ZVec == ZVec = True
a :< as == b :< bs = a==b && as==bs
instance Ord a => Ord (Vec n a) where
ZVec `compare` ZVec = EQ
(a :< as) `compare` (b :< bs) =
case a `compare` b of
LT -> LT
GT -> GT
EQ -> as `compare` bs
instance Show a => Show (Vec n a) where
show v = "elemsV " ++ show (toList v)
instance (IsNat n, Monoid a) => Monoid (Vec n a) where
mempty = pure mempty
mappend = liftA2 mappend
instance Functor (Vec n) where
fmap _ ZVec = ZVec
fmap f (a :< u) = f a :< fmap f u
instance IsNat n => Applicative (Vec n) where
pure = pureV
(<*>) = applyV
pureV :: IsNat n => a -> Vec n a
pureV = pureV' nat
pureV' :: Nat n -> a -> Vec n a
pureV' Zero _ = ZVec
pureV' (Succ n) a = a :< pureV' n a
applyV :: Vec n (a -> b) -> Vec n a -> Vec n b
ZVec `applyV` ZVec = ZVec
(f :< fs) `applyV` (x :< xs) = f x :< (fs `applyV` xs)
instance IsNat n => Monad (Vec n) where
return = pure
v >>= f = joinV (f <$> v)
joinV :: Vec n (Vec n a) -> Vec n a
joinV ZVec = ZVec
joinV ((a :< _) :< vs) = a :< joinV (tailV <$> vs)
instance Foldable (Vec n) where
foldr _ b ZVec = b
foldr h b (a :< as) = a `h` foldr h b as
instance Traversable (Vec n) where
traverse _ ZVec = pure ZVec
traverse f (a :< as) = liftA2 (:<) (f a) (traverse f as)
instance (IsNat n, Num a) => AdditiveGroup (Vec n a) where
{ zeroV = pure 0; (^+^) = liftA2 (+) ; negateV = fmap negate }
instance (IsNat n, Num a) => VectorSpace (Vec n a) where
type Scalar (Vec n a) = Vec1 a
(*^) (s :< ZVec) = fmap (s *)
instance (IsNat n, Num a) => InnerSpace (Vec n a) where
(<.>) = (result.result) (vec1 . sum) (liftA2 (*))
instance (IsNat n, Storable a) => Storable (Vec n a) where
sizeOf = const ((natToZ (nat :: Nat n))
* sizeOf (undefined :: a))
alignment = const (alignment (undefined :: a))
peek = peekV . castPtr
poke = pokeV . castPtr
infixl 1 <+>
(<+>) :: Vec m a -> Vec n a -> Vec (m :+: n) a
ZVec <+> v = v
(a :< u) <+> v = a :< (u <+> v)
peekV :: (IsNat n, Storable a) => Ptr a -> IO (Vec n a)
peekV = peekV' nat
pokeV :: (IsNat n, Storable a) => Ptr a -> Vec n a -> IO ()
pokeV = pokeV' nat
peekV' :: Storable a => Nat n -> Ptr a -> IO (Vec n a)
peekV' Zero _ = return ZVec
peekV' (Succ n) p = do a <- peek p
as <- peekV' n (p `plusPtr` sizeOf a)
return (a :< as)
pokeV' :: Storable a => Nat n -> Ptr a -> Vec n a -> IO ()
pokeV' Zero _ ZVec = return ()
pokeV' (Succ n) p (a :< as) = do poke p a
pokeV' n (p `plusPtr` sizeOf a) as
indices :: Nat n -> Vec n (Index n)
indices Zero = ZVec
indices (Succ n) = index0 :< fmap succI (indices n)
type Vec0 = Vec N0
type Vec1 = Vec N1
type Vec2 = Vec N2
type Vec3 = Vec N3
type Vec4 = Vec N4
type Vec5 = Vec N5
type Vec6 = Vec N6
type Vec7 = Vec N7
type Vec8 = Vec N8
type Vec9 = Vec N9
type Vec10 = Vec N10
type Vec11 = Vec N11
type Vec12 = Vec N12
type Vec13 = Vec N13
type Vec14 = Vec N14
type Vec15 = Vec N15
type Vec16 = Vec N16
vec1 :: a -> Vec1 a
vec1 a = a :< ZVec
vec2 :: a -> a -> Vec2 a
vec2 a b = a :< vec1 b
vec3 :: a -> a -> a -> Vec3 a
vec3 a b c = a :< vec2 b c
vec4 :: a -> a -> a -> a -> Vec4 a
vec4 a b c d = a :< vec3 b c d
vec5 :: a -> a -> a -> a -> a -> Vec5 a
vec5 a b c d e = a :< vec4 b c d e
vec6 :: a -> a -> a -> a -> a -> a -> Vec6 a
vec6 a b c d e f = a :< vec5 b c d e f
vec7 :: a -> a -> a -> a -> a -> a -> a -> Vec7 a
vec7 a b c d e f g = a :< vec6 b c d e f g
vec8 :: a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a
vec8 a b c d e f g h = a :< vec7 b c d e f g h
un1 :: Vec1 a -> a
un1 (a :< ZVec) = a
un2 :: Vec2 a -> (a,a)
un2 (a :< b :< ZVec) = (a,b)
un3 :: Vec3 a -> (a,a,a)
un3 (a :< b :< c :< ZVec) = (a,b,c)
un4 :: Vec4 a -> (a,a,a,a)
un4 (a :< b :< c :< d :< ZVec) = (a,b,c,d)
get :: Index n -> Vec n a -> a
get (Index ZLess Zero ) (a :< _) = a
get (Index (SLess p) (Succ m)) (_ :< as) = get (Index p m) as
get0 :: Vec (N1 :+: n) a -> a
get1 :: Vec (N2 :+: n) a -> a
get2 :: Vec (N3 :+: n) a -> a
get3 :: Vec (N4 :+: n) a -> a
get0 = get index0
get1 = get index1
get2 = get index2
get3 = get index3
update :: Index n -> (a -> a) -> Vec n a -> Vec n a
update (Index ZLess Zero ) f (a :< as) = f a :< as
update (Index (SLess p) (Succ m)) f (a :< as) = a :< update (Index p m) f as
set :: Index n -> a -> Vec n a -> Vec n a
set i a' = update i (const a')
set0 :: a -> Vec (N1 :+: n) a -> Vec (N1 :+: n) a
set1 :: a -> Vec (N2 :+: n) a -> Vec (N2 :+: n) a
set2 :: a -> Vec (N3 :+: n) a -> Vec (N3 :+: n) a
set3 :: a -> Vec (N4 :+: n) a -> Vec (N4 :+: n) a
set0 = set index0
set1 = set index1
set2 = set index2
set3 = set index3
getI :: (IsNat n, Show i, Integral i) => i -> Vec n a -> a
getI = get . coerceToIndex
setI :: (IsNat n, Show i, Integral i) => i -> a -> Vec n a -> Vec n a
setI = set . coerceToIndex
flattenV :: IsNat n => Vec n (Vec m a) -> Vec (n :*: m) a
flattenV = flattenV' nat
flattenV' :: Nat n -> Vec n (Vec m a) -> Vec (n :*: m) a
flattenV' Zero _ = ZVec
flattenV' (Succ n') (v :< vs') = v <+> flattenV' n' vs'
flattenV' _ _ = error "flattenV': GHC doesn't know this case can't happen."
swizzle :: Vec n (Index m) -> Vec m a -> Vec n a
swizzle ZVec _ = ZVec
swizzle (ix :< ixs) v = get ix v :< swizzle ixs v
split :: IsNat n => Vec (n :+: m) a -> (Vec n a, Vec m a)
split = split' nat
split' :: Nat n -> Vec (n :+: m) a -> (Vec n a, Vec m a)
split' Zero v = (ZVec, v)
split' (Succ n) (a :< as) = (a :< bs, cs)
where
(bs,cs) = split' n as
deleteV :: Eq a => a -> Vec (S n) a -> Vec n a
deleteV b (a :< as) | a == b = as
deleteV _ (_ :< ZVec) = error "deleteV: did not find element"
deleteV b (a :< as@(_:<_)) = a :< deleteV b as
elemsV :: IsNat n => [a] -> Vec n a
elemsV = elemsV' nat
elemsV' :: Nat n -> [a] -> Vec n a
elemsV' Zero [] = ZVec
elemsV' Zero (_:_) = error "elemsV: too many elements"
elemsV' (Succ _) [] = error "elemsV: too few elements"
elemsV' (Succ n) (a : as) = a :< elemsV' n as
unzipV :: Vec n (a,b) -> (Vec n a, Vec n b)
unzipV ZVec = (ZVec,ZVec)
unzipV ((a,b) :< ps) = (a :< as, b :< bs) where (as,bs) = unzipV ps
class ToVec c n a where
toVec :: c -> Vec n a
instance ToVec (Vec n a) n a where toVec = id
instance IsNat n => ToVec [a] n a where
toVec = toVecL nat
toVecL :: Nat n -> [a] -> Vec n a
toVecL Zero [] = ZVec
toVecL (Succ m) (a:as) = a :< toVecL m as
toVecL _ _ = error "toVecL: length mismatch"
result :: (b -> b') -> ((a -> b) -> (a -> b'))
result = (.)