{-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} module Numeric.LAPACK.Matrix.Diagonal ( Diagonal, FlexDiagonal, fromList, autoFromList, fromVector, lift, stack, (%%%), split, multiply, solve, inverse, determinant, ) where import qualified Numeric.LAPACK.Matrix.Quadratic as Quad import qualified Numeric.LAPACK.Matrix.Banded as Banded import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix import qualified Numeric.LAPACK.Matrix.Basic as FullBasic import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import qualified Numeric.LAPACK.Vector as Vector import qualified Numeric.LAPACK.Scalar as Scalar import Numeric.LAPACK.Matrix.Array.Banded (Diagonal, FlexDiagonal) import Numeric.LAPACK.Matrix.Layout.Private (Order) import Numeric.LAPACK.Matrix.Private (ShapeInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked)) import qualified Numeric.Netlib.Class as Class import Type.Base.Proxy (Proxy(Proxy)) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable (Array) import Data.Array.Comfort.Shape ((::+)) import Foreign.Storable (Storable) fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Diagonal sh a fromList order sh = Banded.squareFromList (Proxy,Proxy) order sh autoFromList :: (Storable a) => Order -> [a] -> Diagonal ShapeInt a autoFromList order = fromVector order . Vector.autoFromList fromVector :: (Shape.C sh, Storable a) => Order -> Vector sh a -> Diagonal sh a fromVector order = ArrMatrix.Array . Array.mapShape (Omni.quadratic order) takeDiagonal :: (Omni.TriDiag diag) => FlexDiagonal diag sh a -> Vector sh a takeDiagonal = Array.mapShape Omni.squareSize . ArrMatrix.unwrap lift :: (Layout.Packing pack, Shape.C sha, Shape.C shb, Class.Floating a, Class.Floating b) => (Array sha a -> Array shb b) -> FlexDiagonalP pack Omni.Arbitrary sha a -> FlexDiagonalP pack Omni.Arbitrary shb b lift f a = case ArrMatrix.packTag a of Layout.Packed -> Quad.diagonal (ArrMatrix.order a) $ f $ Quad.takeDiagonal a Layout.Unpacked -> Quad.diagonal (ArrMatrix.order a) $ f $ Quad.takeDiagonal a type FlexDiagonalP pack diag sh = ArrMatrix.Quadratic pack diag Layout.Empty Layout.Empty sh infixr 2 %%% (%%%), stack :: (Layout.Packing pack) => (Omni.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) => FlexDiagonalP pack diag sh0 a -> FlexDiagonalP pack diag sh1 a -> FlexDiagonalP pack diag (sh0::+sh1) a (%%%) = stack stack a b = let order = Omni.order $ ArrMatrix.shape b in case ArrMatrix.packTag a of Layout.Packed -> ArrMatrix.Array $ Array.mapShape (Omni.uncheckedDiagonal order) $ Vector.append (takeDiagonal a) (takeDiagonal b) Layout.Unpacked -> let shc = Layout.general order (Unchecked $ Quad.size a) (Unchecked $ Quad.size b) in ArrMatrix.liftUnpacked2 (\a_ b_ -> FullBasic.mapExtent Extent.recheckAppend $ FullBasic.stack (FullBasic.uncheck a_) (Vector.zero shc) (Vector.zero $ Layout.inverse shc) (FullBasic.uncheck b_)) a b split :: (Omni.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) => FlexDiagonalP pack diag (sh0::+sh1) a -> (FlexDiagonalP pack diag sh0 a, FlexDiagonalP pack diag sh1 a) split a = (Quad.takeTopLeft a, Quad.takeBottomRight a) multiply :: (Omni.TriDiag diag, Shape.C sh, Eq sh, Class.Floating a) => FlexDiagonal diag sh a -> FlexDiagonal diag sh a -> FlexDiagonal diag sh a multiply = ArrMatrix.liftOmni2 Vector.mul solve :: (Omni.TriDiag diag, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Class.Floating a) => FlexDiagonal diag height a -> ArrMatrix.Full meas vert horiz height width a -> ArrMatrix.Full meas vert horiz height width a solve a = case ArrMatrix.diagTag a of Omni.Arbitrary -> Banded.solve a Omni.Unit -> \b -> if Omni.squareSize (ArrMatrix.shape a) == Omni.height (ArrMatrix.shape b) then b else error ("Diagonal.solve: height shapes mismatch") inverse :: (Omni.TriDiag diag, Shape.C sh, Class.Floating a) => FlexDiagonal diag sh a -> FlexDiagonal diag sh a inverse a = case ArrMatrix.diagTag a of Omni.Unit -> a Omni.Arbitrary -> ArrMatrix.liftOmni1 Vector.recip a determinant :: (Omni.TriDiag diag, Shape.C sh, Class.Floating a) => FlexDiagonal diag sh a -> a determinant a = case ArrMatrix.diagTag a of Omni.Unit -> Scalar.one Omni.Arbitrary -> Vector.product $ ArrMatrix.unwrap a