{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module NumHask.Array.HMatrix
( Array(..),
index,
mmult,
) where
import GHC.Exts (IsList (..))
import GHC.Show (Show (..))
import NumHask.Array.Shape
import NumHask.Prelude as P
import qualified Numeric.LinearAlgebra as H
import qualified Numeric.LinearAlgebra.Devel as H
import qualified Prelude
newtype Array s a = Array {unArray :: H.Matrix a}
deriving (Show, NFData, Generic)
index ::
forall s a.
( HasShape s,
H.Element a,
H.Container H.Vector a
) =>
Array s a ->
[Int] ->
a
index (Array v) i = H.flatten v `H.atIndex` flatten s i
where
s = shapeVal (toShape @s)
instance
( Additive a,
HasShape s,
H.Container H.Vector a,
Num a
) =>
Additive (Array s a)
where
(+) (Array x1) (Array x2) = Array $ H.add x1 x2
zero = Array $ H.konst zero (n, m)
where
s = shapeVal (toShape @s)
[n, m] = s
instance
( Multiplicative a,
HasShape s,
H.Container H.Vector a,
Num (H.Vector a),
Num a
) =>
Multiplicative (Array s a)
where
(*) (Array x1) (Array x2) = Array $ H.liftMatrix2 (Prelude.*) x1 x2
one = Array $ H.konst one (n, m)
where
s = shapeVal (toShape @s)
[n, m] = s
type instance Actor (Array s a) = a
instance
( HasShape s,
P.Distributive a,
CommutativeRing a,
Semiring a,
H.Container H.Vector a,
Num (H.Vector a),
Num a
) =>
Hilbert (Array s a)
where
(<.>) (Array a) (Array b) = H.sumElements $ H.liftMatrix2 (Prelude.*) a b
{-# INLINE (<.>) #-}
instance
( HasShape s,
Multiplicative a,
H.Container H.Vector a,
Num a
) =>
MultiplicativeAction (Array s a)
where
(.*) (Array r) s = Array $ H.cmap (* s) r
{-# INLINE (.*) #-}
(*.) s (Array r) = Array $ H.cmap (s *) r
{-# INLINE (*.) #-}
instance
( HasShape s,
Additive a,
H.Element a
) =>
IsList (Array s a)
where
type Item (Array s a) = a
fromList l = Array $ H.reshape n $ H.fromList $ take mn $ l ++ repeat zero
where
mn = P.product $ shapeVal (toShape @s)
s = shapeVal (toShape @s)
n = Prelude.last s
toList (Array v) = H.toList $ H.flatten v
mmult ::
forall m n k a.
( KnownNat k,
KnownNat m,
KnownNat n,
HasShape [m, n],
Ring a,
H.Numeric a
) =>
Array [m, k] a ->
Array [k, n] a ->
Array [m, n] a
mmult (Array x) (Array y) = Array $ x H.<> y
type instance Actor (Array s a) = a