module TypeUnary.Vec
(
module TypeUnary.Nat
, Vec(..), unConsV, headV, tailV, joinV, (<+>), indices, iota
, Vec0,Vec1,Vec2,Vec3,Vec4,Vec5,Vec6,Vec7,Vec8,Vec9
, Vec10,Vec11,Vec12,Vec13,Vec14,Vec15,Vec16
, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8
, un1, un2, un3, un4
, get, get0, get1, get2, get3
, update
, set, set0, set1, set2, set3
, getI, setI
, flattenV, chunkV, swizzle, split, deleteV, elemsV
, zipV , zipWithV , unzipV
, zipV3, zipWithV3, unzipV3
, transpose, cross
, ToVec(..)
) where
import Prelude hiding (foldr,sum,and)
import Data.Monoid (Monoid(..),(<>))
import Control.Applicative (Applicative(..),liftA2,(<$>))
import Data.Foldable (Foldable(..),toList,sum)
import Data.Traversable (Traversable(..))
import Data.Typeable (Typeable)
import Foreign.Storable
import Foreign.Ptr (Ptr,plusPtr,castPtr)
import Control.Newtype (Newtype(..))
import Data.VectorSpace
import TypeUnary.Nat
infixr 5 :<
data Vec :: * -> * -> * where
ZVec :: Vec Z a
(:<) :: a -> Vec n a -> Vec (S n) a
deriving Typeable
unConsV :: Vec (S n) a -> (a, Vec n a)
unConsV (a :< as) = (a,as)
headV :: Vec (S n) a -> a
headV (a :< _) = a
tailV :: Vec (S n) a -> Vec n a
tailV (_ :< as') = as'
cant :: String -> a
cant str = error $ str ++ ": GHC doesn't know this case can't happen."
cantV :: String -> a
cantV str = cant (str ++ " on Vec")
instance Eq a => Eq (Vec n a) where
ZVec == ZVec = True
a :< as == b :< bs = a==b && as==bs
_ == _ = cantV "(==)"
instance Ord a => Ord (Vec n a) where
ZVec `compare` ZVec = EQ
(a :< as) `compare` (b :< bs) =
(a `compare` b) `mappend` (as `compare` bs)
_ `compare` _ = cantV "compare"
instance Show a => Show (Vec n a) where
show v = "elemsV " ++ show (toList v)
instance (IsNat n, Monoid a) => Monoid (Vec n a) where
mempty = pure mempty
mappend = liftA2 mappend
instance Functor (Vec n) where
fmap _ ZVec = ZVec
fmap f (a :< u) = f a :< fmap f u
instance IsNat n => Applicative (Vec n) where
pure = pureV nat
(<*>) = applyV nat
pureV :: Nat n -> a -> Vec n a
pureV Zero _ = ZVec
pureV (Succ n) a = a :< pureV n a
inVecS :: ((a, Vec n a) -> (b, Vec n b)) -> (Vec (S n) a -> Vec (S n) b)
inVecS f = uncurry (:<) . f . unConsV
inVecS2 :: ((a, Vec n a) -> (b, Vec n b) -> (c, Vec n c))
-> (Vec (S n) a -> Vec (S n) b -> Vec (S n) c )
inVecS2 f = inVecS . f . unConsV
applyV :: Nat n -> Vec n (a -> b) -> Vec n a -> Vec n b
applyV Zero = \ _ _ -> ZVec
applyV (Succ n) = inVecS2 (\ (f,fs) (x,xs) -> (f x , applyV n fs xs))
instance IsNat n => Monad (Vec n) where
return = pure
v >>= f = joinV (f <$> v)
joinV :: Vec n (Vec n a) -> Vec n a
joinV ZVec = ZVec
joinV ((a :< _) :< vs) = a :< joinV (tailV <$> vs)
joinV _ = cant "joinV"
instance Foldable (Vec n) where
foldMap _ ZVec = mempty
foldMap h (a :< as) = h a <> foldMap h as
instance Traversable (Vec n) where
traverse _ ZVec = pure ZVec
traverse f (a :< as) = liftA2 (:<) (f a) (traverse f as)
instance Newtype (Vec Z a) () where
pack () = ZVec
unpack ZVec = ()
instance Newtype (Vec (S n) a) (a,Vec n a) where
pack = uncurry (:<)
unpack = unConsV
instance (IsNat n, Num a) => AdditiveGroup (Vec n a) where
{ zeroV = pure 0; (^+^) = liftA2 (+) ; negateV = fmap negate }
instance (IsNat n, Num a) => VectorSpace (Vec n a) where
type Scalar (Vec n a) = Vec1 a
(*^) (s :< ZVec) = fmap (s *)
(*^) _ = cantV "(*^)"
instance (IsNat n, Num a) => InnerSpace (Vec n a) where
(<.>) = (result.result) (vec1 . sum) (liftA2 (*))
instance (IsNat n, Storable a) => Storable (Vec n a) where
sizeOf = const ((natToZ (nat :: Nat n))
* sizeOf (undefined :: a))
alignment = const (alignment (undefined :: a))
peek = peekV . castPtr
poke = pokeV . castPtr
infixl 1 <+>
(<+>) :: Vec m a -> Vec n a -> Vec (m :+: n) a
ZVec <+> v = v
(a :< u) <+> v = a :< (u <+> v)
peekV :: (IsNat n, Storable a) => Ptr a -> IO (Vec n a)
peekV = peekV' nat
pokeV :: (IsNat n, Storable a) => Ptr a -> Vec n a -> IO ()
pokeV = pokeV' nat
peekV' :: Storable a => Nat n -> Ptr a -> IO (Vec n a)
peekV' Zero _ = return ZVec
peekV' (Succ n) p = do a <- peek p
as <- peekV' n (p `plusPtr` sizeOf a)
return (a :< as)
pokeV' :: Storable a => Nat n -> Ptr a -> Vec n a -> IO ()
pokeV' Zero _ ZVec = return ()
pokeV' (Succ n) p (a :< as) = do poke p a
pokeV' n (p `plusPtr` sizeOf a) as
pokeV' _ _ _ = cant "pokeV"
indices :: IsNat n => Vec n (Index n)
indices = indices' nat
indices' :: Nat n -> Vec n (Index n)
indices' Zero = ZVec
indices' (Succ n) = index0 :< fmap succI (indices' n)
iota :: (IsNat n, Num a, Enum a) => Vec n a
iota = unIndex <$> indices
type Vec0 = Vec N0
type Vec1 = Vec N1
type Vec2 = Vec N2
type Vec3 = Vec N3
type Vec4 = Vec N4
type Vec5 = Vec N5
type Vec6 = Vec N6
type Vec7 = Vec N7
type Vec8 = Vec N8
type Vec9 = Vec N9
type Vec10 = Vec N10
type Vec11 = Vec N11
type Vec12 = Vec N12
type Vec13 = Vec N13
type Vec14 = Vec N14
type Vec15 = Vec N15
type Vec16 = Vec N16
vec1 :: a -> Vec1 a
vec1 a = a :< ZVec
vec2 :: a -> a -> Vec2 a
vec2 a b = a :< vec1 b
vec3 :: a -> a -> a -> Vec3 a
vec3 a b c = a :< vec2 b c
vec4 :: a -> a -> a -> a -> Vec4 a
vec4 a b c d = a :< vec3 b c d
vec5 :: a -> a -> a -> a -> a -> Vec5 a
vec5 a b c d e = a :< vec4 b c d e
vec6 :: a -> a -> a -> a -> a -> a -> Vec6 a
vec6 a b c d e f = a :< vec5 b c d e f
vec7 :: a -> a -> a -> a -> a -> a -> a -> Vec7 a
vec7 a b c d e f g = a :< vec6 b c d e f g
vec8 :: a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a
vec8 a b c d e f g h = a :< vec7 b c d e f g h
un1 :: Vec1 a -> a
un1 (a :< ZVec) = a
un1 _ = cant "un1"
un2 :: Vec2 a -> (a,a)
un2 (a :< b :< ZVec) = (a,b)
un2 _ = cant "un2"
un3 :: Vec3 a -> (a,a,a)
un3 (a :< b :< c :< ZVec) = (a,b,c)
un3 _ = cant "un3"
un4 :: Vec4 a -> (a,a,a,a)
un4 (a :< b :< c :< d :< ZVec) = (a,b,c,d)
un4 _ = cant "un4"
get :: Index n -> Vec n a -> a
get (Index ZLess Zero ) (a :< _) = a
get (Index (SLess p) (Succ m)) (_ :< as) = get (Index p m) as
get _ _ = cant "get"
get0 :: Vec (N1 :+: n) a -> a
get1 :: Vec (N2 :+: n) a -> a
get2 :: Vec (N3 :+: n) a -> a
get3 :: Vec (N4 :+: n) a -> a
get0 = get index0
get1 = get index1
get2 = get index2
get3 = get index3
update :: Index n -> (a -> a) -> Vec n a -> Vec n a
update (Index ZLess Zero ) f (a :< as) = f a :< as
update (Index (SLess p) (Succ m)) f (a :< as) = a :< update (Index p m) f as
update _ _ _ = cantV "update"
set :: Index n -> a -> Vec n a -> Vec n a
set i a' = update i (const a')
set0 :: a -> Vec (N1 :+: n) a -> Vec (N1 :+: n) a
set1 :: a -> Vec (N2 :+: n) a -> Vec (N2 :+: n) a
set2 :: a -> Vec (N3 :+: n) a -> Vec (N3 :+: n) a
set3 :: a -> Vec (N4 :+: n) a -> Vec (N4 :+: n) a
set0 = set index0
set1 = set index1
set2 = set index2
set3 = set index3
getI :: (IsNat n, Show i, Integral i) => i -> Vec n a -> a
getI = get . coerceToIndex
setI :: (IsNat n, Show i, Integral i) => i -> a -> Vec n a -> Vec n a
setI = set . coerceToIndex
flattenV :: IsNat n => Vec n (Vec m a) -> Vec (n :*: m) a
flattenV = flattenV' nat
flattenV' :: Nat n -> Vec n (Vec m a) -> Vec (n :*: m) a
flattenV' Zero _ = ZVec
flattenV' (Succ n') (v :< vs') = v <+> flattenV' n' vs'
flattenV' _ _ = error "flattenV': GHC doesn't know this case can't happen."
chunkV :: (IsNat n, IsNat m) => Vec (n :*: m) a -> Vec n (Vec m a)
chunkV = chunkV' nat
chunkV' :: IsNat m => Nat n -> Vec (n :*: m) a -> Vec n (Vec m a)
chunkV' Zero ZVec = ZVec
chunkV' (Succ n) as = v :< chunkV' n as' where (v,as') = split as
chunkV' _ _ = cant "chunkV"
swizzle :: Vec n (Index m) -> Vec m a -> Vec n a
swizzle is v = flip get v <$> is
split :: IsNat n => Vec (n :+: m) a -> (Vec n a, Vec m a)
split = split' nat
split' :: Nat n -> Vec (n :+: m) a -> (Vec n a, Vec m a)
split' Zero v = (ZVec, v)
split' (Succ n) (a :< as) = (a :< bs, cs)
where
(bs,cs) = split' n as
split' _ _ = cantV "split"
#if 0
addZ :: IsNat n => Dict (n ~ (n :+: Z))
addZ = addZ' nat
addZ' :: Nat n -> Dict (n ~ (n :+: Z))
addZ' Zero = Dict
addZ' (Succ m) | Dict <- addZ' m = Dict
add1 :: IsNat m => Dict ((m :+: S Z) ~ S m)
add1 = add1' nat
add1' :: Nat m -> Dict ((m :+: S Z) ~ S m)
add1' Zero = Dict
add1' (Succ m) | Dict <- add1' m = Dict
addS' :: IsNat m => Nat n -> Dict ((m :+: S n) ~ S (m :+: n))
addS' Zero | Dict <- add1 = Dict
...
reverseV :: forall n a. IsNat n => Vec n a -> Vec n a
reverseV | Dict <- (addZ :: Dict (n ~ (n :+: Z))) = reverse' nat ZVec
reverse' :: Nat n -> Vec m a -> Vec n a -> Vec (n :+: m) a
reverse' Zero ma ZVec = ma
reverse' (Succ n) ma (a :< as) = reverse' n (a :< ma) as
#endif
deleteV :: Eq a => a -> Vec (S n) a -> Vec n a
deleteV b (a :< as) | a == b = as
deleteV _ (_ :< ZVec) = error "deleteV: did not find element"
deleteV b (a :< as@(_:<_)) = a :< deleteV b as
elemsV :: IsNat n => [a] -> Vec n a
elemsV = elemsV' nat
elemsV' :: Nat n -> [a] -> Vec n a
elemsV' Zero [] = ZVec
elemsV' Zero (_:_) = error "elemsV: too many elements"
elemsV' (Succ _) [] = error "elemsV: too few elements"
elemsV' (Succ n) (a : as) = a :< elemsV' n as
zipV :: Vec n a -> Vec n b -> Vec n (a,b)
zipV = zipWithV (,)
zipV3 :: Vec n a -> Vec n b -> Vec n c -> Vec n (a,b,c)
zipV3 = zipWithV3 (,,)
zipWithV :: (a -> b -> c) -> Vec n a -> Vec n b -> Vec n c
zipWithV _ ZVec ZVec = ZVec
zipWithV f (a :< as) (b :< bs) = f a b :< zipWithV f as bs
zipWithV _ _ _ = cant "zipWithV"
zipWithV3 :: (a -> b -> c -> d) -> Vec n a -> Vec n b -> Vec n c -> Vec n d
zipWithV3 _ ZVec ZVec ZVec = ZVec
zipWithV3 f (a :< as) (b :< bs) (c :< cs) = f a b c :< zipWithV3 f as bs cs
zipWithV3 _ _ _ _ = cant "zipWithV3"
unzipV :: Vec n (a,b) -> (Vec n a, Vec n b)
unzipV ZVec = (ZVec,ZVec)
unzipV ((a,b) :< ps) = (a :< as, b :< bs) where (as,bs) = unzipV ps
unzipV3 :: Vec n (a,b,c) -> (Vec n a, Vec n b, Vec n c)
unzipV3 ZVec = (ZVec,ZVec,ZVec)
unzipV3 ((a,b,c) :< ps) = (a :< as, b :< bs, c :< cs)
where (as,bs,cs) = unzipV3 ps
cross :: Vec m a -> Vec n b -> Vec m (Vec n (a,b))
cross as bs = (\ a -> (a,) <$> bs) <$> as
transpose :: IsNat n => Vec m (Vec n a) -> Vec n (Vec m a)
transpose = sequenceA
class ToVec c n a where
toVec :: c -> Vec n a
instance ToVec (Vec n a) n a where toVec = id
instance IsNat n => ToVec [a] n a where
toVec = toVecL nat
toVecL :: Nat n -> [a] -> Vec n a
toVecL Zero [] = ZVec
toVecL (Succ m) (a:as) = a :< toVecL m as
toVecL _ _ = error "toVecL: length mismatch"
result :: (b -> b') -> ((a -> b) -> (a -> b'))
result = (.)
#define INSTANCE_Enum
#define CONSTRAINTS IsNat n,
#define APPLICATIVE Vec n
#include "ApplicativeNumeric-inc.hs"