module NumHask.Array
( Array(..)
, SomeArray(..)
, row
, col
, unsafeRow
, unsafeCol
, slice
, unsafeSlice
, index
, unsafeIndex
, foldAlong
, mapAlong
, concatenate
, zipWith
, transpose
, squeeze
, (><)
, mmult
, fromList
) where
import Data.Distributive
import Data.Functor.Rep
import Data.Promotion.Prelude
import Data.Singletons
import Data.Singletons.Prelude
import Data.Singletons.TypeLits
import GHC.Exts
import GHC.Show
import GHC.Generics (Generic1)
import NumHask.Array.Constraints
import NumHask.Prelude hiding (All, Map, (><), mmult, show, row, col, zipWith, transpose)
import qualified Data.Vector as V
import qualified NumHask.Prelude as P
import Data.Kind
newtype Array (r :: [Nat]) a = Array (V.Vector a) deriving (Functor, Eq, Foldable, Generic, Generic1, NFData)
data SomeArray a =
SomeArray [Int]
(V.Vector a)
deriving (Functor, Eq, Foldable)
someArray :: (SingI r) => Array (r :: [Nat]) a -> SomeArray a
someArray n@(Array v) = SomeArray (shape n) v
instance forall (r :: [Nat]). (SingI r) => HasShape (Array r) where
type Shape (Array r) = [Int]
shape _ = fmap fromIntegral (fromSing (sing :: Sing r))
ind :: [Int] -> [Int] -> Int
ind ns xs = sum $ P.zipWith (*) xs (drop 1 $ scanr (*) 1 ns)
unind :: [Int] -> Int -> [Int]
unind ns x =
fst $
foldr
(\a (acc, r) ->
let (d, m) = divMod r a
in (m : acc, d))
([], x)
ns
instance forall r. (SingI r) => Distributive (Array r) where
distribute f =
Array $ V.generate n $ \i -> fmap (\(Array v) -> V.unsafeIndex v i) f
where
n =
case (sing :: Sing r) of
SNil -> 1
(SCons x xs) -> product $ fromInteger <$> (fromSing x : fromSing xs)
instance forall (r :: [Nat]). (SingI r) => Representable (Array r) where
type Rep (Array r) = [Int]
tabulate f = Array $ V.generate (product ns) (f . unind ns)
where
ns =
case (sing :: Sing r) of
SNil -> []
(SCons x xs) -> fromIntegral <$> (fromSing x : fromSing xs)
index (Array xs) rs = xs V.! ind ns rs
where
ns =
case (sing :: Sing r) of
SNil -> []
(SCons x xs') -> fromIntegral <$> (fromSing x : fromSing xs')
instance (SingI r, Num a) => IsList (Array (r :: [Nat]) a) where
type Item (Array r a) = a
fromList l = Array $ V.fromList $ take n $ l ++ repeat 0
where
n =
case (sing :: Sing r) of
SNil -> 1
(SCons x xs') ->
product $ fromIntegral <$> (fromSing x : fromSing xs')
toList (Array v) = V.toList v
instance (Show a) => Show (SomeArray a) where
show r@(SomeArray l _) = go (length l) r
where
go n r'@(SomeArray l' v') =
case length l' of
0 -> show $ V.head v'
1 -> "[" ++ intercalate ", " (show <$> GHC.Exts.toList v') ++ "]"
x ->
"[" ++
intercalate
(",\n" ++ replicate (n x + 1) ' ')
(go n <$> flatten1 r') ++
"]"
flatten1 :: SomeArray a -> [SomeArray a]
flatten1 (SomeArray rep v) =
(\s -> SomeArray (drop 1 rep) (V.unsafeSlice (s * l) l v)) <$> ss
where
(n, l) =
case rep of
[] -> (0, 1)
x:r -> (x, product r)
ss = take n [0 ..]
instance (Show a, SingI r) => Show (Array (r :: [Nat]) a) where
show = show . someArray
(><) ::
forall (r :: [Nat]) (s :: [Nat]) a.
(CRing a, SingI r, SingI s, SingI (r :++ s))
=> Array r a
-> Array s a
-> Array (r :++ s) a
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i))
where
dimm = length (shape m)
mmult ::
forall m n k a.
(Semiring a, Num a, CRing a, KnownNat m, KnownNat n, KnownNat k)
=> Array '[ m, k] a
-> Array '[ k, n] a
-> Array '[ m, n] a
mmult x y = tabulate (\[i, j] -> unsafeRow i x <.> unsafeCol j y)
row ::
forall i a m n. (KnownNat m, KnownNat n, KnownNat i, (i :< m) ~ 'True)
=> Proxy i
-> Array '[ m, n] a
-> Array '[ n] a
row i_ = unsafeRow i
where
i = (fromIntegral . fromSing . singByProxy) i_
unsafeRow ::
forall a m n. (KnownNat m, KnownNat n)
=> Int
-> Array '[ m, n] a
-> Array '[ n] a
unsafeRow i t@(Array a) = Array $ V.unsafeSlice (i * n) n a
where
[_, n] = shape t
col ::
forall j a m n. (KnownNat m, KnownNat n, KnownNat j, (j :< n) ~ 'True)
=> Proxy j
-> Array '[ m, n] a
-> Array '[ m] a
col j_ = unsafeCol j
where
j = (fromIntegral . fromSing . singByProxy) j_
unsafeCol ::
forall a m n. (KnownNat m, KnownNat n)
=> Int
-> Array '[ m, n] a
-> Array '[ m] a
unsafeCol j t@(Array a) = Array $ V.generate m (\x -> a V.! (j + x * n))
where
[m, n] = shape t
unsafeIndex :: SingI r => Array r a -> [Int] -> a
unsafeIndex t@(Array a) i = a V.! ind (shape t) i
unsafeSlice :: (SingI r) => [[Int]] -> Array r a -> Array r0 a
unsafeSlice s t = Array (V.fromList [unsafeIndex t i | i <- sequence s])
type family Slice (xss :: [[Nat]]) :: [Nat] where
Slice xss = Map LengthSym0 xss
data AllLTSym0 (a :: TyFun [Nat] (TyFun Nat Bool -> Type))
data AllLTSym1 (l :: [Nat]) (a :: TyFun Nat Bool)
type instance Apply AllLTSym0 l = AllLTSym1 l
type instance Apply (AllLTSym1 l) n = All ((:>$$) n) l
slice ::
forall s r a. (SingI s, SingI r, And (ZipWith AllLTSym0 s r) ~ 'True)
=> Proxy s
-> Array r a
-> Array (Slice s) a
slice s_ = unsafeSlice s
where
s = ((fmap . fmap) fromInteger . fromSing . singByProxy) s_
chunkItUp :: [V.Vector a] -> Int -> V.Vector a -> [V.Vector a]
chunkItUp acc i v =
if null v
then acc
else let (c, r) = V.splitAt i v
in chunkItUp (c : acc) i r
zipWith :: (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWith fn (Array a) (Array b) = Array $ V.zipWith fn a b
foldAlong ::
forall s vw uvw uw w a.
( SingI s
, SingI uvw
, uw ~ (Fold s uvw)
, w ~ (Drop 1 vw)
, vw ~ (TailModule s uvw)
)
=> Proxy s
-> (Array vw a -> Array w a)
-> Array uvw a
-> Array uw a
foldAlong s_ f a@(Array v) =
Array $
V.concat
(foldl'
(\xs x ->
let (Array vx) = f (Array x)
in vx : xs)
[]
md)
where
s = (fromInteger . fromSing . singByProxy) s_
md = chunkItUp [] (product $ drop s $ shape a) v
mapAlong ::
forall s uvw vw a. (SingI s, SingI uvw, vw ~ (HeadModule s uvw))
=> Proxy s
-> (Array vw a -> Array vw a)
-> Array uvw a
-> Array uvw a
mapAlong s_ f a@(Array v) =
Array $
V.concat
(foldl'
(\xs x ->
let (Array vx) = f (Array x)
in vx : xs)
[]
md)
where
s = (fromInteger . fromSing . singByProxy) s_
md = chunkItUp [] (product $ drop s $ shape a) v
concatenate ::
forall s r t a. (SingI s, SingI r, SingI t, (IsValidConcat s t r) ~ 'True)
=> Proxy s
-> Array r a
-> Array t a
-> Array (Concatenate s t r) a
concatenate s_ r@(Array vr) t@(Array vt) =
Array . V.concat $ (concat . reverse . P.transpose) [rm, tm]
where
s = (fromInteger . fromSing . singByProxy) s_
rm = chunkItUp [] (product $ drop s $ shape t) vt
tm = chunkItUp [] (product $ drop s $ shape r) vr
transpose ::
forall s t a. (t ~ Transpose s)
=> Array s a
-> Array t a
transpose (Array x) = Array x
squeeze ::
forall s t a. (t ~ Squeeze s)
=> Array s a
-> Array t a
squeeze (Array x) = Array x
instance (SingI r, AdditiveMagma a) => AdditiveMagma (Array r a) where
plus = liftR2 plus
instance (SingI r, AdditiveUnital a) => AdditiveUnital (Array r a) where
zero = pureRep zero
instance (SingI r, AdditiveAssociative a) =>
AdditiveAssociative (Array r a)
instance (SingI r, AdditiveCommutative a) =>
AdditiveCommutative (Array r a)
instance (SingI r, AdditiveInvertible a) => AdditiveInvertible (Array r a) where
negate = fmapRep negate
instance (SingI r, Additive a) => Additive (Array r a)
instance (SingI r, AdditiveGroup a) => AdditiveGroup (Array r a)
instance (SingI r, MultiplicativeMagma a) =>
MultiplicativeMagma (Array r a) where
times = liftR2 times
instance (SingI r, MultiplicativeUnital a) =>
MultiplicativeUnital (Array r a) where
one = pureRep one
instance (SingI r, MultiplicativeAssociative a) =>
MultiplicativeAssociative (Array r a)
instance (SingI r, MultiplicativeCommutative a) =>
MultiplicativeCommutative (Array r a)
instance (SingI r, MultiplicativeInvertible a) =>
MultiplicativeInvertible (Array r a) where
recip = fmapRep recip
instance (SingI r, Multiplicative a) => Multiplicative (Array r a)
instance (SingI r, MultiplicativeGroup a) =>
MultiplicativeGroup (Array r a)
instance (SingI r, MultiplicativeMagma a, Additive a) =>
Distribution (Array r a)
instance (SingI r, Semiring a) => Semiring (Array r a)
instance (SingI r, Ring a) => Ring (Array r a)
instance (SingI r, CRing a) => CRing (Array r a)
instance (SingI r, Field a) => Field (Array r a)
instance (SingI r, ExpField a) => ExpField (Array r a) where
exp = fmapRep exp
log = fmapRep log
instance (SingI r, BoundedField a) => BoundedField (Array r a) where
isNaN f = or (fmapRep isNaN f)
instance (SingI r, Signed a) => Signed (Array r a) where
sign = fmapRep sign
abs = fmapRep abs
instance (ExpField a) => Normed (Array r a) a where
size r = sqrt $ foldr (+) zero $ (** (one + one)) <$> r
instance (SingI r, Epsilon a) => Epsilon (Array r a) where
nearZero f = and (fmapRep nearZero f)
aboutEqual a b = and (liftR2 aboutEqual a b)
instance (SingI r, ExpField a) => Metric (Array r a) a where
distance a b = size (a b)
instance (SingI r, Integral a) => Integral (Array r a) where
divMod a b = (d, m)
where
x = liftR2 divMod a b
d = fmap fst x
m = fmap snd x
instance (CRing a, Num a, Semiring a, SingI r) => Hilbert (Array r) a where
a <.> b = sum $ liftR2 (*) a b