{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Test.Utility where import qualified Numeric.LAPACK.Matrix.Hermitian as Herm import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Matrix as Matrix import qualified Numeric.LAPACK.Vector as Vector import qualified Numeric.LAPACK.Orthogonal.Householder as HH import qualified Numeric.LAPACK.Orthogonal as Ortho import Numeric.LAPACK.Matrix.Array (ArrayMatrix) import Numeric.LAPACK.Matrix.Shape (Order(RowMajor,ColumnMajor)) import Numeric.LAPACK.Matrix (Matrix, ShapeInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Scalar (RealOf, absolute) import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable (Array) import Data.Array.Comfort.Shape ((:+:)) import qualified Control.Monad.Trans.State as MS import Control.Monad (replicateM) import Control.Applicative (Applicative, liftA2, pure, (<*>), (<$>)) import qualified Data.List.HT as ListHT import qualified Data.Complex as Complex import Data.Complex (Complex((:+))) import Data.Traversable (traverse) import Data.Monoid (Monoid(mempty,mappend)) import Data.Semigroup (Semigroup((<>))) import Data.Eq.HT (equating) import qualified Test.QuickCheck as QC import Test.ChasingBottoms.IsBottom (isBottom) equalListWith :: (a -> a -> Bool) -> [a] -> [a] -> Bool equalListWith eq xs ys = and $ ListHT.takeWhileJust $ zipWith (\mx my -> case (mx,my) of (Nothing,Nothing) -> Nothing (Just x, Just y) -> Just $ eq x y _ -> Just False) (map Just xs ++ repeat Nothing) (map Just ys ++ repeat Nothing) equalVectorBody :: (Shape.C shape, Class.Floating a) => Array shape a -> Array shape a -> Bool equalVectorBody = getEqualArray $ Class.switchFloating (EqualArray $ equating Array.toList) (EqualArray $ equating Array.toList) (EqualArray $ equating Array.toList) (EqualArray $ equating Array.toList) newtype EqualArray f a = EqualArray {getEqualArray :: f a -> f a -> Bool} equalVector :: (Shape.C shape, Eq shape, Class.Floating a) => Array shape a -> Array shape a -> Bool equalVector x y = if Array.shape x == Array.shape y then equalVectorBody x y else error "equalArray: shapes mismatch" equalArray :: (Shape.C shape, Eq shape, Class.Floating a) => ArrayMatrix shape a -> ArrayMatrix shape a -> Bool equalArray x y = equalVector (ArrMatrix.toVector x) (ArrMatrix.toVector y) equalMatrix :: (ArrMatrix.ShapeOrder shape, Eq shape, Class.Floating a) => ArrayMatrix shape a -> ArrayMatrix shape a -> Bool equalMatrix x y = equalArray (Matrix.adaptOrder y x) y approx :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> a -> a -> Bool approx tol x y = absolute (x-y) <= tol approxReal :: (Class.Real a) => a -> a -> a -> Bool approxReal tol x y = abs (x-y) <= tol approxVectorTol :: (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> Array shape a -> Array shape a -> Bool approxVectorTol tol x y = if Array.shape x == Array.shape y then and $ zipWith (approx tol) (Array.toList x) (Array.toList y) else error "approxArray: shapes mismatch" approxVector :: (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => Array shape a -> Array shape a -> Bool approxVector = approxVectorTol 1e-5 approxRealVectorTol :: (Shape.C shape, Eq shape, Class.Real a) => a -> Array shape a -> Array shape a -> Bool approxRealVectorTol tol x y = if Array.shape x == Array.shape y then and $ zipWith (approxReal tol) (Array.toList x) (Array.toList y) else error "approxRealArray: shapes mismatch" approxArrayTol :: (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> ArrayMatrix shape a -> ArrayMatrix shape a -> Bool approxArrayTol tol x y = approxVectorTol tol (ArrMatrix.toVector x) (ArrMatrix.toVector y) approxArray :: (Shape.C shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ArrayMatrix shape a -> ArrayMatrix shape a -> Bool approxArray x y = approxVector (ArrMatrix.toVector x) (ArrMatrix.toVector y) approxMatrix :: (ArrMatrix.ShapeOrder shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> ArrayMatrix shape a -> ArrayMatrix shape a -> Bool approxMatrix tol x y = approxArrayTol tol x $ Matrix.adaptOrder x y maybeConjugate :: (Matrix.Complex typ, Class.Floating a) => HH.Conjugation -> Matrix typ a -> Matrix typ a maybeConjugate HH.NonConjugated = id maybeConjugate HH.Conjugated = Matrix.conjugate type NonEmptyInt = ():+:ShapeInt type EInt = Either () Int genReal :: (Class.Real a) => Integer -> QC.Gen a genReal n = fromInteger <$> QC.choose (-n,n) genComplex :: (Class.Real a) => Integer -> QC.Gen (Complex a) genComplex n = liftA2 (Complex.:+) (genReal n) (genReal n) genElement :: (Class.Floating a) => Integer -> QC.Gen a genElement n = Class.switchFloating (genReal n) (genReal n) (genComplex n) (genComplex n) genVector :: (Shape.C shape, Class.Floating a) => Integer -> shape -> QC.Gen (Array shape a) genVector maxElem shape = Array.fromList shape <$> replicateM (Shape.size shape) (genElement maxElem) genArray :: (Shape.C shape, Class.Floating a) => Integer -> shape -> QC.Gen (ArrayMatrix shape a) genArray maxElem shape = fmap ArrMatrix.lift0 $ genVector maxElem shape genArrayIndexed :: (Shape.Indexed shape, Class.Floating a) => shape -> (Shape.Index shape -> QC.Gen a) -> QC.Gen (ArrayMatrix shape a) genArrayIndexed shape f = ArrMatrix.lift0 . Array.fromList shape <$> traverse f (Shape.indices shape) genArrayExtraDiag :: (Shape.Indexed shape, Shape.Index shape ~ (i,i), Eq i, Class.Floating a) => Integer -> shape -> (i -> QC.Gen a) -> QC.Gen (ArrayMatrix shape a) genArrayExtraDiag maxElem shape diag = genArrayIndexed shape $ \(r,c) -> if r==c then diag r else genElement maxElem select :: [a] -> QC.Gen (a, [a]) select = QC.elements . ListHT.removeEach genDistinct :: (Class.Floating a, RealOf a ~ ar, Class.Real ar) => Integer -> Integer -> ShapeInt -> QC.Gen (Vector ShapeInt a) genDistinct maxElemS maxElemD size@(Shape.ZeroBased n) = do let range k = map fromInteger [(-k)..k] fmap (Vector.fromList size) $ MS.evalStateT (replicateM n $ MS.StateT select) $ Class.switchFloating (range maxElemS) (range maxElemD) (liftA2 (:+) (range maxElemS) (range maxElemS)) (liftA2 (:+) (range maxElemD) (range maxElemD)) genOrder :: QC.Gen Order genOrder = QC.elements [RowMajor, ColumnMajor] invertible :: (Matrix.Determinant typ, Class.Floating a, RealOf a ~ ar, Class.Real ar) => Matrix typ a -> Bool invertible a = absolute (Matrix.determinant a) > 0.1 fullRankTall :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Class.Real ar) => Matrix.Tall height width a -> Bool fullRankTall a = Ortho.determinantAbsolute a > 0.1 isIdentity :: (ArrMatrix.SquareShape shape, ArrMatrix.ShapeOrder shape, Eq shape, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> ArrayMatrix shape a -> Bool isIdentity tol eye = approxArrayTol tol eye (Matrix.identityFrom eye) isUnitary :: (Extent.C vert, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> Matrix.Full vert Extent.Small ShapeInt ShapeInt a -> Bool isUnitary tol = isIdentity tol . Herm.gramian . Matrix.fromFull addMatrices :: (ArrMatrix.Homogeneous sh, Eq sh, Class.Floating a) => sh -> [ArrayMatrix sh a] -> ArrayMatrix sh a addMatrices sh = foldl (ArrMatrix.lift2 Vector.add) (ArrMatrix.zero sh) infixl 3 !||| infixl 2 !=== (!|||) :: (Shape.C height, Eq height, Shape.C widthA, Shape.C widthB, Class.Floating a) => Matrix.General height widthA a -> Matrix.General height widthB a -> Matrix.General height (widthA:+:widthB) a (!|||) = Matrix.beside Matrix.leftBias Extent.appendAny (!===) :: (Shape.C width, Eq width, Shape.C heightA, Shape.C heightB, Class.Floating a) => Matrix.General heightA width a -> Matrix.General heightB width a -> Matrix.General (heightA:+:heightB) width a (!===) = Matrix.above Matrix.leftBias Extent.appendAny newtype Tagged tag a = Tagged a deriving (Show) type TaggedGen tag a = Tagged tag (QC.Gen a) instance Functor (Tagged tag) where fmap f (Tagged a) = Tagged (f a) instance Applicative (Tagged tag) where pure = Tagged Tagged f <*> Tagged a = Tagged (f a) checkForAllPlain :: (Show a, QC.Testable test) => TaggedGen tag a -> (a -> test) -> Tagged tag QC.Property checkForAllPlain (Tagged gen) test = Tagged $ QC.forAll gen test checkForAll :: (Show a, QC.Testable test) => TaggedGen tag (a, Match) -> (a -> test) -> Tagged tag QC.Property checkForAll taggedGen test = checkForAllPlain taggedGen $ \(a,match) -> case match of Match -> QC.property $ test a Mismatch -> QC.property $ isBottom $ test a {- | In @DontForceMatch@ mode the test generators may ignore generating matching dimensions. If dimensions actually mismatch, a @Mismatch@ value is returned. In this case the test driver asserts that the test routine is aborted with an error. However, a typical test type might be \"generic implementation = specialized implementation\". If the generic implementation correctly checks the sizes, then the tester cannot detect a missing check in the specialized implementation. So far the proposed way to avoid this problem is to add a test that relies solely on the function to be tested. If you have no better idea, compare an implementation with itself. -} data Match = Mismatch | Match deriving (Eq, Show) instance Semigroup Match where (<>) = mappend instance Monoid Match where mempty = Match mappend Match Match = Match mappend _ _ = Mismatch prefix :: String -> [(String, test)] -> [(String, test)] prefix msg = map (\(str,test) -> (msg ++ "." ++ str, test))