{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} module Numeric.Rounded.Hardware.Interval ( Interval(..) , increasing , maxI , minI , powInt , null , inf , sup , width , widthUlp , hull , intersection ) where import Control.DeepSeq (NFData (..)) import Control.Monad import Control.Monad.ST import qualified Data.Array.Base as A import Data.Coerce import Data.Ix import Data.Primitive import qualified Data.Vector.Generic as VG import qualified Data.Vector.Generic.Mutable as VGM import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import GHC.Float (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import Numeric.Rounded.Hardware.Internal import qualified Numeric.Rounded.Hardware.Interval.Class as C import qualified Numeric.Rounded.Hardware.Interval.NonEmpty as NE import Prelude hiding (null) data Interval a = I !(Rounded 'TowardNegInf a) !(Rounded 'TowardInf a) | Empty deriving (Show,Generic) instance NFData a => NFData (Interval a) increasing :: (forall r. Rounding r => Rounded r a -> Rounded r a) -> Interval a -> Interval a increasing f (I a b) = I (f a) (f b) increasing _ Empty = Empty {-# INLINE increasing #-} instance (Num a, RoundedRing a) => Num (Interval a) where (+) = liftBinaryNE (+) (-) = liftBinaryNE (-) negate = liftUnaryNE negate (*) = liftBinaryNE (*) abs = liftUnaryNE abs signum = liftUnaryNE signum fromInteger x = case intervalFromInteger x of (y, y') -> I y y' {-# INLINE (+) #-} {-# INLINE (-) #-} {-# INLINE negate #-} {-# INLINE (*) #-} {-# INLINE abs #-} {-# INLINE signum #-} {-# INLINE fromInteger #-} instance (Num a, RoundedFractional a) => Fractional (Interval a) where recip = liftUnaryNE recip (/) = liftBinaryNE (/) fromRational x = case intervalFromRational x of (y, y') -> I y y' {-# INLINE recip #-} {-# INLINE (/) #-} {-# INLINE fromRational #-} maxI :: Ord a => Interval a -> Interval a -> Interval a maxI (I a a') (I b b') = I (max a b) (max a' b') maxI _ _ = Empty {-# SPECIALIZE maxI :: Interval Float -> Interval Float -> Interval Float #-} {-# SPECIALIZE maxI :: Interval Double -> Interval Double -> Interval Double #-} minI :: Ord a => Interval a -> Interval a -> Interval a minI (I a a') (I b b') = I (min a b) (min a' b') minI _ _ = Empty {-# SPECIALIZE minI :: Interval Float -> Interval Float -> Interval Float #-} {-# SPECIALIZE minI :: Interval Double -> Interval Double -> Interval Double #-} powInt :: (Ord a, Num a, RoundedRing a) => Interval a -> Int -> Interval a powInt (I a a') n | odd n || 0 <= a = I (a^n) (a'^n) | a' <= 0 = I ((coerce (abs a'))^n) ((coerce (abs a))^n) | otherwise = I 0 (max ((coerce (abs a))^n) (a'^n)) powInt Empty _ = Empty {-# SPECIALIZE powInt :: Interval Float -> Int -> Interval Float #-} {-# SPECIALIZE powInt :: Interval Double -> Int -> Interval Double #-} null :: Interval a -> Bool null Empty = True null _ = False inf :: Interval a -> Rounded 'TowardNegInf a inf (I x _) = x inf _ = error "empty interval" sup :: Interval a -> Rounded 'TowardInf a sup (I _ y) = y sup _ = error "empty interval" width :: (Num a, RoundedRing a) => Interval a -> Rounded 'TowardInf a width (I x y) = y - coerce x width Empty = 0 widthUlp :: (RealFloat a) => Interval a -> Maybe Integer widthUlp (I x y) = distanceUlp (getRounded x) (getRounded y) widthUlp Empty = Just 0 hull :: RoundedRing a => Interval a -> Interval a -> Interval a hull (I x y) (I x' y') = I (min x x') (max y y') hull Empty v = v hull u Empty = u intersection :: RoundedRing a => Interval a -> Interval a -> Interval a intersection (I x y) (I x' y') | getRounded x'' <= getRounded y'' = I x'' y'' where x'' = max x x' y'' = min y y' intersection _ _ = Empty liftUnaryNE :: (NE.Interval a -> NE.Interval a) -> Interval a -> Interval a liftUnaryNE f (I x x') = case f (NE.I x x') of NE.I y y' -> I y y' liftUnaryNE _f Empty = Empty {-# INLINE [1] liftUnaryNE #-} liftBinaryNE :: (NE.Interval a -> NE.Interval a -> NE.Interval a) -> Interval a -> Interval a -> Interval a liftBinaryNE f (I x x') (I y y') = case f (NE.I x x') (NE.I y y') of NE.I z z' -> I z z' liftBinaryNE _f _ _ = Empty {-# INLINE [1] liftBinaryNE #-} instance (Num a, RoundedFractional a, RoundedSqrt a, Eq a, RealFloat a, RealFloatConstants a) => Floating (Interval a) where pi = I pi_down pi_up exp = liftUnaryNE exp log = liftUnaryNE log sqrt = liftUnaryNE sqrt (**) = liftBinaryNE (**) logBase = liftBinaryNE logBase sin = liftUnaryNE sin cos = liftUnaryNE cos tan = liftUnaryNE tan asin = liftUnaryNE asin acos = liftUnaryNE acos atan = liftUnaryNE atan sinh = liftUnaryNE sinh cosh = liftUnaryNE cosh tanh = liftUnaryNE tanh asinh = liftUnaryNE asinh acosh = liftUnaryNE acosh atanh = liftUnaryNE atanh log1p = liftUnaryNE log1p expm1 = liftUnaryNE expm1 log1pexp = liftUnaryNE log1pexp log1mexp = liftUnaryNE log1mexp {-# INLINE exp #-} {-# INLINE log #-} {-# INLINE sqrt #-} {-# INLINE (**) #-} {-# INLINE logBase #-} {-# INLINE sin #-} {-# INLINE cos #-} {-# INLINE tan #-} {-# INLINE asin #-} {-# INLINE acos #-} {-# INLINE atan #-} {-# INLINE sinh #-} {-# INLINE cosh #-} {-# INLINE tanh #-} {-# INLINE asinh #-} {-# INLINE acosh #-} {-# INLINE atanh #-} {-# INLINE log1p #-} {-# INLINE expm1 #-} {-# INLINE log1pexp #-} {-# INLINE log1mexp #-} instance (Num a, RoundedRing a, RealFloat a) => C.IsInterval (Interval a) where type EndPoint (Interval a) = a makeInterval = I width = width withEndPoints f (I x y) = f x y withEndPoints _ Empty = Empty hull = hull intersection = intersection maybeIntersection x y = case intersection x y of Empty -> Nothing z -> Just z equalAsSet (I x y) (I x' y') = x == x' && y == y' equalAsSet Empty Empty = True equalAsSet _ _ = False subset (I x y) (I x' y') = x' <= x && y <= y' subset Empty _ = True subset I{} Empty = False weaklyLess (I x y) (I x' y') = x <= x' && y <= y' weaklyLess Empty Empty = True weaklyLess _ _ = False precedes (I _ y) (I x' _) = getRounded y <= getRounded x' precedes _ _ = True interior (I x y) (I x' y') = getRounded x' <# getRounded x && getRounded y <# getRounded y' where s <# t = s < t || (s == t && isInfinite s) interior Empty _ = True interior I{} Empty = False strictLess (I x y) (I x' y') = getRounded x <# getRounded x' && getRounded y <# getRounded y' where s <# t = s < t || (s == t && isInfinite s) strictLess Empty Empty = True strictLess _ _ = False strictPrecedes (I _ y) (I x' _) = getRounded y < getRounded x' strictPrecedes _ _ = True disjoint (I x y) (I x' y') = getRounded y < getRounded x' || getRounded y' < getRounded x disjoint _ _ = True -- -- Instance for Data.Vector.Unboxed.Unbox -- newtype instance VUM.MVector s (Interval a) = MV_Interval (VUM.MVector s (a, a)) newtype instance VU.Vector (Interval a) = V_Interval (VU.Vector (a, a)) intervalToPair :: Fractional a => Interval a -> (a, a) intervalToPair (I (Rounded x) (Rounded y)) = (x, y) intervalToPair Empty = (1/0, -1/0) {-# INLINE intervalToPair #-} pairToInterval :: Ord a => (a, a) -> Interval a pairToInterval (x, y) | y < x = Empty | otherwise = I (Rounded x) (Rounded y) {-# INLINE pairToInterval #-} instance (VU.Unbox a, Ord a, Fractional a) => VGM.MVector VUM.MVector (Interval a) where basicLength (MV_Interval mv) = VGM.basicLength mv basicUnsafeSlice i l (MV_Interval mv) = MV_Interval (VGM.basicUnsafeSlice i l mv) basicOverlaps (MV_Interval mv) (MV_Interval mv') = VGM.basicOverlaps mv mv' basicUnsafeNew l = MV_Interval <$> VGM.basicUnsafeNew l basicInitialize (MV_Interval mv) = VGM.basicInitialize mv basicUnsafeReplicate i x = MV_Interval <$> VGM.basicUnsafeReplicate i (intervalToPair x) basicUnsafeRead (MV_Interval mv) i = pairToInterval <$> VGM.basicUnsafeRead mv i basicUnsafeWrite (MV_Interval mv) i x = VGM.basicUnsafeWrite mv i (intervalToPair x) basicClear (MV_Interval mv) = VGM.basicClear mv basicSet (MV_Interval mv) x = VGM.basicSet mv (intervalToPair x) basicUnsafeCopy (MV_Interval mv) (MV_Interval mv') = VGM.basicUnsafeCopy mv mv' basicUnsafeMove (MV_Interval mv) (MV_Interval mv') = VGM.basicUnsafeMove mv mv' basicUnsafeGrow (MV_Interval mv) n = MV_Interval <$> VGM.basicUnsafeGrow mv n {-# INLINE basicLength #-} {-# INLINE basicUnsafeSlice #-} {-# INLINE basicOverlaps #-} {-# INLINE basicUnsafeNew #-} {-# INLINE basicInitialize #-} {-# INLINE basicUnsafeReplicate #-} {-# INLINE basicUnsafeRead #-} {-# INLINE basicUnsafeWrite #-} {-# INLINE basicClear #-} {-# INLINE basicSet #-} {-# INLINE basicUnsafeCopy #-} {-# INLINE basicUnsafeMove #-} {-# INLINE basicUnsafeGrow #-} instance (VU.Unbox a, Ord a, Fractional a) => VG.Vector VU.Vector (Interval a) where basicUnsafeFreeze (MV_Interval mv) = V_Interval <$> VG.basicUnsafeFreeze mv basicUnsafeThaw (V_Interval v) = MV_Interval <$> VG.basicUnsafeThaw v basicLength (V_Interval v) = VG.basicLength v basicUnsafeSlice i l (V_Interval v) = V_Interval (VG.basicUnsafeSlice i l v) basicUnsafeIndexM (V_Interval v) i = pairToInterval <$> VG.basicUnsafeIndexM v i basicUnsafeCopy (MV_Interval mv) (V_Interval v) = VG.basicUnsafeCopy mv v elemseq (V_Interval _) x y = x `seq` y {-# INLINE basicUnsafeFreeze #-} {-# INLINE basicUnsafeThaw #-} {-# INLINE basicLength #-} {-# INLINE basicUnsafeSlice #-} {-# INLINE basicUnsafeIndexM #-} {-# INLINE basicUnsafeCopy #-} {-# INLINE elemseq #-} instance (VU.Unbox a, Ord a, Fractional a) => VU.Unbox (Interval a) -- -- Instances for Data.Array.Unboxed -- instance (Prim a, Ord a, Fractional a) => A.MArray (A.STUArray s) (Interval a) (ST s) where getBounds (A.STUArray l u _ _) = return (l, u) getNumElements (A.STUArray _ _ n _) = return n -- newArray: Use default unsafeNewArray_ = A.newArray_ newArray_ bounds@(l,u) = do let n = rangeSize bounds arr@(MutableByteArray arr_) <- newByteArray (2 * sizeOf (undefined :: a) * n) setByteArray arr 0 (2 * n) (0 :: a) return (A.STUArray l u n arr_) unsafeRead (A.STUArray _ _ _ byteArr) i = do x <- readByteArray (MutableByteArray byteArr) (2 * i) y <- readByteArray (MutableByteArray byteArr) (2 * i + 1) return (pairToInterval (x, y)) unsafeWrite (A.STUArray _ _ _ byteArr) i e = do let (x, y) = intervalToPair e writeByteArray (MutableByteArray byteArr) (2 * i) x writeByteArray (MutableByteArray byteArr) (2 * i + 1) y instance (Prim a, Ord a, Fractional a) => A.IArray A.UArray (Interval a) where bounds (A.UArray l u _ _) = (l,u) numElements (A.UArray _ _ n _) = n unsafeArray bounds el = runST $ do marr <- A.newArray_ bounds forM_ el $ \(i,e) -> A.unsafeWrite marr i e A.unsafeFreezeSTUArray marr unsafeAt (A.UArray _ _ _ byteArr) i = let x = indexByteArray (ByteArray byteArr) (2 * i) y = indexByteArray (ByteArray byteArr) (2 * i + 1) in pairToInterval (x, y) -- unsafeReplace, unsafeAccum, unsafeAccumArray: Use default