module Numeric.EasyTensor
( Tensor ()
, fill
, prod, (%*)
, inverse, transpose
, (<:>), (//), (\\)
, index, indexCol, indexRow, dimN, dimM
, (V..*.), dot, (·), normL1, normL2, normLPInf, normLNInf, normLP
, eye, diag, det, trace
, toDiag, toDiag', fromDiag, fromDiag'
, Mat, Vec, Vec'
, Vec2f, Vec3f, Vec4f
, Vec2f', Vec3f', Vec4f'
, Mat22f, Mat23f, Mat24f
, Mat32f, Mat33f, Mat34f
, Mat42f, Mat43f, Mat44f
, scalar, vec2, vec3, vec4, vec2', vec3', vec4'
, mat22, mat33, mat44
, det2, det2', cross, (×)
) where
import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
import qualified Numeric.Vector as V
import qualified Numeric.Matrix as M
import qualified Numeric.Matrix.Class as M
import Numeric.Commons
newtype Tensor t n m = Tensor { _unT :: TT t n m }
instance Show (TT t n m) => Show (Tensor t n m) where
show (Tensor t) = show t
deriving instance Eq (TT t n m) => Eq (Tensor t n m)
deriving instance Ord (TT t n m) => Ord (Tensor t n m)
deriving instance Num (TT t n m) => Num (Tensor t n m)
deriving instance Fractional (TT t n m) => Fractional (Tensor t n m)
deriving instance Floating (TT t n m) => Floating (Tensor t n m)
deriving instance V.VectorCalculus t n (TT t n 1) => V.VectorCalculus t n (Tensor t n 1)
deriving instance V.VectorCalculus t m (TT t 1 m) => V.VectorCalculus t m (Tensor t 1 m)
deriving instance M.MatrixCalculus t n m (TT t n m) => M.MatrixCalculus t n m (Tensor t n m)
deriving instance M.SquareMatrixCalculus t n (TT t n n) => M.SquareMatrixCalculus t n (Tensor t n n)
deriving instance M.MatrixInverse (TT t n n) => M.MatrixInverse (Tensor t n n)
deriving instance PrimBytes (TT t n m) => PrimBytes (Tensor t n m)
deriving instance FloatBytes (TT t n m) => FloatBytes (Tensor t n m)
deriving instance DoubleBytes (TT t n m) => DoubleBytes (Tensor t n m)
deriving instance IntBytes (TT t n m) => IntBytes (Tensor t n m)
deriving instance WordBytes (TT t n m) => WordBytes (Tensor t n m)
newtype Scalar t = Scalar { _unScalar :: t }
deriving ( Bounded, Enum, Eq, Integral, Num, Fractional, Floating, Ord, Read, Real, RealFrac, RealFloat
, PrimBytes, FloatBytes, DoubleBytes, IntBytes, WordBytes)
instance Show t => Show (Scalar t) where
show (Scalar t) = "{ " ++ show t ++ " }"
newtype CoVector t n = CoVector {_unCoVec :: V.Vector t n}
instance Show (V.Vector t n) => Show (CoVector t n) where
show (CoVector t) = show t
deriving instance Eq (V.Vector t n) => Eq (CoVector t n)
deriving instance Ord (V.Vector t n) => Ord (CoVector t n)
deriving instance Num (V.Vector t n) => Num (CoVector t n)
deriving instance Fractional (V.Vector t n) => Fractional (CoVector t n)
deriving instance Floating (V.Vector t n) => Floating (CoVector t n)
deriving instance PrimBytes (V.Vector t n) => PrimBytes (CoVector t n)
deriving instance V.VectorCalculus t n (V.Vector t n) => V.VectorCalculus t n (CoVector t n)
deriving instance FloatBytes (V.Vector t n) => FloatBytes (CoVector t n)
deriving instance DoubleBytes (V.Vector t n) => DoubleBytes (CoVector t n)
deriving instance IntBytes (V.Vector t n) => IntBytes (CoVector t n)
deriving instance WordBytes (V.Vector t n) => WordBytes (CoVector t n)
newtype ContraVector t n = ContraVector {_unContraVec :: V.Vector t n}
instance Show (V.Vector t n) => Show (ContraVector t n) where
show (ContraVector t) = show t ++ "'"
deriving instance Eq (V.Vector t n) => Eq (ContraVector t n)
deriving instance Ord (V.Vector t n) => Ord (ContraVector t n)
deriving instance Num (V.Vector t n) => Num (ContraVector t n)
deriving instance Fractional (V.Vector t n) => Fractional (ContraVector t n)
deriving instance Floating (V.Vector t n) => Floating (ContraVector t n)
deriving instance PrimBytes (V.Vector t n) => PrimBytes (ContraVector t n)
deriving instance V.VectorCalculus t n (V.Vector t n) => V.VectorCalculus t n (ContraVector t n)
deriving instance FloatBytes (V.Vector t n) => FloatBytes (ContraVector t n)
deriving instance DoubleBytes (V.Vector t n) => DoubleBytes (ContraVector t n)
deriving instance IntBytes (V.Vector t n) => IntBytes (ContraVector t n)
deriving instance WordBytes (V.Vector t n) => WordBytes (ContraVector t n)
newtype Matrix t n m = Matrix (M.Matrix t n m)
instance Show (M.Matrix t n m) => Show (Matrix t n m) where
show (Matrix t) = show t
deriving instance Eq (M.Matrix t n m) => Eq (Matrix t n m)
deriving instance Ord (M.Matrix t n m) => Ord (Matrix t n m)
deriving instance Num (M.Matrix t n m) => Num (Matrix t n m)
deriving instance Fractional (M.Matrix t n m) => Fractional (Matrix t n m)
deriving instance Floating (M.Matrix t n m) => Floating (Matrix t n m)
deriving instance PrimBytes (M.Matrix t n m) => PrimBytes (Matrix t n m)
deriving instance M.MatrixCalculus t n m (M.Matrix t n m) => M.MatrixCalculus t n m (Matrix t n m)
deriving instance M.SquareMatrixCalculus t n (M.Matrix t n n) => M.SquareMatrixCalculus t n (Matrix t n n)
deriving instance M.MatrixInverse (M.Matrix t n n) => M.MatrixInverse (Matrix t n n)
deriving instance FloatBytes (M.Matrix t n m) => FloatBytes (Matrix t n m)
deriving instance DoubleBytes (M.Matrix t n m) => DoubleBytes (Matrix t n m)
deriving instance IntBytes (M.Matrix t n m) => IntBytes (Matrix t n m)
deriving instance WordBytes (M.Matrix t n m) => WordBytes (Matrix t n m)
type family TT t (n :: Nat) (m :: Nat) = v | v -> t n m where
TT t 1 1 = Scalar t
TT t n 1 = ContraVector t n
TT t 1 m = CoVector t m
TT t n m = Matrix t n m
fill :: M.MatrixCalculus t n m (Tensor t n m) => Tensor t 1 1 -> Tensor t n m
fill = M.broadcastMat . _unScalar . _unT
index :: M.MatrixCalculus t n m (Tensor t n m) => Int -> Int -> Tensor t n m -> Tensor t 1 1
index i j = Tensor . Scalar . M.indexMat i j
indexCol :: ( M.MatrixCalculus t n m (Tensor t n m)
, V.VectorCalculus t n (Tensor t n 1)
, PrimBytes (Tensor t n 1)
)
=> Int -> Tensor t n m -> Tensor t n 1
indexCol = M.indexCol
indexRow :: ( M.MatrixCalculus t n m (Tensor t n m)
, V.VectorCalculus t m (Tensor t 1 m)
, PrimBytes (Tensor t 1 m)
)
=> Int -> Tensor t n m -> Tensor t 1 m
indexRow = M.indexRow
transpose :: ( M.MatrixCalculus t n m (Tensor t n m)
, M.MatrixCalculus t m n (Tensor t m n)
, PrimBytes (Tensor t m n)
) => Tensor t n m -> Tensor t m n
transpose = M.transpose
dimN :: ( M.MatrixCalculus t n m (Tensor t n m)) => Tensor t n m -> Int
dimN = M.dimN
dimM :: ( M.MatrixCalculus t n m (Tensor t n m)) => Tensor t n m -> Int
dimM = M.dimM
infixl 7 %*
(%*) :: (M.MatrixProduct (Tensor t n m) (Tensor t m k) (Tensor t n k))
=> Tensor t n m -> Tensor t m k -> Tensor t n k
(%*) = prod
prod :: (M.MatrixProduct (Tensor t n m) (Tensor t m k) (Tensor t n k))
=> Tensor t n m -> Tensor t m k -> Tensor t n k
prod = M.prod
(//) :: ( M.MatrixProduct (Tensor t n m) (Tensor t m m) (Tensor t n m)
, M.MatrixInverse (TT t m m))
=> Tensor t n m -> Tensor t m m -> Tensor t n m
(//) a b = prod a (inverse b)
(\\) :: ( M.MatrixProduct (Tensor t n n) (Tensor t n m) (Tensor t n m)
, M.MatrixInverse (TT t n n))
=> Tensor t n n -> Tensor t n m -> Tensor t n m
(\\) a b = prod (inverse a) b
inverse :: M.MatrixInverse (Tensor t n n) => Tensor t n n -> Tensor t n n
inverse = M.inverse
instance ( FloatBytes (Tensor Float n m)
, FloatBytes (Tensor Float m k)
, PrimBytes (Tensor Float n k)
, M.MatrixCalculus Float n m (Tensor Float n m)
, M.MatrixCalculus Float m k (Tensor Float m k)
)
=> M.MatrixProduct (Tensor Float n m) (Tensor Float m k) (Tensor Float n k) where
prod x y = case (dimN x, dimM x, dimM y) of
( I# n, I# m, I# k) -> M.prodF n m k x y
instance ( DoubleBytes (Tensor Double n m)
, DoubleBytes (Tensor Double m k)
, PrimBytes (Tensor Double n k)
, M.MatrixCalculus Double n m (Tensor Double n m)
, M.MatrixCalculus Double m k (Tensor Double m k)
)
=> M.MatrixProduct (Tensor Double n m) (Tensor Double m k) (Tensor Double n k) where
prod x y = case (dimN x, dimM x, dimM y) of
( I# n, I# m, I# k) -> M.prodD n m k x y
(<:>) :: ( PrimBytes (Tensor t k n)
, PrimBytes (Tensor t k m)
, PrimBytes (Tensor t k (n + m))
)
=> Tensor t k n -> Tensor t k m -> Tensor t k (n + m)
a <:> b = case (# toBytes a, toBytes b, byteSize a, byteSize b #) of
(# arr1, arr2, n, m #) -> case runRW#
( \s0 -> case newByteArray# (n +# m) s0 of
(# s1, marr #) -> case copyByteArray# arr1 0# marr 0# n s1 of
s2 -> case copyByteArray# arr2 0# marr n m s2 of
s3 -> unsafeFreezeByteArray# marr s3
) of (# _, r #) -> fromBytes r
infixl 5 <:>
type Mat t n m = Tensor t n m
type Vec t n = Tensor t n 1
type Vec' t m = Tensor t 1 m
type Vec2f = Tensor Float 2 1
type Vec3f = Tensor Float 3 1
type Vec4f = Tensor Float 4 1
type Vec2f' = Tensor Float 1 2
type Vec3f' = Tensor Float 1 3
type Vec4f' = Tensor Float 1 4
type Mat22f = Tensor Float 2 2
type Mat32f = Tensor Float 3 2
type Mat42f = Tensor Float 4 2
type Mat23f = Tensor Float 2 3
type Mat33f = Tensor Float 3 3
type Mat43f = Tensor Float 4 3
type Mat24f = Tensor Float 2 4
type Mat34f = Tensor Float 3 4
type Mat44f = Tensor Float 4 4
scalar :: t -> Tensor t 1 1
scalar = Tensor . Scalar
vec2 :: V.Vector2D t => t -> t -> Tensor t 2 1
vec2 a b = Tensor . ContraVector $ V.vec2 a b
vec2' :: V.Vector2D t => t -> t -> Tensor t 1 2
vec2' a b = Tensor . CoVector $ V.vec2 a b
vec3 :: V.Vector3D t => t -> t -> t -> Tensor t 3 1
vec3 a b c = Tensor . ContraVector $ V.vec3 a b c
vec3' :: V.Vector3D t => t -> t -> t -> Tensor t 1 3
vec3' a b c = Tensor . CoVector $ V.vec3 a b c
vec4 :: V.Vector4D t => t -> t -> t -> t -> Tensor t 4 1
vec4 a b c d = Tensor . ContraVector $ V.vec4 a b c d
vec4' :: V.Vector4D t => t -> t -> t -> t -> Tensor t 1 4
vec4' a b c d = Tensor . CoVector $ V.vec4 a b c d
mat22 :: M.Matrix2x2 t => Tensor t 2 1 -> Tensor t 2 1 -> Tensor t 2 2
mat22 (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . Matrix $ M.mat22 a b
mat33 :: ( PrimBytes (Tensor t 3 3)
, PrimBytes (Tensor t 3 2)
, PrimBytes (Tensor t 3 1)
)
=> Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 3
mat33 a b c = a <:> b <:> c
mat44 :: ( PrimBytes (Tensor t 4 4)
, PrimBytes (Tensor t 4 3)
, PrimBytes (Tensor t 4 2)
, PrimBytes (Tensor t 4 1)
)
=> Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 4
mat44 a b c d = a <:> b <:> c <:> d
det2 :: V.Vector2D t => Tensor t 2 1 -> Tensor t 2 1 -> Tensor t 1 1
det2 (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . Scalar $ V.det2 a b
det2' :: V.Vector2D t => Tensor t 1 2 -> Tensor t 1 2 -> Tensor t 1 1
det2' (Tensor (CoVector a)) (Tensor (CoVector b)) = Tensor . Scalar $ V.det2 a b
infixl 7 ×
(×) :: V.Vector3D t => Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1
(×) = cross
cross :: V.Vector3D t => Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1
cross (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . ContraVector $ V.cross a b
infixl 7 ·
(·) :: V.VectorCalculus t n v => v -> v -> Tensor t 1 1
(·) = dot
dot :: V.VectorCalculus t n v => v -> v -> Tensor t 1 1
dot a b = Tensor . Scalar $ V.dot a b
normL1 :: V.VectorCalculus t n v => v -> Tensor t 1 1
normL1 = Tensor . Scalar . V.normL1
normL2 :: V.VectorCalculus t n v => v -> Tensor t 1 1
normL2 = Tensor . Scalar . V.normL2
normLPInf :: V.VectorCalculus t n v => v -> Tensor t 1 1
normLPInf = Tensor . Scalar . V.normLPInf
normLNInf :: V.VectorCalculus t n v => v -> Tensor t 1 1
normLNInf = Tensor . Scalar . V.normLNInf
normLP :: V.VectorCalculus t n v => Int -> v -> Tensor t 1 1
normLP p = Tensor . Scalar . V.normLP p
eye :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n
eye = M.eye
diag :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t 1 1 -> Tensor t n n
diag = M.diag . _unScalar . _unT
det :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n -> Tensor t 1 1
det = Tensor . Scalar . M.det
trace :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n -> Tensor t 1 1
trace = Tensor . Scalar . M.trace
fromDiag :: ( M.SquareMatrixCalculus t n (Tensor t n n)
, V.VectorCalculus t n (Tensor t n 1)
, PrimBytes (Tensor t n 1)
)
=> Tensor t n n -> Tensor t n 1
fromDiag = M.fromDiag
fromDiag' :: ( M.SquareMatrixCalculus t n (Tensor t n n)
, V.VectorCalculus t n (Tensor t 1 n)
, PrimBytes (Tensor t 1 n)
)
=> Tensor t n n -> Tensor t 1 n
fromDiag' = M.fromDiag
toDiag :: ( M.SquareMatrixCalculus t n (Tensor t n n)
, V.VectorCalculus t n (Tensor t n 1)
, PrimBytes (Tensor t n 1)
)
=> Tensor t n 1 -> Tensor t n n
toDiag = M.toDiag
toDiag' :: ( M.SquareMatrixCalculus t n (Tensor t n n)
, V.VectorCalculus t n (Tensor t 1 n)
, PrimBytes (Tensor t 1 n)
)
=> Tensor t 1 n -> Tensor t n n
toDiag' = M.toDiag
instance Num t => V.VectorCalculus t 1 (Scalar t) where
broadcastVec = Scalar
(.*.) = (*)
dot a = _unScalar . (a *)
indexVec 1 = _unScalar
indexVec i = const . error $ "Bad index " ++ show i ++ " for a scalar"
normL1 = _unScalar . abs
normL2 = _unScalar . abs
normLPInf = _unScalar
normLNInf = _unScalar
normLP _ = _unScalar . abs
dim _ = 1
instance Num t => M.MatrixCalculus t 1 1 (Scalar t) where
broadcastMat = Scalar
indexMat 1 1 = _unScalar
indexMat i j = const . error $ "Bad index (" ++ show i ++ ", " ++ show j ++ ") for a scalar"
transpose = unsafeCoerce
dimN _ = 1
dimM _ = 1
indexCol 1 = unsafeCoerce
indexCol j = const . error $ "Bad column index " ++ show j ++ " for a scalar"
indexRow 1 = unsafeCoerce
indexRow i = const . error $ "Bad row index " ++ show i ++ " for a scalar"
instance Num t => M.SquareMatrixCalculus t 1 (Scalar t) where
eye = Scalar 1
diag = Scalar
det = _unScalar
trace = _unScalar
fromDiag = unsafeCoerce
toDiag = unsafeCoerce
instance Fractional t => M.MatrixInverse (Scalar t) where
inverse = Scalar . recip . _unScalar
instance (KnownNat n, V.VectorCalculus t n (V.Vector t n)) => M.MatrixCalculus t n 1 (ContraVector t n) where
broadcastMat = ContraVector . V.broadcastVec
indexMat i 1 (ContraVector v) = V.indexVec i v
indexMat i j (ContraVector v) = error $ "Bad index (" ++ show i ++ ", " ++ show j ++ ") for a " ++ show (V.dim v) ++ "x1D matrix"
transpose (ContraVector v) = unsafeCoerce $ CoVector v
dimN = V.dim . _unContraVec
dimM _ = 1
indexCol 1 (ContraVector v) = unsafeCoerce $ ContraVector v
indexCol j (ContraVector v) = error $ "Bad column index " ++ show j ++ " for a " ++ show (V.dim v) ++ "x1D matrix"
indexRow i = unsafeCoerce . Scalar . V.indexVec i . _unContraVec
instance (KnownNat m, V.VectorCalculus t m (V.Vector t m)) => M.MatrixCalculus t 1 m (CoVector t m) where
broadcastMat = CoVector . V.broadcastVec
indexMat 1 i (CoVector v) = V.indexVec i v
indexMat j i (CoVector v) = error $ "Bad index (" ++ show j ++ ", " ++ show i ++ ") for a 1x" ++ show (V.dim v) ++ "D matrix"
transpose (CoVector v) = unsafeCoerce $ ContraVector v
dimN _ = 1
dimM = V.dim . _unCoVec
indexCol i = unsafeCoerce . Scalar . V.indexVec i . _unCoVec
indexRow 1 x = unsafeCoerce x
indexRow j (CoVector v) = error $ "Bad column index " ++ show j ++ " for a 1x" ++ show (V.dim v) ++ "D matrix"