{-# LANGUAGE TypeFamilies #-} module Test.LowerUpper (testsVar) where import qualified Test.Divide as Divide import qualified Test.Generator as Gen import qualified Test.Utility as Util import Test.Permutation (genPermMatrix) import Test.Generator ((<#\#>), (<#*#>)) import Test.Utility (Tagged, approx, approxMatrix, maybeConjugate) import qualified Numeric.LAPACK.Linear.LowerUpper as LU import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Matrix.Square as Square import qualified Numeric.LAPACK.Matrix as Matrix import qualified Numeric.LAPACK.Permutation as Perm import Numeric.LAPACK.Matrix.Square (Square) import Numeric.LAPACK.Matrix (ShapeInt, (#*#), (#*##)) import Numeric.LAPACK.Scalar (RealOf, selectReal) import qualified Numeric.Netlib.Class as Class import Control.Applicative (liftA2, (<$>)) import Data.Semigroup ((<>)) import qualified Test.QuickCheck as QC toFromTallMatrix :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => Matrix.Tall ShapeInt ShapeInt a -> Bool toFromTallMatrix a = approxMatrix 1e-5 a (LU.toMatrix $ LU.fromMatrix a) {- Strictly wide matrices are problematic, because a full rank wide matrix can have a leading column consisting entirely of zeros. To prevent this, the LU decomposition would need column pivoting. For now we restrict to Square matrices. -} toFromSquareMatrix :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => Square ShapeInt a -> Bool toFromSquareMatrix a = approxMatrix 1e-5 a (LU.toMatrix $ LU.fromMatrix a) permutation :: (Class.Floating a) => PermMatrix.Permutation ShapeInt a -> Bool permutation perm = perm == (LU.extractP Perm.NonInverted $ LU.fromMatrix $ PermMatrix.toSquare perm) multiplyPApply :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (LU.Inversion, LU.Inversion) -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplyPApply (inv0,inv1) (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.multiplyP (inv0<>inv1) lu b) (Perm.apply inv0 (PermMatrix.toPermutation $ LU.extractP inv1 lu) b) multiplyP :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => LU.Inversion -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplyP inv (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.multiplyP inv lu b) (LU.extractP inv lu #*## b) multiplyL :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => LU.Transposition -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplyL trans (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.wideMultiplyL trans lu b) (Matrix.multiplySquare trans (LU.extractL lu) b) wideMultiplyL :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => LU.Transposition -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool wideMultiplyL trans (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.wideMultiplyL trans lu b) (Matrix.multiplySquare trans (LU.wideExtractL lu) b) multiplyU :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => LU.Transposition -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplyU trans (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.tallMultiplyU trans lu b) (Matrix.multiplySquare trans (LU.extractU lu) b) tallMultiplyU :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => LU.Transposition -> (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool tallMultiplyU trans (a,b) = let lu = LU.fromMatrix a in approxMatrix (selectReal 1e-1 1e-5) (LU.tallMultiplyU trans lu b) (Matrix.multiplySquare trans (LU.tallExtractU lu) b) multiplySquareFull :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplySquareFull (a,b) = approxMatrix (selectReal 1e-1 1e-5) (a #*## b) (LU.multiplyFull (LU.mapExtent Extent.fromSquare $ LU.fromMatrix a) b) multiplyTallFull :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool multiplyTallFull (a,b) = approxMatrix (selectReal 1e-1 1e-5) (a #*# b) (LU.multiplyFull (LU.mapExtent Extent.generalizeTall $ LU.fromMatrix a) b) determinant :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => Square ShapeInt a -> Bool determinant a = approx (selectReal 1e-1 1e-5) (Square.determinant a) (LU.determinant $ LU.fromMatrix a) wideSolveL :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (LU.Transposition, LU.Conjugation) -> (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool wideSolveL (trans,conj) (a,b) = let lu = LU.fromMatrix a l = maybeConjugate conj $ LU.wideExtractL lu in approxMatrix (selectReal 1e-1 1e-5) (LU.wideSolveL trans conj lu b) (Matrix.solve trans l b) tallSolveU :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (LU.Transposition, LU.Conjugation) -> (Matrix.Tall ShapeInt ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool tallSolveU (trans,conj) (a,b) = let lu = LU.fromMatrix a u = maybeConjugate conj $ LU.tallExtractU lu in approxMatrix (selectReal 1e-1 1e-5) (LU.tallSolveU trans conj lu b) (Matrix.solve trans u b) solve :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => (Square ShapeInt a, Matrix.General ShapeInt ShapeInt a) -> Bool solve (a,b) = approxMatrix (selectReal 1e-1 1e-5) (Square.solve a b) (LU.solve (LU.fromMatrix a) b) checkForAll :: (Show a, QC.Testable test) => Gen.T dim tag a -> (a -> test) -> Tagged tag QC.Property checkForAll gen = Util.checkForAll (Gen.run gen 3 5) checkForAllExtra :: (Show a, Show b, QC.Testable test) => QC.Gen a -> Gen.T dim tag b -> (a -> b -> test) -> Tagged tag QC.Property checkForAllExtra = Gen.withExtra checkForAll testsVar :: (Show a, Class.Floating a, Eq a, RealOf a ~ ar, Class.Real ar) => [(String, Tagged a QC.Property)] testsVar = ("toFromTallMatrix", checkForAll Gen.fullRankTall toFromTallMatrix) : ("toFromSquareMatrix", checkForAll Gen.invertible toFromSquareMatrix) : ("permutation", checkForAll genPermMatrix permutation) : ("multiplyPApply", checkForAllExtra (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum) ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyPApply) : ("multiplyP", checkForAllExtra QC.arbitraryBoundedEnum ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyP) : ("multiplyL", checkForAllExtra QC.arbitraryBoundedEnum ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyL) : ("wideMultiplyL", checkForAllExtra QC.arbitraryBoundedEnum ((,) <$> Gen.invertible <#*#> Gen.matrix) wideMultiplyL) : ("multiplyU", checkForAllExtra QC.arbitraryBoundedEnum ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplyU) : ("tallMultiplyU", checkForAllExtra QC.arbitraryBoundedEnum ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) tallMultiplyU) : ("multiplySquareFull", checkForAll ((,) <$> Gen.invertible <#*#> Gen.matrix) multiplySquareFull) : ("multiplyTallFull", checkForAll ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) multiplyTallFull) : ("determinant", checkForAll Gen.invertible determinant) : ("wideSolveL", checkForAllExtra (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum) ((,) <$> Gen.invertible <#\#> Gen.matrix) wideSolveL) : ("tallSolveU", checkForAllExtra (liftA2 (,) QC.arbitraryBoundedEnum QC.arbitraryBoundedEnum) ((,) <$> Gen.fullRankTall <#*#> Gen.matrix) tallSolveU) : ("solve", checkForAll ((,) <$> Gen.invertible <#\#> Gen.matrix) solve) : Divide.testsVar (LU.fromMatrix <$> Gen.invertible) ++ []