{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module NumHask.Array where
import Data.Distributive (Distributive(..))
import Data.Functor.Rep (Representable(..), liftR2, pureRep, fmapRep)
import Data.List ((!!))
import GHC.Exts (IsList(..))
import GHC.Show (Show(..))
import NumHask.Error (impossible)
import NumHask.Array.Constraints
(Fold, HeadModule, TailModule, IsValidConcat, Concatenate, Transpose, Squeeze)
import NumHask.Prelude as P
import NumHask.Shape (HasShape(..))
import Numeric.Dimensions as D
import qualified Data.Singletons.Prelude as S
import qualified Data.Vector as V
data family Array (c :: Type -> Type) (ds :: [k]) (a :: Type)
newtype instance
Array c (ds :: [Nat]) t =
Array { _getContainer :: c t}
deriving (Functor, Foldable)
instance NFData (Array c ds t) where
rnf a = seq a ()
instance (Dimensions r) => HasShape (Array c (r :: [Nat])) where
type Shape (Array c r) = [Int]
shape _ = fmap fromIntegral (listDims $ dims @Nat @r)
newtype AnyArray c a = AnyArray ([Int], c a)
anyArray :: (Dimensions ds) => Array c (ds :: [Nat]) a -> AnyArray c a
anyArray arr@(Array c) = AnyArray (shape arr, c)
class (Functor f) => Container f where
generate :: Int -> (Int -> a) -> f a
idx :: f a -> Int -> a
cslice :: Int -> Int -> f a -> f a
zipWith :: (a -> a -> a) -> f a -> f a -> f a
chunkItUp :: [f a] -> Int -> f a -> [f a]
cfoldl' :: (b -> a -> b) -> b -> f a -> b
cfoldr :: (a -> b -> b) -> b -> f a -> b
cconcat :: [f a] -> f a
instance Container V.Vector where
generate = V.generate
idx = V.unsafeIndex
cslice = V.unsafeSlice
zipWith = V.zipWith
chunkItUp acc i v =
if null v
then acc
else let (c, r) = V.splitAt i v
in chunkItUp (c : acc) i r
cfoldl' = V.foldl'
cfoldr = V.foldr
cconcat = V.concat
instance Container [] where
generate n g = take n $ g <$> [0 ..]
idx = (!!)
cslice d t = take t . drop d
zipWith = P.zipWith
chunkItUp acc i v =
if null v
then acc
else let (c, r) = splitAt i v
in chunkItUp (c : acc) i r
cfoldl' = foldl'
cfoldr = foldr
cconcat = mconcat
instance (Eq (c t), Dimensions ds) => Eq (Array c (ds :: [Nat]) t) where
(Array a) == (Array b) = a == b
ind :: [Int] -> [Int] -> Int
ind ns xs = sum $ P.zipWith (*) xs (drop 1 $ scanr (*) 1 ns)
unind :: [Int] -> Int -> [Int]
unind ns x =
fst $
foldr
(\a (acc, r) ->
let (d, m) = divMod r a
in (m : acc, d))
([], x)
ns
instance forall r c. (Dimensions r, Container c) =>
Data.Distributive.Distributive (Array c (r :: [Nat])) where
distribute f = Array $ generate (fromIntegral n) $ \i -> fmap (\(Array v) -> idx v i) f
where
n = totalDim $ dims @Nat @r
instance forall r c. (Dimensions r, Container c) =>
Representable (Array c (r :: [Nat])) where
type Rep (Array c r) = [Int]
tabulate f = Array $ generate (fromIntegral $ product ns) (f . unind (fmap fromIntegral ns))
where
ns = listDims $ dims @Nat @r
index (Array xs) rs = xs `idx` ind (fmap fromIntegral ns) rs
where
ns = listDims $ dims @Nat @r
instance
( Item (Array c r a) ~ Item (c a)
, Dimensions r
, Additive a
, IsList (c a)
) =>
IsList (Array c (r :: [Nat]) a) where
type Item (Array c r a) = a
fromList l = Array $ fromList $ take n $ l ++ repeat zero
where
n = fromIntegral $ totalDim (dims @_ @r)
toList (Array v) = GHC.Exts.toList v
instance (Show a, Show (Item (c a)), Container c, IsList (c a)) => Show (AnyArray c a) where
show aa@(AnyArray (l,_)) = go (length l) aa
where
go n aa'@(AnyArray (l', c')) =
case length l' of
0 -> "[]"
1 -> "[" ++ intercalate ", " (GHC.Show.show <$> GHC.Exts.toList c') ++ "]"
x ->
"[" ++
intercalate
(",\n" ++ replicate (n - x + 1) ' ')
(go n <$> flatten1 aa') ++
"]"
flatten1 :: (Container c) => AnyArray c a -> [AnyArray c a]
flatten1 (AnyArray (rep, v)) =
(\s -> AnyArray (drop 1 rep, cslice (s * l) l v)) <$> ss
where
(n, l) =
case rep of
[] -> (0, 1)
x:r -> (x, product r)
ss = take n [0 ..]
instance (Show a, Show (Item (c a)), IsList (c a), Container c, Dimensions ds)
=> Show (Array c (ds :: [Nat]) a) where
show = GHC.Show.show . anyArray
type Vector c n = Array c '[ n]
type Matrix c m n = Array c '[ m, n]
(><) :: forall c (r :: [Nat]) (s :: [Nat]) a.
( Container c
, CommutativeRing a
, Dimensions r
, Dimensions s
, Dimensions ((D.++) r s))
=> Array c r a
-> Array c s a
-> Array c ((D.++) r s) a
(><) m n = tabulate (\i -> index m (take dimm i) * index n (drop dimm i))
where
dimm = length (shape m)
mmult :: forall c m n k a.
( Hilbert (Vector c k a)
, Dimensions '[ m, k]
, Dimensions '[ k, n]
, Dimensions '[ m, n]
, Container c
)
=> Matrix c (m :: Nat) (k :: Nat) a
-> Matrix c k n a
-> Matrix c m n a
mmult x y = tabulate go
where
go [i, j] = unsafeRow i x <.> unsafeCol j y
go _ = impossible "mmult only typechecks on arrays"
row :: forall c i a m n.
( Dimensions '[ m, n]
, Container c
, KnownNat i
, ((S.<) i m) ~ 'True
)
=> Proxy i
-> Matrix c m n a
-> Vector c n a
row i_ = unsafeRow i
where
i = (fromIntegral . S.fromSing . S.singByProxy) i_
rank2Shape
:: Dimensions '[ m, n]
=> Matrix c (m :: Nat) (n :: Nat) a
-> (Int, Int)
rank2Shape t =
case shape t of
[m, n] -> (m, n)
_ -> impossible "only typechecks for matricies"
unsafeRow :: forall c a m n.
( Container c
, Dimensions '[ m, n])
=> Int
-> Matrix c (m :: Nat) (n :: Nat) a
-> Vector c n a
unsafeRow i t@(Array a) = Array $ cslice (i * n) n a
where
(_, n) = rank2Shape t
unsafeCol ::
forall c a m n. (Container c, Dimensions '[ m, n])
=> Int
-> Matrix c (m :: Nat) (n :: Nat) a
-> Vector c m a
unsafeCol j t@(Array a) = Array $ generate m (\x -> a `idx` (j + x * n))
where
(m, n) = rank2Shape t
col :: forall c j a m n.
( Dimensions '[ m, n]
, Container c
, KnownNat j
, ((S.<) j n) ~ 'True
)
=> Proxy j
-> Matrix c m n a
-> Vector c m a
col j_ = unsafeCol j
where
j = (fromIntegral . S.fromSing . S.singByProxy) j_
unsafeIndex :: (Container c, Dimensions r) => Array c (r :: [Nat]) a -> [Int] -> a
unsafeIndex t@(Array a) i = a `idx` ind (shape t) i
unsafeSlice ::
(Container c, IsList (c a), Item (c a) ~ a, Dimensions r, Dimensions r0)
=> [[Int]]
-> Array c (r :: [Nat]) a
-> Array c (r0 :: [Nat]) a
unsafeSlice s t = Array (fromList [unsafeIndex t i | i <- sequence s])
slice s_ = unsafeSlice s
where
s = ((fmap . fmap) fromInteger . S.fromSing . S.singByProxy) s_
foldAlong ::
forall c s vw uvw uw w a.
( Container c
, KnownNat s
, Dimensions uvw
, uw ~ (Fold s uvw)
, w ~ (S.Drop 1 vw)
, vw ~ (TailModule s uvw)
)
=> Proxy s
-> (Array c vw a -> Array c w a)
-> Array c uvw a
-> Array c uw a
foldAlong s_ f a@(Array v) =
Array $
cconcat
(cfoldl'
(\xs x ->
let (Array vx) = f (Array x)
in vx : xs)
[]
md)
where
s = (fromIntegral . S.fromSing . S.singByProxy) s_
md = chunkItUp [] (product $ drop s $ shape a) v
mapAlong ::
forall c s uvw vw a.
(Container c, KnownNat s, Dimensions uvw, vw ~ (HeadModule s uvw))
=> Proxy s
-> (Array c vw a -> Array c vw a)
-> Array c uvw a
-> Array c uvw a
mapAlong s_ f a@(Array v) =
Array $
cconcat
(cfoldl'
(\xs x ->
let (Array vx) = f (Array x)
in vx : xs)
[]
md)
where
s = (fromIntegral . S.fromSing . S.singByProxy) s_
md = chunkItUp [] (product $ drop s $ shape a) v
concatenate ::
forall c s r t a.
( Container c
, S.SingI s
, Dimensions r
, Dimensions t
, (IsValidConcat s t r) ~ 'True
)
=> Proxy s
-> Array c r a
-> Array c t a
-> Array c (Concatenate s t r) a
concatenate s_ r@(Array vr) t@(Array vt) =
Array . cconcat $ (concat . reverse . P.transpose) [rm, tm]
where
s = (fromIntegral . S.fromSing . S.singByProxy) s_
rm = chunkItUp [] (product $ drop s $ shape t) vt
tm = chunkItUp [] (product $ drop s $ shape r) vr
transpose ::
forall c s t a. (t ~ Transpose s, Container c, Dimensions s, Dimensions t)
=> Array c (s :: [Nat]) a
-> Array c (t :: [Nat]) a
transpose (Array x) = Array x
squeeze ::
forall c s t a. (t ~ Squeeze s)
=> Array c s a
-> Array c t a
squeeze (Array x) = Array x
instance (Dimensions r, Container c, Additive a) =>
Additive (Array c (r :: [Nat]) a) where
a + b = liftR2 (+) a b
zero = pureRep zero
instance (Dimensions r, Container c, Subtractive a) =>
Subtractive (Array c (r :: [Nat]) a) where
negate = fmapRep negate
instance (Dimensions r, Container c, Multiplicative a) =>
Multiplicative (Array c (r :: [Nat]) a) where
a * b = liftR2 (*) a b
one = pureRep one
instance (Dimensions r, Container c, Divisive a) =>
Divisive (Array c (r :: [Nat]) a) where
recip = fmapRep recip
instance (Dimensions r, Container c, Multiplicative a, Additive a) =>
P.Distributive (Array c (r :: [Nat]) a)
instance (Dimensions r, Container c, IntegralDomain a) => IntegralDomain (Array c (r :: [Nat]) a)
instance (Dimensions r, Container c, Field a) => Field (Array c (r :: [Nat]) a)
instance (Dimensions r, Container c, ExpField a) => ExpField (Array c (r :: [Nat]) a) where
exp = fmapRep exp
log = fmapRep log
instance (Foldable (Array c r), Dimensions r, Container c, UpperBoundedField a) =>
UpperBoundedField (Array c (r :: [Nat]) a) where
isNaN = foldl' (||) False . fmapRep isNaN
instance (Foldable (Array c r), Dimensions r, Container c, LowerBoundedField a) =>
LowerBoundedField (Array c (r :: [Nat]) a)
instance (Dimensions r, Container c, Multiplicative a, Signed a)
=> Signed (Array c (r :: [Nat]) a) where
sign = fmapRep sign
abs = fmapRep abs
instance (Functor (Array c r), Foldable (Array c r), Additive (Array c r a), Normed a a, ExpField a) =>
Normed (Array c (r :: [Nat]) a) a where
normL1 r = foldr (+) zero $ normL1 <$> r
normL2 r = sqrt $ foldr (+) zero $ (** (one + one)) <$> r
instance (Eq (c a), Foldable (Array c r), Dimensions r, Container c, Epsilon a) =>
Epsilon (Array c (r :: [Nat]) a) where
epsilon = tabulate (const epsilon)
nearZero f = and (fmapRep nearZero f)
aboutEqual a b = and (liftR2 aboutEqual a b)
instance (Foldable (Array c r), Dimensions r, Container c, ExpField a, Subtractive a, Normed a a) =>
Metric (Array c (r :: [Nat]) a) a where
distanceL1 a b = normL1 (a - b)
distanceL2 a b = normL2 (a - b)
instance (Dimensions r, Container c, Integral a) => Integral (Array c (r :: [Nat]) a) where
divMod a b = (d, m)
where
x = liftR2 divMod a b
d = fmap fst x
m = fmap snd x
quotRem a b = (q, r)
where
x = liftR2 quotRem a b
q = fmap fst x
r = fmap snd x
type instance Actor (Array c r a) = a
instance (Dimensions r, Container c, Multiplicative a) =>
HadamardMultiplication (Array c (r :: [Nat])) a where
(.*.) = liftR2 (*)
instance (Dimensions r, Container c, Divisive a) =>
HadamardDivision (Array c (r :: [Nat])) a where
(./.) = liftR2 (/)
instance (Dimensions r, Container c, Additive a) =>
AdditiveAction (Array c (r::[Nat]) a) where
(.+) r s = fmap (s +) r
(+.) s = fmap (s +)
instance (Dimensions r, Container c, Subtractive a) =>
SubtractiveAction (Array c (r::[Nat]) a) where
(.-) r s = fmap (\x -> x - s) r
(-.) s = fmap (\x -> x - s)
instance (Dimensions r, Container c, Multiplicative a) =>
MultiplicativeAction (Array c (r :: [Nat]) a) where
(.*) r s = fmap (* s) r
(*.) s = fmap (s *)
instance (Dimensions r, Container c, Divisive a) =>
DivisiveAction (Array c (r::[Nat]) a) where
(./) r s = fmap (/ s) r
(/.) s = fmap (/ s)
instance forall a c r. (Actor (Array c r a) ~ a, Foldable (Array c r), P.Distributive a, CommutativeRing a, Semiring a, Dimensions r, Container c) =>
Hilbert (Array c (r :: [Nat]) a) where
a <.> b = sum $ liftR2 (*) a b
instance
( Foldable (Array c r)
, Dimensions r
, Container c
, CommutativeRing a
, Multiplicative a
) =>
TensorProduct (Array c (r :: [Nat]) a) where
(><) m n = tabulate (\i -> index m i *. n)
timesleft v m = tabulate (\i -> v <.> index m i)
timesright m v = tabulate (\i -> v <.> index m i)
instance (Eq (c a), Container c, Dimensions r, JoinSemiLattice a) => JoinSemiLattice (Array c (r :: [Nat]) a) where
(\/) = liftR2 (\/)
instance (Eq (c a), Container c, Dimensions r, MeetSemiLattice a) => MeetSemiLattice (Array c (r :: [Nat]) a) where
(/\) = liftR2 (/\)
instance (Eq (c a), Container c, Dimensions r, BoundedJoinSemiLattice a) => BoundedJoinSemiLattice (Array c (r :: [Nat]) a) where
bottom = pureRep bottom
instance (Eq (c a), Container c, Dimensions r, BoundedMeetSemiLattice a) => BoundedMeetSemiLattice (Array c (r :: [Nat]) a) where
top = pureRep top
singleton :: (Dimensions r, Container c) => a -> Array c (r :: [Nat]) a
singleton a = tabulate (const a)