{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE UndecidableInstances #-} module Numeric.LAPACK.Matrix.Plain.Multiply where import qualified Numeric.LAPACK.Matrix.Plain.Class as Plain import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Shape.Box as Box import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Matrix.BandedHermitian.Basic as BandedHermitian import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Triangular import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as Hermitian import qualified Numeric.LAPACK.Matrix.Square.Basic as Square import qualified Numeric.LAPACK.Matrix.Basic as Basic import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Matrix.Shape.Private (Empty, Filled, Unit, NonUnit) import Numeric.LAPACK.Matrix.Extent.Private (Small) import Numeric.LAPACK.Matrix.Triangular.Basic (Triangular) import Numeric.LAPACK.Matrix.Basic (swapMultiply, transpose) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed)) import Numeric.LAPACK.Matrix.Private (Square, Full, mapExtent) import Numeric.LAPACK.Vector (Vector) import qualified Numeric.Netlib.Class as Class import qualified Type.Data.Num.Unary as Unary import Type.Data.Num.Unary ((:+:)) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array) class (Box.Box shape) => Scale shape where scale :: (Class.Floating a) => a -> Array shape a -> Array shape a scale = Vector.scale instance (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => Scale (MatrixShape.Full vert horiz height width) where instance (MatrixShape.Content lo, MatrixShape.Content up, diag ~ NonUnit, Shape.C size) => Scale (MatrixShape.Triangular lo diag up size) where instance (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => Scale (MatrixShape.Banded sub super vert horiz height width) where class (Box.Box shape) => MultiplyVector shape where matrixVector :: (Box.WidthOf shape ~ width, Eq width, Class.Floating a) => Array shape a -> Vector width a -> Vector (Box.HeightOf shape) a vectorMatrix :: (Box.HeightOf shape ~ height, Eq height, Class.Floating a) => Vector height a -> Array shape a -> Vector (Box.WidthOf shape) a instance (Extent.C vert, Extent.C horiz, Shape.C width, Shape.C height) => MultiplyVector (MatrixShape.Full vert horiz height width) where matrixVector = Basic.multiplyVector vectorMatrix v m = Basic.multiplyVector (transpose m) v instance (Shape.C shape) => MultiplyVector (MatrixShape.Hermitian shape) where matrixVector = Hermitian.multiplyVector NonTransposed vectorMatrix = flip $ Hermitian.multiplyVector Transposed instance (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag, Shape.C shape) => MultiplyVector (MatrixShape.Triangular lo diag up shape) where matrixVector m v = Triangular.multiplyVector m v vectorMatrix v m = Triangular.multiplyVector (Triangular.transpose m) v instance (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => MultiplyVector (MatrixShape.Banded sub super vert horiz height width) where matrixVector m v = Banded.multiplyVector m v vectorMatrix v m = Banded.multiplyVector (Banded.transpose m) v instance (Unary.Natural offDiag, Shape.C size) => MultiplyVector (MatrixShape.BandedHermitian offDiag size) where matrixVector = BandedHermitian.multiplyVector NonTransposed vectorMatrix = flip $ BandedHermitian.multiplyVector Transposed class (Plain.SquareShape shape) => MultiplySquare shape where {-# MINIMAL transposableSquare | fullSquare,squareFull #-} transposableSquare :: (Box.HeightOf shape ~ height, Eq height, Shape.C width, Extent.C vert, Extent.C horiz, Class.Floating a) => Transposition -> Array shape a -> Full vert horiz height width a -> Full vert horiz height width a transposableSquare NonTransposed a b = squareFull a b transposableSquare Transposed a b = transpose $ fullSquare (transpose b) a squareFull :: (Box.HeightOf shape ~ height, Eq height, Shape.C width, Extent.C vert, Extent.C horiz, Class.Floating a) => Array shape a -> Full vert horiz height width a -> Full vert horiz height width a squareFull = transposableSquare NonTransposed fullSquare :: (Box.WidthOf shape ~ width, Eq width, Shape.C height, Extent.C vert, Extent.C horiz, Class.Floating a) => Full vert horiz height width a -> Array shape a -> Full vert horiz height width a fullSquare = swapMultiply $ transposableSquare Transposed instance (vert ~ Small, horiz ~ Small, Shape.C height, height ~ width) => MultiplySquare (MatrixShape.Full vert horiz height width) where transposableSquare NonTransposed = squareFull transposableSquare Transposed = squareFull . transpose squareFull a b = Basic.multiply (mapExtent Extent.fromSquare a) b instance (Shape.C shape) => MultiplySquare (MatrixShape.Hermitian shape) where transposableSquare = Hermitian.multiplyFull instance (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag, Shape.C shape) => MultiplySquare (MatrixShape.Triangular lo diag up shape) where squareFull = Triangular.multiplyFull fullSquare = swapMultiply $ Triangular.multiplyFull . Triangular.transpose instance (Unary.Natural sub, Unary.Natural super, vert ~ Small, horiz ~ Small, Shape.C height, height ~ width) => MultiplySquare (MatrixShape.Banded sub super vert horiz height width) where squareFull = Banded.multiplyFull . bandedGenSquare fullSquare = swapMultiply $ Banded.multiplyFull . bandedGenSquare . Banded.transpose bandedGenSquare :: (Extent.C vert, Extent.C horiz) => Banded.Square sub super size a -> Banded.Banded sub super vert horiz size size a bandedGenSquare = Banded.mapExtent Extent.fromSquare instance (Unary.Natural offDiag, Shape.C size) => MultiplySquare (MatrixShape.BandedHermitian offDiag size) where transposableSquare = BandedHermitian.multiplyFull class (Plain.SquareShape shape) => Power shape where square :: (Class.Floating a) => Array shape a -> Array shape a power :: (Class.Floating a) => Int -> Array shape a -> Array shape a instance (Extent.Small ~ vert, Extent.Small ~ horiz, Shape.C height, height ~ width) => Power (MatrixShape.Full vert horiz height width) where square = Square.square power = Square.power . fromIntegral instance (Shape.C size) => Power (MatrixShape.Hermitian size) where square = Hermitian.square power = Hermitian.power . fromIntegral instance (Triangular.PowerContentDiag lo diag up, Shape.C size) => Power (MatrixShape.Triangular lo diag up size) where square = Triangular.square power = Triangular.power class (Box.Box shape) => MultiplySame shape where same :: (Class.Floating a) => Array shape a -> Array shape a -> Array shape a instance (Extent.C vert, Extent.C horiz, Shape.C height, Eq height, height ~ width) => MultiplySame (MatrixShape.Full vert horiz height width) where same = Basic.multiply instance (MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diag, Shape.C size, Eq size) => MultiplySame (MatrixShape.Triangular lo diag up size) where same = Triangular.multiply {- | This class allows to Basic.multiply two matrices of arbitrary special features and returns the most special matrix type possible. At the first glance, this is handy. At the second glance, this has some problems. First of all, we may refine the types in future and then multiplication may return a different, more special type than before. Second, if you write code with polymorphic matrix types, then 'matrixMatrix' may leave you with constraints like @ExtentPriv.Multiply vert vert ~ vert@. That constraint is always fulfilled but the compiler cannot infer that. Because of these problems you may instead consider using specialised 'Basic.multiply' functions from the various modules for production use. Btw. 'MultiplyVector' and 'MultiplySquare' are much less problematic, because the input and output are always dense vectors or dense matrices. -} class (Shape.C shapeA, Shape.C shapeB) => Multiply shapeA shapeB where type Multiplied shapeA shapeB matrixMatrix :: (Class.Floating a) => Array shapeA a -> Array shapeB a -> Array (Multiplied shapeA shapeB) a instance (Shape.C heightA, Shape.C widthA, Shape.C widthB, widthA ~ heightB, Eq heightB, Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB) => Multiply (MatrixShape.Full vertA horizA heightA widthA) (MatrixShape.Full vertB horizB heightB widthB) where type Multiplied (MatrixShape.Full vertA horizA heightA widthA) (MatrixShape.Full vertB horizB heightB widthB) = MatrixShape.Full (ExtentPriv.Multiply vertA vertB) (ExtentPriv.Multiply horizA horizB) heightA widthB matrixMatrix a b = case unifyFactors (fullExtent a) (fullExtent b) of ((ExtentPriv.TagFact, ExtentPriv.TagFact), (unifyLeft, unifyRight)) -> Basic.multiply (mapExtent unifyLeft a) (mapExtent unifyRight b) fullExtent :: Full vert horiz height width a -> Extent.Extent vert horiz height width fullExtent = MatrixShape.fullExtent . Array.shape unifyFactors :: (Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB) => (ExtentPriv.Multiply vertA vertB ~ vertC) => (ExtentPriv.Multiply horizA horizB ~ horizC) => Extent.Extent vertA horizA height fuse -> Extent.Extent vertB horizB fuse width -> ((ExtentPriv.TagFact vertC, ExtentPriv.TagFact horizC), (Extent.Map vertA horizA vertC horizC height fuse, Extent.Map vertB horizB vertC horizC fuse width)) unifyFactors a b = ((ExtentPriv.multiplyTagLaw (ExtentPriv.heightFact a) (ExtentPriv.heightFact b), ExtentPriv.multiplyTagLaw (ExtentPriv.widthFact a) (ExtentPriv.widthFact b)), (ExtentPriv.Map $ flip ExtentPriv.unifyLeft b, ExtentPriv.Map $ ExtentPriv.unifyRight a)) instance (Extent.C vert, Extent.C horiz, Shape.C size, size ~ width, Eq width, Shape.C height) => Multiply (MatrixShape.Full vert horiz height width) (MatrixShape.Hermitian size) where type Multiplied (MatrixShape.Full vert horiz height width) (MatrixShape.Hermitian size) = MatrixShape.Full vert horiz height width matrixMatrix = fullSquare instance (Extent.C vert, Extent.C horiz, Shape.C size, size ~ height, Eq height, Shape.C width) => Multiply (MatrixShape.Hermitian size) (MatrixShape.Full vert horiz height width) where type Multiplied (MatrixShape.Hermitian size) (MatrixShape.Full vert horiz height width) = MatrixShape.Full vert horiz height width matrixMatrix = squareFull instance (Shape.C shapeA, shapeA ~ shapeB, Eq shapeB) => Multiply (MatrixShape.Hermitian shapeA) (MatrixShape.Hermitian shapeB) where type Multiplied (MatrixShape.Hermitian shapeA) (MatrixShape.Hermitian shapeB) = MatrixShape.Square shapeA matrixMatrix a = squareFull a . Hermitian.toSquare instance (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag, Extent.C vert, Extent.C horiz, Shape.C size, size ~ width, Eq width, Shape.C height) => Multiply (MatrixShape.Full vert horiz height width) (MatrixShape.Triangular lo diag up size) where type Multiplied (MatrixShape.Full vert horiz height width) (MatrixShape.Triangular lo diag up size) = MatrixShape.Full vert horiz height width matrixMatrix = fullSquare instance (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag, Extent.C vert, Extent.C horiz, Shape.C size, size ~ height, Eq height, Shape.C width) => Multiply (MatrixShape.Triangular lo diag up size) (MatrixShape.Full vert horiz height width) where type Multiplied (MatrixShape.Triangular lo diag up size) (MatrixShape.Full vert horiz height width) = MatrixShape.Full vert horiz height width matrixMatrix = squareFull instance (Shape.C sizeA, sizeA ~ sizeB, Eq sizeB, MultiplyTriangular loA upA loB upB, MatrixShape.TriDiag diagA, MatrixShape.TriDiag diagB) => Multiply (MatrixShape.Triangular loA diagA upA sizeA) (MatrixShape.Triangular loB diagB upB sizeB) where type Multiplied (MatrixShape.Triangular loA diagA upA sizeA) (MatrixShape.Triangular loB diagB upB sizeB) = -- requires UndecidableInstances MultipliedTriangular loA diagA upA loB diagB upB sizeB matrixMatrix = triangularTriangular class (MatrixShape.Content loA, MatrixShape.Content upA, MatrixShape.Content loB, MatrixShape.Content upB) => MultiplyTriangular loA upA loB upB where triangularTriangular :: (Class.Floating a, Shape.C size, Eq size, MatrixShape.TriDiag diagA, MatrixShape.TriDiag diagB) => Triangular loA diagA upA size a -> Triangular loB diagB upB size a -> Array (MultipliedTriangular loA diagA upA loB diagB upB size) a type MultipliedTriangular loA diagA upA loB diagB upB size = ComposedTriangular (MultipliedPart loA loB) (MultipliedDiag diagA diagB) (MultipliedPart upA upB) size type family MultipliedPart a b :: * type instance MultipliedPart Empty b = b type instance MultipliedPart Filled b = Filled type family MultipliedDiag a b :: * type instance MultipliedDiag Unit b = b type instance MultipliedDiag NonUnit b = NonUnit type family ComposedTriangular lo diag up size :: * type instance ComposedTriangular Empty diag up size = MatrixShape.Triangular Empty diag up size type instance ComposedTriangular Filled diag Empty size = MatrixShape.LowerTriangular diag size type instance ComposedTriangular Filled diag Filled size = MatrixShape.Square size instance MultiplyTriangular Empty Empty Empty Empty where triangularTriangular = triangularTriangularConform instance MultiplyTriangular Empty Empty Filled Filled where triangularTriangular a = Triangular.multiplyFull a . Triangular.toSquare instance MultiplyTriangular Empty Filled Filled Filled where triangularTriangular a = Triangular.multiplyFull a . Triangular.toSquare instance MultiplyTriangular Filled Empty Filled Filled where triangularTriangular a = Triangular.multiplyFull a . Triangular.toSquare instance MultiplyTriangular Empty Filled Empty Filled where triangularTriangular = triangularTriangularConform instance MultiplyTriangular Filled Empty Filled Empty where triangularTriangular = triangularTriangularConform instance MultiplyTriangular Filled Empty Empty Filled where triangularTriangular a = Triangular.multiplyFull a . Triangular.toSquare instance MultiplyTriangular Empty Filled Filled Empty where triangularTriangular a = Triangular.multiplyFull a . Triangular.toSquare instance MultiplyTriangular Filled Filled Empty Empty where triangularTriangular = triangularTriangularToSquare instance MultiplyTriangular Filled Filled Empty Filled where triangularTriangular = triangularTriangularToSquare instance MultiplyTriangular Filled Filled Filled Empty where triangularTriangular = triangularTriangularToSquare instance MultiplyTriangular Filled Filled Filled Filled where triangularTriangular = triangularTriangularToSquare triangularTriangularToSquare :: (MatrixShape.Content loA, MatrixShape.Content upA, MatrixShape.TriDiag diagA, MatrixShape.Content loB, MatrixShape.Content upB, MatrixShape.TriDiag diagB, Shape.C size, Eq size, Class.Floating a) => Triangular loA diagA upA size a -> Triangular loB diagB upB size a -> Square size a triangularTriangularToSquare = fullSquare . Triangular.toSquare newtype TriangularTriangularConform lo up size a diagB diagA = TriangularTriangularConform { getTriangularTriangularConform :: Triangular lo diagA up size a -> Triangular lo diagB up size a -> Triangular lo (MultipliedDiag diagA diagB) up size a } triangularTriangularConform :: (Shape.C size, Eq size, Class.Floating a, MatrixShape.DiagUpLo lo up, MatrixShape.TriDiag diagA, MatrixShape.TriDiag diagB) => (MultipliedDiag diagA diagB ~ diagC) => Triangular lo diagA up size a -> Triangular lo diagB up size a -> Triangular lo diagC up size a triangularTriangularConform = getTriangularTriangularConform $ MatrixShape.switchTriDiag (TriangularTriangularConform $ \a b -> Triangular.multiply (Triangular.relaxUnitDiagonal a) b) (TriangularTriangularConform $ \a b -> Triangular.multiply a (Triangular.strictNonUnitDiagonal b)) instance (Unary.Natural sub, Unary.Natural super, Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB, Shape.C heightA, Shape.C widthA, Shape.C widthB, widthA ~ heightB, Eq heightB) => Multiply (MatrixShape.Full vertA horizA heightA widthA) (MatrixShape.Banded sub super vertB horizB heightB widthB) where type Multiplied (MatrixShape.Full vertA horizA heightA widthA) (MatrixShape.Banded sub super vertB horizB heightB widthB) = MatrixShape.Full (ExtentPriv.Multiply vertA vertB) (ExtentPriv.Multiply horizA horizB) heightA widthB matrixMatrix a b = case unifyFactors (fullExtent a) (bandedExtent b) of ((ExtentPriv.TagFact, ExtentPriv.TagFact), (unifyLeft, unifyRight)) -> swapMultiply (Banded.multiplyFull . Banded.transpose) (mapExtent unifyLeft a) (Banded.mapExtent unifyRight b) instance (Unary.Natural sub, Unary.Natural super, Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB, Shape.C heightA, Shape.C widthA, Shape.C widthB, widthA ~ heightB, Eq heightB) => Multiply (MatrixShape.Banded sub super vertA horizA heightA widthA) (MatrixShape.Full vertB horizB heightB widthB) where type Multiplied (MatrixShape.Banded sub super vertA horizA heightA widthA) (MatrixShape.Full vertB horizB heightB widthB) = MatrixShape.Full (ExtentPriv.Multiply vertA vertB) (ExtentPriv.Multiply horizA horizB) heightA widthB matrixMatrix a b = case unifyFactors (bandedExtent a) (fullExtent b) of ((ExtentPriv.TagFact, ExtentPriv.TagFact), (unifyLeft, unifyRight)) -> Banded.multiplyFull (Banded.mapExtent unifyLeft a) (mapExtent unifyRight b) instance (Unary.Natural subA, Unary.Natural superA, Unary.Natural subB, Unary.Natural superB, Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB, Shape.C heightA, Shape.C widthA, Shape.C widthB, widthA ~ heightB, Eq heightB) => Multiply (MatrixShape.Banded subA superA vertA horizA heightA widthA) (MatrixShape.Banded subB superB vertB horizB heightB widthB) where type Multiplied (MatrixShape.Banded subA superA vertA horizA heightA widthA) (MatrixShape.Banded subB superB vertB horizB heightB widthB) = MatrixShape.Banded (subA :+: subB) (superA :+: superB) (ExtentPriv.Multiply vertA vertB) (ExtentPriv.Multiply horizA horizB) heightA widthB matrixMatrix a b = case unifyFactors (bandedExtent a) (bandedExtent b) of ((ExtentPriv.TagFact, ExtentPriv.TagFact), (unifyLeft, unifyRight)) -> Banded.multiply (Banded.mapExtent unifyLeft a) (Banded.mapExtent unifyRight b) bandedExtent :: Banded.Banded sup super vert horiz height width a -> Extent.Extent vert horiz height width bandedExtent = MatrixShape.bandedExtent . Array.shape instance (Unary.Natural offDiag, Extent.C vert, Extent.C horiz, Shape.C size, size ~ width, Eq width, Shape.C height, Eq height) => Multiply (MatrixShape.Full vert horiz height width) (MatrixShape.BandedHermitian offDiag size) where type Multiplied (MatrixShape.Full vert horiz height width) (MatrixShape.BandedHermitian offDiag size) = MatrixShape.Full vert horiz height width matrixMatrix = fullSquare instance (Unary.Natural offDiag, Extent.C vert, Extent.C horiz, Shape.C size, size ~ height, Eq height, Shape.C width, Eq width) => Multiply (MatrixShape.BandedHermitian offDiag size) (MatrixShape.Full vert horiz height width) where type Multiplied (MatrixShape.BandedHermitian offDiag size) (MatrixShape.Full vert horiz height width) = MatrixShape.Full vert horiz height width matrixMatrix = squareFull instance (Unary.Natural offDiag, Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz, Shape.C size, size ~ width, Eq width, Shape.C height, Eq height) => Multiply (MatrixShape.Banded sub super vert horiz height width) (MatrixShape.BandedHermitian offDiag size) where type Multiplied (MatrixShape.Banded sub super vert horiz height width) (MatrixShape.BandedHermitian offDiag size) = MatrixShape.Banded (sub:+:offDiag) (super:+:offDiag) vert horiz height width matrixMatrix a b = Banded.multiply a (bandedGenSquare $ BandedHermitian.toBanded b) instance (Unary.Natural offDiag, Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz, Shape.C size, size ~ height, Eq height, Shape.C width, Eq width) => Multiply (MatrixShape.BandedHermitian offDiag size) (MatrixShape.Banded sub super vert horiz height width) where type Multiplied (MatrixShape.BandedHermitian offDiag size) (MatrixShape.Banded sub super vert horiz height width) = MatrixShape.Banded (offDiag:+:sub) (offDiag:+:super) vert horiz height width matrixMatrix a b = Banded.multiply (bandedGenSquare $ BandedHermitian.toBanded a) b instance (Unary.Natural offDiagA, Unary.Natural offDiagB, Shape.C sizeA, sizeA ~ sizeB, Shape.C sizeB, Eq sizeB) => Multiply (MatrixShape.BandedHermitian offDiagA sizeA) (MatrixShape.BandedHermitian offDiagB sizeB) where type Multiplied (MatrixShape.BandedHermitian offDiagA sizeA) (MatrixShape.BandedHermitian offDiagB sizeB) = MatrixShape.Banded (offDiagA:+:offDiagB) (offDiagA:+:offDiagB) Small Small sizeA sizeB matrixMatrix a b = Banded.multiply (BandedHermitian.toBanded a) (BandedHermitian.toBanded b)