{-# Language AllowAmbiguousTypes #-} module Data.Connection.Trip ( -- * Triple Trip(..) , tripl , tripr , unitl , unitr , counitl , counitr , strong' , choice' -- * Rounding , half , tied , above , below , roundOn , floorOn , ceilingOn , truncateOn ) where import Control.Category (Category) import Data.Bifunctor (bimap) import Data.Bool import Data.Connection.Conn import Data.Prd import Prelude hiding (until, Ord(..), Bounded) import qualified Control.Category as C --------------------------------------------------------------------- -- Adjoint triples --------------------------------------------------------------------- -- | An adjoint triple of Galois connections. -- -- An adjoint triple is a chain of connections of length 2: -- -- \(f \dashv g \dashv h \) -- -- For further information see 'Data.Connection.Property' and <https://ncatlab.org/nlab/show/adjoint+triple>. -- data Trip a b = Trip (a -> b) (b -> a) (a -> b) instance Category Trip where id = Trip id id id Trip f' g' h' . Trip f g h = Trip (f' . f) (g . g') (h' . h) tripl :: Trip a b -> Conn a b tripl (Trip f g _) = Conn f g tripr :: Trip a b -> Conn b a tripr (Trip _ g h) = Conn g h unitl :: Trip a b -> a -> a unitl = unit . tripl unitr :: Trip a b -> b -> b unitr = unit . tripr counitl :: Trip a b -> b -> b counitl = counit . tripl counitr :: Trip a b -> a -> a counitr = counit . tripr strong' :: Trip a b -> Trip c d -> Trip (a, c) (b, d) strong' (Trip ab ba ab') (Trip cd dc cd') = Trip f g h where f = bimap ab cd g = bimap ba dc h = bimap ab' cd' choice' :: Trip a b -> Trip c d -> Trip (Either a c) (Either b d) choice' (Trip ab ba ab') (Trip cd dc cd') = Trip f g h where f = either (Left . ab) (Right . cd) g = either (Left . ba) (Right . dc) h = either (Left . ab') (Right . cd') --------------------------------------------------------------------- -- Rounding --------------------------------------------------------------------- -- | Determine which half of the interval between two representations of /a/ a particular value lies. -- half :: (Num a, Prd a) => Trip a b -> a -> Maybe Ordering half t x = pcompare (x - unitl t x) (counitr t x - x) -- | Determine whether /x/ lies above the halfway point between two representations. -- -- @ 'above' t x '==' (x '-' 'unitl' t x) '`gt`' ('counitr' t x '-' x) @ -- above :: (Num a, Prd a) => Trip a b -> a -> Bool above t = maybe False (== GT) . half t -- | Determine whether /x/ lies below the halfway point between two representations. -- -- @ 'below' t x '==' (x '-' 'unitl' t x) '`lt`' ('counitr' t x '-' x) @ -- below :: (Num a, Prd a) => Trip a b -> a -> Bool below t = maybe False (== LT) . half t -- | Determine whether /x/ lies exactly halfway between two representations. -- -- @ 'tied' t x '==' (x '-' 'unitl' t x) '=~' ('counitr' t x '-' x) @ -- tied :: (Num a, Prd a) => Trip a b -> a -> Bool tied t = maybe False (== EQ) . half t -- | Return the nearest value to x. -- -- If x lies halfway between two values, then return the value with the -- larger absolute value (i.e. round away from zero). -- -- @ roundOn C.id == id @ -- roundOn :: (Prd a, Num a) => Trip a b -> a -> b roundOn t x | above t x = ceilingOn t x -- upper half interval | below t x = floorOn t x -- lower half interval | otherwise = bool (ceilingOn t x) (floorOn t x) $ x <= 0 -- @ floorOn C.id == id @ floorOn :: Trip a b -> a -> b floorOn = connr . tripr -- @ ceilingOn C.id == id @ ceilingOn :: Trip a b -> a -> b ceilingOn = connl . tripl -- @ truncateOn C.id == id @ truncateOn :: (Num a, Prd a) => Trip a b -> a -> b truncateOn t x = bool (ceilingOn t x) (floorOn t x) $ x >= 0 --------------------------------------------------------------------- -- Internal --------------------------------------------------------------------- {- -- | The four primary IEEE rounding modes. -- -- See <https://en.wikipedia.org/wiki/Rounding>. -- data Mode = RNZ -- ^ round to nearest with ties towards 0 | RTP -- ^ round towards pos infinity | RTN -- ^ round towards neg infinity | RTZ -- ^ round towards 0 deriving (Eq, Show) -- >>> addOn ratf32 RTN 1 2 -- 3.0 -- minSubf addOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> b addOn t@(Trip _ f _) rm x y = rnd t rm (addSgn t rm x y) (f x + f y) negOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b negOn t@(Trip _ f _) rm x = rnd t rm (neg' t rm x) (0 - f x) subOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> b subOn t@(Trip _ f _) rm x y = rnd t rm (subSgn t rm x y) (f x - f y) mulOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> b mulOn t@(Trip _ f _) rm x y = rnd t rm (xorSgn t rm x y) (f x * f y) {- big = shiftf (-1) maximal λ> fmaOn ratf32 RTN big 2 (-big) 3.4028235e38 λ> big * 2 - big Infinity -} fmaOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> b -> b fmaOn t@(Trip _ f _) rm x y z = rnd t rm (fmaSgn t rm x y z) $ f x * f y + f z {- λ> remOn @Int RTP 17 5 -3 λ> remOn @Int RNZ 17 5 2 -} remOn :: (Prd a, Prd b, Fractional a) => Trip a b -> Mode -> b -> b -> b remOn t rm x y = fmaOn t rm (negOn t rm $ divOn t rm x y) y x {- λ> divOn @Int RNZ 17 5 3 λ> divOn @Int RTP 17 5 4 -} -- when pos numbers are divided by −0 we return minus infinity rather than pos: -- >>> divOn C.id RNZ 1 (shiftf (-1) 0) -- -Infinity divOn :: (Prd a, Prd b, Fractional a) => Trip a b -> Mode -> b -> b -> b divOn t@(Trip _ f _) rm x y = rnd t rm (xorSgn t rm x y) (f x / f y) -- requires that sign be flipped back in /a/. divOn' :: (Prd a, Prd b, Fractional a) => Trip a b -> Mode -> b -> b -> b divOn' t@(Trip _ f _) rm x y | xorSgn t rm x y = rnd t rm True (negate $ f x / f y) | otherwise = rnd t rm False (f x / f y) {- rndOn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b rndOn t@(Trip f g h) rm x = rnd t rm (neg' t rm x) (g x) -} -- Determine the sign of 0 when /a/ contains signed 0s rsz :: (Prd a, Prd b) => Trip a b -> Bool -> a -> b rsz t = bool (floorOn t) (ceilingOn t) rnd :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> Bool -> a -> b rnd t RNZ s x = bool (roundOn t x) (rsz t s x) $ x =~ 0 rnd t RTP s x = bool (ceilingOn t x) (rsz t s x) $ x =~ 0 rnd t RTN s x = bool (floorOn t x) (rsz t s x) $ x =~ 0 rnd t RTZ s x = bool (truncateOn t x) (rsz t s x) $ x =~ 0 neg' :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> Bool neg' t rm x = x < rnd t rm False 0 --pos' :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> Bool --pos' t rm x = x > rnd t rm False 0 -- | Determine signed-0 behavior under addition. addSgn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> Bool addSgn t rm x y | rm == RTN = neg' t rm x || neg' t rm y | otherwise = neg' t rm x && neg' t rm y subSgn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> Bool subSgn t rm x y = not (addSgn t rm x y) -- | Determine signed-0 behavior under multiplication and division. xorSgn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> Bool xorSgn t rm x y = neg' t rm x `xor` neg' t rm y fmaSgn :: (Prd a, Prd b, Num a) => Trip a b -> Mode -> b -> b -> b -> Bool fmaSgn t rm x y z = addSgn t rm (mulOn t rm x y) z -}