{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module NumHask.Array.HMatrix
(
Array (..),
index,
tabulate,
shape,
toDynamic,
toFixed,
fromFixed,
reshape,
transpose,
diag,
ident,
singleton,
selects,
selectsExcept,
folds,
concatenate,
insert,
append,
reorder,
expand,
slice,
squeeze,
Scalar,
fromScalar,
toScalar,
Vector,
Matrix,
col,
row,
safeCol,
safeRow,
mmult,
)
where
import Data.List ((!!))
import qualified Data.Vector as V
import GHC.Exts (IsList (..))
import GHC.TypeLits
import qualified NumHask.Array.Dynamic as D
import qualified NumHask.Array.Fixed as F
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (transpose)
import qualified Numeric.LinearAlgebra as H
import qualified Numeric.LinearAlgebra.Devel as H
import qualified Prelude
newtype Array s a = Array {unArray :: H.Matrix a}
deriving (Show, NFData, Generic)
instance
( Additive a,
HasShape s,
H.Container H.Vector a,
Num a
) =>
Additive (Array s a)
where
(+) (Array x1) (Array x2) = Array $ H.add x1 x2
zero = Array $ H.konst zero (n, m)
where
s = shapeVal (toShape @s)
[n, m] = s
instance
( Multiplicative a,
HasShape s,
H.Container H.Vector a,
Num (H.Vector a),
Num a
) =>
Multiplicative (Array s a)
where
(*) (Array x1) (Array x2) = Array $ H.liftMatrix2 (Prelude.*) x1 x2
one = Array $ H.konst one (n, m)
where
s = shapeVal (toShape @s)
[n, m] = s
type instance Actor (Array s a) = a
instance
( HasShape s,
Multiplicative a,
H.Container H.Vector a,
Num a
) =>
MultiplicativeAction (Array s a)
where
(.*) (Array r) s = Array $ H.cmap (* s) r
{-# INLINE (.*) #-}
(*.) s (Array r) = Array $ H.cmap (s *) r
{-# INLINE (*.) #-}
type instance Actor (Array s a) = a
instance
( HasShape s,
H.Element a
) =>
IsList (Array s a)
where
type Item (Array s a) = a
fromList l =
bool
(throw (NumHaskException "shape mismatch"))
(Array $ H.reshape n $ H.fromList l)
((length l == 1 && null s) || (length l == size s))
where
s = shapeVal (toShape @s)
n = Prelude.last s
toList (Array v) = H.toList $ H.flatten v
shape :: forall a s. (HasShape s) => Array s a -> [Int]
shape _ = shapeVal $ toShape @s
{-# INLINE shape #-}
toDynamic :: (HasShape s, H.Element a) => Array s a -> D.Array a
toDynamic a@(Array h) = D.fromFlatList (shape a) (mconcat $ H.toLists h)
toFixed :: (HasShape s, H.Element a) => Array s a -> F.Array s a
toFixed (Array h) = fromList (mconcat $ H.toLists h)
fromFixed :: (HasShape s, H.Element a) => F.Array s a -> Array s a
fromFixed a = fromList (P.toList a)
index ::
forall s a.
( HasShape s,
H.Element a,
H.Container H.Vector a
) =>
Array s a ->
[Int] ->
a
index (Array v) i = H.flatten v `H.atIndex` flatten s i
where
s = shapeVal (toShape @s)
tabulate ::
forall s a.
( HasShape s,
H.Element a
) =>
([Int] -> a) ->
Array s a
tabulate f =
fromList (V.toList $ V.generate (size s) (f . shapen s))
where
s = shapeVal (toShape @s)
reshape ::
forall a s s'.
( Size s ~ Size s',
HasShape s,
HasShape s',
H.Container H.Vector a
) =>
Array s a ->
Array s' a
reshape a = tabulate (index a . shapen s . flatten s')
where
s = shapeVal (toShape @s)
s' = shapeVal (toShape @s')
transpose :: forall a s. (H.Element a, H.Container H.Vector a, HasShape s, HasShape (Reverse s)) => Array s a -> Array (Reverse s) a
transpose a = tabulate (index a . reverse)
ident :: forall a s. (H.Element a, H.Container H.Vector a, HasShape s, Additive a, Multiplicative a) => Array s a
ident = tabulate (bool zero one . isDiag)
where
isDiag [] = True
isDiag [_] = True
isDiag [x, y] = x == y
isDiag (x : y : xs) = x == y && isDiag (y : xs)
diag ::
forall a s.
( HasShape s,
HasShape '[Minimum s],
H.Element a,
H.Container H.Vector a
) =>
Array s a ->
Array '[Minimum s] a
diag a = tabulate go
where
go [] = throw (NumHaskException "Rank Underflow")
go (s' : _) = index a (replicate (length ds) s')
ds = shapeVal (toShape @s)
singleton :: (H.Element a, H.Container H.Vector a, HasShape s) => a -> Array s a
singleton a = tabulate (const a)
selects ::
forall ds s s' a.
( HasShape s,
HasShape ds,
HasShape s',
s' ~ DropIndexes s ds,
H.Element a,
H.Container H.Vector a
) =>
Proxy ds ->
[Int] ->
Array s a ->
Array s' a
selects _ i a = tabulate go
where
go s = index a (addIndexes s ds i)
ds = shapeVal (toShape @ds)
selectsExcept ::
forall ds s s' a.
( HasShape s,
HasShape ds,
HasShape s',
s' ~ TakeIndexes s ds,
H.Element a,
H.Container H.Vector a
) =>
Proxy ds ->
[Int] ->
Array s a ->
Array s' a
selectsExcept _ i a = tabulate go
where
go s = index a (addIndexes i ds s)
ds = shapeVal (toShape @ds)
folds ::
forall ds st si so a b.
( HasShape st,
HasShape ds,
HasShape si,
HasShape so,
si ~ DropIndexes st ds,
so ~ TakeIndexes st ds,
H.Element a,
H.Container H.Vector a,
H.Element b,
H.Container H.Vector b
) =>
(Array si a -> b) ->
Proxy ds ->
Array st a ->
Array so b
folds f d a = tabulate go
where
go s = f (selects d s a)
concatenate ::
forall a s0 s1 d s.
( CheckConcatenate d s0 s1 s,
Concatenate d s0 s1 ~ s,
HasShape s0,
HasShape s1,
HasShape s,
KnownNat d,
H.Element a,
H.Container H.Vector a
) =>
Proxy d ->
Array s0 a ->
Array s1 a ->
Array s a
concatenate _ s0 s1 = tabulate go
where
go s =
bool
(index s0 s)
( index
s1
( addIndex
(dropIndex s d)
d
((s !! d) - (ds0 !! d))
)
)
((s !! d) >= (ds0 !! d))
ds0 = shapeVal (toShape @s0)
d = fromIntegral $ natVal @d Proxy
insert ::
forall a s s' d i.
( DropIndex s d ~ s',
CheckInsert d i s,
KnownNat i,
KnownNat d,
HasShape s,
HasShape s',
HasShape (Insert d s),
H.Element a,
H.Container H.Vector a
) =>
Proxy d ->
Proxy i ->
Array s a ->
Array s' a ->
Array (Insert d s) a
insert _ _ a b = tabulate go
where
go s
| s !! d == i = index b (dropIndex s d)
| s !! d < i = index a s
| otherwise = index a (decAt d s)
d = fromIntegral $ natVal @d Proxy
i = fromIntegral $ natVal @i Proxy
append ::
forall a d s s'.
( DropIndex s d ~ s',
CheckInsert d (Dimension s d - 1) s,
KnownNat (Dimension s d - 1),
KnownNat d,
HasShape s,
HasShape s',
HasShape (Insert d s),
H.Element a,
H.Container H.Vector a
) =>
Proxy d ->
Array s a ->
Array s' a ->
Array (Insert d s) a
append d = insert d (Proxy :: Proxy (Dimension s d - 1))
reorder ::
forall a ds s.
( HasShape ds,
HasShape s,
HasShape (Reorder s ds),
CheckReorder ds s,
H.Element a,
H.Container H.Vector a
) =>
Proxy ds ->
Array s a ->
Array (Reorder s ds) a
reorder _ a = tabulate go
where
go s = index a (addIndexes [] ds s)
ds = shapeVal (toShape @ds)
expand ::
forall s s' a b c.
( HasShape s,
HasShape s',
HasShape ((++) s s'),
H.Element a,
H.Container H.Vector a,
H.Element b,
H.Container H.Vector b,
H.Element c
) =>
(a -> b -> c) ->
Array s a ->
Array s' b ->
Array ((++) s s') c
expand f a b = tabulate (\i -> f (index a (take r i)) (index b (drop r i)))
where
r = rank (shape a)
slice ::
forall (pss :: [[Nat]]) s s' a.
( HasShape s,
HasShape s',
KnownNatss pss,
KnownNat (Rank pss),
s' ~ Ranks pss,
H.Element a,
H.Container H.Vector a
) =>
Proxy pss ->
Array s a ->
Array s' a
slice pss a = tabulate go
where
go s = index a (zipWith (!!) pss' s)
pss' = natValss pss
squeeze ::
forall s t a.
(t ~ Squeeze s) =>
Array s a ->
Array t a
squeeze (Array x) = Array x
type Scalar a = Array ('[] :: [Nat]) a
fromScalar :: (H.Element a, H.Container H.Vector a, HasShape ('[] :: [Nat])) => Array ('[] :: [Nat]) a -> a
fromScalar a = index a ([] :: [Int])
toScalar :: (H.Element a, H.Container H.Vector a, HasShape ('[] :: [Nat])) => a -> Array ('[] :: [Nat]) a
toScalar a = fromList [a]
type Vector s a = Array '[s] a
type Matrix m n a = Array '[m, n] a
instance
( Multiplicative a,
P.Distributive a,
Subtractive a,
H.Numeric a,
KnownNat m,
HasShape '[m, m],
H.Element a,
H.Container H.Vector a
) =>
Multiplicative (Matrix m m a)
where
(*) = mmult
one = ident
row :: forall m n a. (H.Element a, H.Container H.Vector a, KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
row i (Array a) = fromList $ H.toList $ H.subVector (i * n) n (H.flatten a)
where
n = fromIntegral $ natVal @n Proxy
safeRow :: forall m n a j. (H.Element a, H.Container H.Vector a, 'True ~ CheckIndex j m, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeRow _j (Array a) = fromList $ H.toList $ H.subVector (j * n) n (H.flatten a)
where
n = fromIntegral $ natVal @n Proxy
j = fromIntegral $ natVal @j Proxy
col :: forall m n a. (H.Element a, H.Container H.Vector a, KnownNat m, KnownNat n, HasShape '[m, n]) => Int -> Matrix m n a -> Vector n a
col i (Array a) = Array $ H.takeColumns i a
safeCol :: forall m n a j. (H.Element a, H.Container H.Vector a, 'True ~ CheckIndex j n, KnownNat j, KnownNat m, KnownNat n, HasShape '[m, n]) => Proxy j -> Matrix m n a -> Vector n a
safeCol _j (Array a) = Array $ H.takeColumns j a
where
j = fromIntegral $ natVal @j Proxy
mmult ::
forall m n k a.
( KnownNat k,
KnownNat m,
KnownNat n,
HasShape [m, n],
Ring a,
H.Numeric a
) =>
Array [m, k] a ->
Array [k, n] a ->
Array [m, n] a
mmult (Array x) (Array y) = Array $ x H.<> y