{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MagicHash, UnboxedTuples #-} {-# LANGUAGE KindSignatures, DataKinds #-} {-# LANGUAGE TypeOperators, FlexibleInstances, ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.EasyTensor -- Copyright : (c) Artem Chirkin -- License : MIT -- -- Maintainer : chirkin@arch.ethz.ch -- -- This module generalizes matrices and vectors. -- Yet it is limited to rank 2, allowing for a simple and nicely type-checked interface. -- -- -- ----------------------------------------------------------------------------- module Numeric.EasyTensor ( Tensor () -- * Common operations , fill , prod, (%*) , inverse, transpose , (<:>), (//), (\\) , index, indexCol, indexRow, dimN, dimM , (V..*.), dot, (·), normL1, normL2, normLPInf, normLNInf, normLP , eye, diag, det, trace , toDiag, toDiag', fromDiag, fromDiag' -- * Type abbreviations , Mat, Vec, Vec' , Vec2f, Vec3f, Vec4f , Vec2f', Vec3f', Vec4f' , Mat22f, Mat23f, Mat24f , Mat32f, Mat33f, Mat34f , Mat42f, Mat43f, Mat44f -- * Simplified type constructors , scalar, vec2, vec3, vec4, vec2', vec3', vec4' , mat22, mat33, mat44 -- * Some low-dimensional operations , det2, det2', cross, (×) ) where import GHC.Base (runRW#) import GHC.Prim import GHC.Types import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) import qualified Numeric.Vector as V import qualified Numeric.Matrix as M import qualified Numeric.Matrix.Class as M import Numeric.Commons newtype Tensor t n m = Tensor { _unT :: TT t n m } instance Show (TT t n m) => Show (Tensor t n m) where show (Tensor t) = show t deriving instance Eq (TT t n m) => Eq (Tensor t n m) deriving instance Ord (TT t n m) => Ord (Tensor t n m) deriving instance Num (TT t n m) => Num (Tensor t n m) deriving instance Fractional (TT t n m) => Fractional (Tensor t n m) deriving instance Floating (TT t n m) => Floating (Tensor t n m) deriving instance V.VectorCalculus t n (TT t n 1) => V.VectorCalculus t n (Tensor t n 1) deriving instance V.VectorCalculus t m (TT t 1 m) => V.VectorCalculus t m (Tensor t 1 m) deriving instance M.MatrixCalculus t n m (TT t n m) => M.MatrixCalculus t n m (Tensor t n m) deriving instance M.SquareMatrixCalculus t n (TT t n n) => M.SquareMatrixCalculus t n (Tensor t n n) deriving instance M.MatrixInverse (TT t n n) => M.MatrixInverse (Tensor t n n) deriving instance PrimBytes (TT t n m) => PrimBytes (Tensor t n m) deriving instance FloatBytes (TT t n m) => FloatBytes (Tensor t n m) deriving instance DoubleBytes (TT t n m) => DoubleBytes (Tensor t n m) deriving instance IntBytes (TT t n m) => IntBytes (Tensor t n m) deriving instance WordBytes (TT t n m) => WordBytes (Tensor t n m) newtype Scalar t = Scalar { _unScalar :: t } deriving ( Bounded, Enum, Eq, Integral, Num, Fractional, Floating, Ord, Read, Real, RealFrac, RealFloat , PrimBytes, FloatBytes, DoubleBytes, IntBytes, WordBytes) instance Show t => Show (Scalar t) where show (Scalar t) = "{ " ++ show t ++ " }" newtype CoVector t n = CoVector {_unCoVec :: V.Vector t n} instance Show (V.Vector t n) => Show (CoVector t n) where show (CoVector t) = show t deriving instance Eq (V.Vector t n) => Eq (CoVector t n) deriving instance Ord (V.Vector t n) => Ord (CoVector t n) deriving instance Num (V.Vector t n) => Num (CoVector t n) deriving instance Fractional (V.Vector t n) => Fractional (CoVector t n) deriving instance Floating (V.Vector t n) => Floating (CoVector t n) deriving instance PrimBytes (V.Vector t n) => PrimBytes (CoVector t n) deriving instance V.VectorCalculus t n (V.Vector t n) => V.VectorCalculus t n (CoVector t n) deriving instance FloatBytes (V.Vector t n) => FloatBytes (CoVector t n) deriving instance DoubleBytes (V.Vector t n) => DoubleBytes (CoVector t n) deriving instance IntBytes (V.Vector t n) => IntBytes (CoVector t n) deriving instance WordBytes (V.Vector t n) => WordBytes (CoVector t n) newtype ContraVector t n = ContraVector {_unContraVec :: V.Vector t n} instance Show (V.Vector t n) => Show (ContraVector t n) where show (ContraVector t) = show t ++ "'" deriving instance Eq (V.Vector t n) => Eq (ContraVector t n) deriving instance Ord (V.Vector t n) => Ord (ContraVector t n) deriving instance Num (V.Vector t n) => Num (ContraVector t n) deriving instance Fractional (V.Vector t n) => Fractional (ContraVector t n) deriving instance Floating (V.Vector t n) => Floating (ContraVector t n) deriving instance PrimBytes (V.Vector t n) => PrimBytes (ContraVector t n) deriving instance V.VectorCalculus t n (V.Vector t n) => V.VectorCalculus t n (ContraVector t n) deriving instance FloatBytes (V.Vector t n) => FloatBytes (ContraVector t n) deriving instance DoubleBytes (V.Vector t n) => DoubleBytes (ContraVector t n) deriving instance IntBytes (V.Vector t n) => IntBytes (ContraVector t n) deriving instance WordBytes (V.Vector t n) => WordBytes (ContraVector t n) newtype Matrix t n m = Matrix (M.Matrix t n m) instance Show (M.Matrix t n m) => Show (Matrix t n m) where show (Matrix t) = show t deriving instance Eq (M.Matrix t n m) => Eq (Matrix t n m) deriving instance Ord (M.Matrix t n m) => Ord (Matrix t n m) deriving instance Num (M.Matrix t n m) => Num (Matrix t n m) deriving instance Fractional (M.Matrix t n m) => Fractional (Matrix t n m) deriving instance Floating (M.Matrix t n m) => Floating (Matrix t n m) deriving instance PrimBytes (M.Matrix t n m) => PrimBytes (Matrix t n m) deriving instance M.MatrixCalculus t n m (M.Matrix t n m) => M.MatrixCalculus t n m (Matrix t n m) deriving instance M.SquareMatrixCalculus t n (M.Matrix t n n) => M.SquareMatrixCalculus t n (Matrix t n n) deriving instance M.MatrixInverse (M.Matrix t n n) => M.MatrixInverse (Matrix t n n) deriving instance FloatBytes (M.Matrix t n m) => FloatBytes (Matrix t n m) deriving instance DoubleBytes (M.Matrix t n m) => DoubleBytes (Matrix t n m) deriving instance IntBytes (M.Matrix t n m) => IntBytes (Matrix t n m) deriving instance WordBytes (M.Matrix t n m) => WordBytes (Matrix t n m) type family TT t (n :: Nat) (m :: Nat) = v | v -> t n m where TT t 1 1 = Scalar t TT t n 1 = ContraVector t n TT t 1 m = CoVector t m TT t n m = Matrix t n m -- | Fill whole tensor with a single value fill :: M.MatrixCalculus t n m (Tensor t n m) => Tensor t 1 1 -> Tensor t n m fill = M.broadcastMat . _unScalar . _unT {-# INLINE fill #-} -- | Get an element of a tensor index :: M.MatrixCalculus t n m (Tensor t n m) => Int -> Int -> Tensor t n m -> Tensor t 1 1 index i j = Tensor . Scalar . M.indexMat i j {-# INLINE index #-} -- | Get a column vector of a matrix indexCol :: ( M.MatrixCalculus t n m (Tensor t n m) , V.VectorCalculus t n (Tensor t n 1) , PrimBytes (Tensor t n 1) ) => Int -> Tensor t n m -> Tensor t n 1 indexCol = M.indexCol {-# INLINE indexCol #-} -- | Get a row vector of a matrix indexRow :: ( M.MatrixCalculus t n m (Tensor t n m) , V.VectorCalculus t m (Tensor t 1 m) , PrimBytes (Tensor t 1 m) ) => Int -> Tensor t n m -> Tensor t 1 m indexRow = M.indexRow {-# INLINE indexRow #-} transpose :: ( M.MatrixCalculus t n m (Tensor t n m) , M.MatrixCalculus t m n (Tensor t m n) , PrimBytes (Tensor t m n) ) => Tensor t n m -> Tensor t m n transpose = M.transpose {-# INLINE transpose #-} dimN :: ( M.MatrixCalculus t n m (Tensor t n m)) => Tensor t n m -> Int dimN = M.dimN {-# INLINE dimN #-} dimM :: ( M.MatrixCalculus t n m (Tensor t n m)) => Tensor t n m -> Int dimM = M.dimM {-# INLINE dimM #-} -- | Matrix product for tensors rank 2, as well matrix-vector or vector-matrix products infixl 7 %* (%*) :: (M.MatrixProduct (Tensor t n m) (Tensor t m k) (Tensor t n k)) => Tensor t n m -> Tensor t m k -> Tensor t n k (%*) = prod {-# INLINE (%*) #-} -- | Matrix product for tensors rank 2, as well matrix-vector or vector-matrix products prod :: (M.MatrixProduct (Tensor t n m) (Tensor t m k) (Tensor t n k)) => Tensor t n m -> Tensor t m k -> Tensor t n k prod = M.prod {-# INLINE prod #-} -- | Divide on the right: R = A * B^(-1) (//) :: ( M.MatrixProduct (Tensor t n m) (Tensor t m m) (Tensor t n m) , M.MatrixInverse (TT t m m)) => Tensor t n m -> Tensor t m m -> Tensor t n m (//) a b = prod a (inverse b) {-# INLINE (//) #-} -- | Divide on the left: R = A^(-1) * b (\\) :: ( M.MatrixProduct (Tensor t n n) (Tensor t n m) (Tensor t n m) , M.MatrixInverse (TT t n n)) => Tensor t n n -> Tensor t n m -> Tensor t n m (\\) a b = prod (inverse a) b {-# INLINE (\\) #-} -- | Matrix inverse inverse :: M.MatrixInverse (Tensor t n n) => Tensor t n n -> Tensor t n n inverse = M.inverse {-# INLINE inverse #-} instance ( FloatBytes (Tensor Float n m) , FloatBytes (Tensor Float m k) , PrimBytes (Tensor Float n k) , M.MatrixCalculus Float n m (Tensor Float n m) , M.MatrixCalculus Float m k (Tensor Float m k) ) => M.MatrixProduct (Tensor Float n m) (Tensor Float m k) (Tensor Float n k) where prod x y = case (dimN x, dimM x, dimM y) of ( I# n, I# m, I# k) -> M.prodF n m k x y {-# INLINE prod #-} instance ( DoubleBytes (Tensor Double n m) , DoubleBytes (Tensor Double m k) , PrimBytes (Tensor Double n k) , M.MatrixCalculus Double n m (Tensor Double n m) , M.MatrixCalculus Double m k (Tensor Double m k) ) => M.MatrixProduct (Tensor Double n m) (Tensor Double m k) (Tensor Double n k) where prod x y = case (dimN x, dimM x, dimM y) of ( I# n, I# m, I# k) -> M.prodD n m k x y {-# INLINE prod #-} -- | Append one vector to another, adding up their dimensionality (<:>) :: ( PrimBytes (Tensor t k n) , PrimBytes (Tensor t k m) , PrimBytes (Tensor t k (n + m)) ) => Tensor t k n -> Tensor t k m -> Tensor t k (n + m) a <:> b = case (# toBytes a, toBytes b, byteSize a, byteSize b #) of (# arr1, arr2, n, m #) -> case runRW# ( \s0 -> case newByteArray# (n +# m) s0 of (# s1, marr #) -> case copyByteArray# arr1 0# marr 0# n s1 of s2 -> case copyByteArray# arr2 0# marr n m s2 of s3 -> unsafeFreezeByteArray# marr s3 ) of (# _, r #) -> fromBytes r infixl 5 <:> -- Simple types type Mat t n m = Tensor t n m type Vec t n = Tensor t n 1 type Vec' t m = Tensor t 1 m -- Even Sympler types type Vec2f = Tensor Float 2 1 type Vec3f = Tensor Float 3 1 type Vec4f = Tensor Float 4 1 type Vec2f' = Tensor Float 1 2 type Vec3f' = Tensor Float 1 3 type Vec4f' = Tensor Float 1 4 type Mat22f = Tensor Float 2 2 type Mat32f = Tensor Float 3 2 type Mat42f = Tensor Float 4 2 type Mat23f = Tensor Float 2 3 type Mat33f = Tensor Float 3 3 type Mat43f = Tensor Float 4 3 type Mat24f = Tensor Float 2 4 type Mat34f = Tensor Float 3 4 type Mat44f = Tensor Float 4 4 -- construct tensors scalar :: t -> Tensor t 1 1 scalar = Tensor . Scalar vec2 :: V.Vector2D t => t -> t -> Tensor t 2 1 vec2 a b = Tensor . ContraVector $ V.vec2 a b vec2' :: V.Vector2D t => t -> t -> Tensor t 1 2 vec2' a b = Tensor . CoVector $ V.vec2 a b vec3 :: V.Vector3D t => t -> t -> t -> Tensor t 3 1 vec3 a b c = Tensor . ContraVector $ V.vec3 a b c vec3' :: V.Vector3D t => t -> t -> t -> Tensor t 1 3 vec3' a b c = Tensor . CoVector $ V.vec3 a b c vec4 :: V.Vector4D t => t -> t -> t -> t -> Tensor t 4 1 vec4 a b c d = Tensor . ContraVector $ V.vec4 a b c d vec4' :: V.Vector4D t => t -> t -> t -> t -> Tensor t 1 4 vec4' a b c d = Tensor . CoVector $ V.vec4 a b c d -- | Compose a 2x2D matrix mat22 :: M.Matrix2x2 t => Tensor t 2 1 -> Tensor t 2 1 -> Tensor t 2 2 mat22 (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . Matrix $ M.mat22 a b -- | Compose a 3x3D matrix mat33 :: ( PrimBytes (Tensor t 3 3) , PrimBytes (Tensor t 3 2) , PrimBytes (Tensor t 3 1) ) => Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 3 mat33 a b c = a <:> b <:> c -- | Compose a 4x4D matrix mat44 :: ( PrimBytes (Tensor t 4 4) , PrimBytes (Tensor t 4 3) , PrimBytes (Tensor t 4 2) , PrimBytes (Tensor t 4 1) ) => Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 1 -> Tensor t 4 4 mat44 a b c d = a <:> b <:> c <:> d -- useful low-dimensional functions det2 :: V.Vector2D t => Tensor t 2 1 -> Tensor t 2 1 -> Tensor t 1 1 det2 (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . Scalar $ V.det2 a b det2' :: V.Vector2D t => Tensor t 1 2 -> Tensor t 1 2 -> Tensor t 1 1 det2' (Tensor (CoVector a)) (Tensor (CoVector b)) = Tensor . Scalar $ V.det2 a b -- | Cross product for two vectors in 3D infixl 7 × (×) :: V.Vector3D t => Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1 (×) = cross {-# INLINE (×) #-} -- | Cross product for two vectors in 3D cross :: V.Vector3D t => Tensor t 3 1 -> Tensor t 3 1 -> Tensor t 3 1 cross (Tensor (ContraVector a)) (Tensor (ContraVector b)) = Tensor . ContraVector $ V.cross a b {-# INLINE cross #-} -- re-use functions provided by Vector and Matrix Calculus -- | Dot product of two vectors infixl 7 · (·) :: V.VectorCalculus t n v => v -> v -> Tensor t 1 1 (·) = dot {-# INLINE (·) #-} -- | Dot product of two vectors dot :: V.VectorCalculus t n v => v -> v -> Tensor t 1 1 dot a b = Tensor . Scalar $ V.dot a b {-# INLINE dot #-} -- | Sum of absolute values normL1 :: V.VectorCalculus t n v => v -> Tensor t 1 1 normL1 = Tensor . Scalar . V.normL1 {-# INLINE normL1 #-} -- | hypot function (square root of squares) normL2 :: V.VectorCalculus t n v => v -> Tensor t 1 1 normL2 = Tensor . Scalar . V.normL2 {-# INLINE normL2 #-} -- | Maximum of absolute values normLPInf :: V.VectorCalculus t n v => v -> Tensor t 1 1 normLPInf = Tensor . Scalar . V.normLPInf {-# INLINE normLPInf #-} -- | Minimum of absolute values normLNInf :: V.VectorCalculus t n v => v -> Tensor t 1 1 normLNInf = Tensor . Scalar . V.normLNInf {-# INLINE normLNInf #-} -- | Norm in Lp space normLP :: V.VectorCalculus t n v => Int -> v -> Tensor t 1 1 normLP p = Tensor . Scalar . V.normLP p {-# INLINE normLP #-} -- | Identity matrix. Mat with 1 on diagonal and 0 elsewhere eye :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n eye = M.eye {-# INLINE eye #-} -- | Put the same value on the Mat diagonal, 0 otherwise diag :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t 1 1 -> Tensor t n n diag = M.diag . _unScalar . _unT {-# INLINE diag #-} -- | Determinant of Mat det :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n -> Tensor t 1 1 det = Tensor . Scalar . M.det {-# INLINE det #-} -- | Sum of diagonal elements trace :: M.SquareMatrixCalculus t n (Tensor t n n) => Tensor t n n -> Tensor t 1 1 trace = Tensor . Scalar . M.trace {-# INLINE trace #-} -- | Get the diagonal elements from Mat into Vec fromDiag :: ( M.SquareMatrixCalculus t n (Tensor t n n) , V.VectorCalculus t n (Tensor t n 1) , PrimBytes (Tensor t n 1) ) => Tensor t n n -> Tensor t n 1 fromDiag = M.fromDiag {-# INLINE fromDiag #-} -- | Get the diagonal elements from Mat into Vec fromDiag' :: ( M.SquareMatrixCalculus t n (Tensor t n n) , V.VectorCalculus t n (Tensor t 1 n) , PrimBytes (Tensor t 1 n) ) => Tensor t n n -> Tensor t 1 n fromDiag' = M.fromDiag {-# INLINE fromDiag' #-} -- | Set Vec values into the diagonal elements of Mat toDiag :: ( M.SquareMatrixCalculus t n (Tensor t n n) , V.VectorCalculus t n (Tensor t n 1) , PrimBytes (Tensor t n 1) ) => Tensor t n 1 -> Tensor t n n toDiag = M.toDiag {-# INLINE toDiag #-} -- | Set Vec values into the diagonal elements of Mat toDiag' :: ( M.SquareMatrixCalculus t n (Tensor t n n) , V.VectorCalculus t n (Tensor t 1 n) , PrimBytes (Tensor t 1 n) ) => Tensor t 1 n -> Tensor t n n toDiag' = M.toDiag {-# INLINE toDiag' #-} -- missing instances instance Num t => V.VectorCalculus t 1 (Scalar t) where broadcastVec = Scalar {-# INLINE broadcastVec #-} (.*.) = (*) {-# INLINE (.*.) #-} dot a = _unScalar . (a *) {-# INLINE dot #-} indexVec 1 = _unScalar indexVec i = const . error $ "Bad index " ++ show i ++ " for a scalar" {-# INLINE indexVec #-} normL1 = _unScalar . abs {-# INLINE normL1 #-} normL2 = _unScalar . abs {-# INLINE normL2 #-} normLPInf = _unScalar {-# INLINE normLPInf #-} normLNInf = _unScalar {-# INLINE normLNInf #-} normLP _ = _unScalar . abs {-# INLINE normLP #-} dim _ = 1 {-# INLINE dim #-} instance Num t => M.MatrixCalculus t 1 1 (Scalar t) where broadcastMat = Scalar {-# INLINE broadcastMat #-} indexMat 1 1 = _unScalar indexMat i j = const . error $ "Bad index (" ++ show i ++ ", " ++ show j ++ ") for a scalar" {-# INLINE indexMat #-} transpose = unsafeCoerce {-# INLINE transpose #-} dimN _ = 1 {-# INLINE dimN #-} dimM _ = 1 {-# INLINE dimM #-} indexCol 1 = unsafeCoerce indexCol j = const . error $ "Bad column index " ++ show j ++ " for a scalar" {-# INLINE indexCol #-} indexRow 1 = unsafeCoerce indexRow i = const . error $ "Bad row index " ++ show i ++ " for a scalar" {-# INLINE indexRow #-} instance Num t => M.SquareMatrixCalculus t 1 (Scalar t) where eye = Scalar 1 {-# INLINE eye #-} diag = Scalar {-# INLINE diag #-} det = _unScalar {-# INLINE det #-} trace = _unScalar {-# INLINE trace #-} fromDiag = unsafeCoerce {-# INLINE fromDiag #-} toDiag = unsafeCoerce {-# INLINE toDiag #-} instance Fractional t => M.MatrixInverse (Scalar t) where inverse = Scalar . recip . _unScalar instance (KnownNat n, V.VectorCalculus t n (V.Vector t n)) => M.MatrixCalculus t n 1 (ContraVector t n) where broadcastMat = ContraVector . V.broadcastVec {-# INLINE broadcastMat #-} indexMat i 1 (ContraVector v) = V.indexVec i v indexMat i j (ContraVector v) = error $ "Bad index (" ++ show i ++ ", " ++ show j ++ ") for a " ++ show (V.dim v) ++ "x1D matrix" {-# INLINE indexMat #-} transpose (ContraVector v) = unsafeCoerce $ CoVector v {-# INLINE transpose #-} dimN = V.dim . _unContraVec {-# INLINE dimN #-} dimM _ = 1 {-# INLINE dimM #-} indexCol 1 (ContraVector v) = unsafeCoerce $ ContraVector v indexCol j (ContraVector v) = error $ "Bad column index " ++ show j ++ " for a " ++ show (V.dim v) ++ "x1D matrix" {-# INLINE indexCol #-} indexRow i = unsafeCoerce . Scalar . V.indexVec i . _unContraVec {-# INLINE indexRow #-} instance (KnownNat m, V.VectorCalculus t m (V.Vector t m)) => M.MatrixCalculus t 1 m (CoVector t m) where broadcastMat = CoVector . V.broadcastVec {-# INLINE broadcastMat #-} indexMat 1 i (CoVector v) = V.indexVec i v indexMat j i (CoVector v) = error $ "Bad index (" ++ show j ++ ", " ++ show i ++ ") for a 1x" ++ show (V.dim v) ++ "D matrix" {-# INLINE indexMat #-} transpose (CoVector v) = unsafeCoerce $ ContraVector v {-# INLINE transpose #-} dimN _ = 1 {-# INLINE dimN #-} dimM = V.dim . _unCoVec {-# INLINE dimM #-} indexCol i = unsafeCoerce . Scalar . V.indexVec i . _unCoVec {-# INLINE indexCol #-} indexRow 1 x = unsafeCoerce x indexRow j (CoVector v) = error $ "Bad column index " ++ show j ++ " for a 1x" ++ show (V.dim v) ++ "D matrix" {-# INLINE indexRow #-}