{-# OPTIONS_GHC -Wall #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module IsClose ( -- Tolerance type Tol (..) -- AllClose class , AllClose (..) , allCloseRawUnit , allCloseRawRealFrac , allCloseRawRealFloat -- * Re-exports , Default (..) -- * HUnit , assertAllClose ) where import Data.Default.Class import Data.List.NonEmpty (NonEmpty (..)) import Data.Map (Map) import qualified Data.Map as Map import Data.Monoid import Data.Semigroup import qualified Numeric.Optimization.MIP as MIP import GHC.Stack (HasCallStack) import Test.HUnit import Text.Printf -- ------------------------------------------------------------------------ -- | Tolerance -- -- Values @a@ and @b@ are considered /close/ if @abs (a - b) <= atol + rtol * abs b@. data Tol a = Tol { rtol :: a -- ^ The relative tolerance parameter (default: @1e-05@) , atol :: a -- ^ The absolute tolerance parameter (default: @1e-08@) , equalNan :: Bool -- ^ Whether to compare NaN’s as equal (default: @False@) } deriving (Show) instance RealFrac a => Default (Tol a) where def = Tol { rtol = 1e-05 , atol = 1e-08 , equalNan = False } -- ------------------------------------------------------------------------ class Real r => AllClose r a where -- | Returns number of mismatches, number of elements, maximal absolute difference, and maximal relative difference. -- Returns @'Ap' 'Nothing'@ if given values are incomparable. allCloseRaw :: Tol r -> a -> a -> Ap Maybe (Sum Int, Sum Int, Max r, Max r) -- | Returns 'True' if the two arrays are equal within the given tolerance; 'False' otherwise. allClose :: Tol r -> a -> a -> Bool allClose tol x y = case getAp (allCloseRaw tol x y) of Nothing -> False Just (Sum numMismatched, _, _, _) -> numMismatched == 0 allCloseRawRealFrac :: RealFrac r => Tol r -> r -> r -> Ap Maybe (Sum Int, Sum Int, Max r, Max r) allCloseRawRealFrac t a b = Ap $ Just $ ( Sum $ if abs (a - b) <= atol t + rtol t * abs b then 0 else 1 , Sum 1 , Max (abs (a - b)) , Max (abs (a - b) / abs b) ) allCloseRawRealFloat :: RealFloat r => Tol r -> r -> r -> Ap Maybe (Sum Int, Sum Int, Max r, Max r) allCloseRawRealFloat t a b | isNaN a /= isNaN b = Ap Nothing | otherwise = Ap $ Just $ ( Sum $ if (equalNan t && isNaN a && isNaN b) || a == b || abs (a - b) <= atol t + rtol t * abs b then 0 else 1 , Sum 1 , Max (abs (a - b)) , Max (abs (a - b) / abs b) ) allCloseRawUnit :: Num r => Ap Maybe (Sum Int, Sum Int, Max r, Max r) allCloseRawUnit = Ap (Just (Sum 0, Sum 0, Max 0, Max 0)) instance AllClose Rational Rational where allCloseRaw = allCloseRawRealFrac instance AllClose Double Double where allCloseRaw = allCloseRawRealFloat instance (AllClose r a) => AllClose r (Maybe a) where allCloseRaw tol (Just a) (Just b) = allCloseRaw tol a b allCloseRaw _ Nothing Nothing = allCloseRawUnit allCloseRaw _ _ _ = Ap Nothing instance (AllClose r v) => AllClose r [v] where allCloseRaw tol xs ys | length xs == length ys = sconcat (allCloseRawUnit :| [allCloseRaw tol a b | (a,b) <- zip xs ys]) | otherwise = Ap Nothing instance (Ord k, AllClose r v) => AllClose r (Map k v) where allCloseRaw tol m1 m2 | Map.keys m1 == Map.keys m2 = sconcat (allCloseRawUnit :| [allCloseRaw tol a b | (a,b) <- zip (Map.elems m1) (Map.elems m2)]) | otherwise = Ap Nothing instance (Real r, AllClose r r) => AllClose r (MIP.Solution r) where allCloseRaw tol a b | MIP.solStatus a == MIP.solStatus b = allCloseRaw tol (MIP.solObjectiveValue a) (MIP.solObjectiveValue b) <> allCloseRaw tol (MIP.solVariables a) (MIP.solVariables b) | otherwise = Ap Nothing -- ------------------------------------------------------------------------ -- | Assert that two objects are equal up to desired tolerance. assertAllClose :: (HasCallStack, AllClose r a, Show r, Show a) => Tol r -> a -- ^ actual -> a -- ^ desired -> Assertion assertAllClose tol a b = case getAp (allCloseRaw tol a b) of Nothing -> assertString $ unlines $ header ++ ["x and y nan location mismatch:"] ++ footer Just (Sum numMismatch, Sum numTotal, Max absDiff, Max relDiff) | numMismatch == 0 -> return () | otherwise -> assertString $ unlines $ header ++ [ printf "Mismatched elements: %d / %d (%f%%)" numMismatch numTotal (fromIntegral numMismatch * 100 / fromIntegral numTotal :: Double) , " Max absolute difference: " ++ show absDiff , " Max relative difference: " ++ show relDiff ] ++ footer where header, footer :: [String] header = [printf "Not equal to tolerance rtol=%s, atol=%s" (show (rtol tol)) (show (atol tol)), ""] footer = [" x: " ++ show a, " y: " ++ show b] -- ------------------------------------------------------------------------