```-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SafeInt
-- Copyright   :  (c) 2010 Well-Typed LLP
--
-- Maintainer  :  Andres Loeh <andres@well-typed.com>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- Defines a variant of Haskell's Int type that is overflow-checked. If
-- an overflow or arithmetic error occurs, a run-time exception is thrown.
--
--------------------------------------------------------------------------

{-# LANGUAGE MagicHash, UnboxedTuples #-}

module Data.SafeInt (SafeInt(..), fromSafe, toSafe) where

import GHC.Prim
import GHC.Base
import GHC.Err
import GHC.Num
import GHC.Word
import GHC.Real
import GHC.Types

newtype SafeInt = SI Int

fromSafe :: SafeInt -> Int
fromSafe (SI x) = x

toSafe :: Int -> SafeInt
toSafe x = SI x

instance Show SafeInt where

showsPrec p x = showsPrec p (fromSafe x)

readsPrec p xs = [ (toSafe x, r) | (x, r) <- readsPrec p xs ]

instance Eq SafeInt where

SI x == SI y = eqInt x y
SI x /= SI y = neInt x y

instance Ord SafeInt where

SI x <  SI y = ltInt x y
SI x <= SI y = leInt x y
SI x >  SI y = gtInt x y
SI x >= SI y = geInt x y

-- | In the `Num' instance, we plug in our own addition, multiplication
-- and subtraction function that perform overflow-checking.
instance Num SafeInt where

(+)               = plusSI
(*)               = timesSI
(-)               = minusSI
negate (SI y)
| y == minInt   = overflowError
| otherwise     = SI (negate y)
abs x
| x >= 0        = x
| otherwise     = negate x
signum x | x > 0  = 1
signum 0          = 0
signum _          = -1
fromInteger x
| x > maxBoundInteger || x < minBoundInteger
= overflowError
| otherwise     = SI (fromInteger x)

maxBoundInteger :: Integer
maxBoundInteger = toInteger maxInt

minBoundInteger :: Integer
minBoundInteger = toInteger minInt

instance Bounded SafeInt where

minBound = SI minInt
maxBound = SI maxInt

instance Enum SafeInt where

succ (SI x) = SI (succ x)
pred (SI x) = SI (pred x)
toEnum                = SI

{-# INLINE enumFrom #-}
enumFrom (SI (I# x)) = eftInt x maxInt#
where !(I# maxInt#) = maxInt

{-# INLINE enumFromTo #-}
enumFromTo (SI (I# x)) (SI (I# y)) = eftInt x y

{-# INLINE enumFromThen #-}
enumFromThen (SI (I# x1)) (SI (I# x2)) = efdInt x1 x2

{-# INLINE enumFromThenTo #-}
enumFromThenTo (SI (I# x1)) (SI (I# x2)) (SI (I# y)) = efdtInt x1 x2 y

-- The following code is copied/adapted from GHC.Enum.

{-# RULES
"eftInt"        [~1] forall x y. eftInt x y = build (\ c n -> eftIntFB c n x y)
"eftIntList"    [1] eftIntFB  (:) [] = eftInt
#-}

eftInt :: Int# -> Int# -> [SafeInt]
-- [x1..x2]
eftInt x0 y | x0 ># y    = []
| otherwise = go x0
where
go x = SI (I# x) : if x ==# y then [] else go (x +# 1#)

{-# INLINE [0] eftIntFB #-}
eftIntFB :: (SafeInt -> r -> r) -> r -> Int# -> Int# -> r
eftIntFB c n x0 y | x0 ># y    = n
| otherwise = go x0
where
go x = SI (I# x) `c` if x ==# y then n else go (x +# 1#)
-- Watch out for y=maxBound; hence ==, not >
-- Be very careful not to have more than one "c"
-- so that when eftInfFB is inlined we can inline
-- whatever is bound to "c"

{-# RULES
"efdtInt"       [~1] forall x1 x2 y.
efdtInt x1 x2 y = build (\ c n -> efdtIntFB c n x1 x2 y)
"efdtIntUpList" [1]  efdtIntFB (:) [] = efdtInt
#-}

efdInt :: Int# -> Int# -> [SafeInt]
-- [x1,x2..maxInt]
efdInt x1 x2
| x2 >=# x1 = case maxInt of I# y -> efdtIntUp x1 x2 y
| otherwise = case minInt of I# y -> efdtIntDn x1 x2 y

efdtInt :: Int# -> Int# -> Int# -> [SafeInt]
-- [x1,x2..y]
efdtInt x1 x2 y
| x2 >=# x1 = efdtIntUp x1 x2 y
| otherwise = efdtIntDn x1 x2 y

{-# INLINE [0] efdtIntFB #-}
efdtIntFB :: (SafeInt -> r -> r) -> r -> Int# -> Int# -> Int# -> r
efdtIntFB c n x1 x2 y
| x2 >=# x1  = efdtIntUpFB c n x1 x2 y
| otherwise  = efdtIntDnFB c n x1 x2 y

-- Requires x2 >= x1
efdtIntUp :: Int# -> Int# -> Int# -> [SafeInt]
efdtIntUp x1 x2 y    -- Be careful about overflow!
| y <# x2   = if y <# x1 then [] else [SI (I# x1)]
| otherwise = -- Common case: x1 <= x2 <= y
let !delta = x2 -# x1 -- >= 0
!y' = y -# delta  -- x1 <= y' <= y; hence y' is representable

-- Invariant: x <= y
-- Note that: z <= y' => z + delta won't overflow
-- so we are guaranteed not to overflow if/when we recurse
go_up x | x ># y'  = [SI (I# x)]
| otherwise = SI (I# x) : go_up (x +# delta)
in SI (I# x1) : go_up x2

-- Requires x2 >= x1
efdtIntUpFB :: (SafeInt -> r -> r) -> r -> Int# -> Int# -> Int# -> r
efdtIntUpFB c n x1 x2 y    -- Be careful about overflow!
| y <# x2   = if y <# x1 then n else SI (I# x1) `c` n
| otherwise = -- Common case: x1 <= x2 <= y
let !delta = x2 -# x1 -- >= 0
!y' = y -# delta  -- x1 <= y' <= y; hence y' is representable

-- Invariant: x <= y
-- Note that: z <= y' => z + delta won't overflow
-- so we are guaranteed not to overflow if/when we recurse
go_up x | x ># y'   = SI (I# x) `c` n
| otherwise = SI (I# x) `c` go_up (x +# delta)
in SI (I# x1) `c` go_up x2

-- Requires x2 <= x1
efdtIntDn :: Int# -> Int# -> Int# -> [SafeInt]
efdtIntDn x1 x2 y    -- Be careful about underflow!
| y ># x2   = if y ># x1 then [] else [SI (I# x1)]
| otherwise = -- Common case: x1 >= x2 >= y
let !delta = x2 -# x1 -- <= 0
!y' = y -# delta  -- y <= y' <= x1; hence y' is representable

-- Invariant: x >= y
-- Note that: z >= y' => z + delta won't underflow
-- so we are guaranteed not to underflow if/when we recurse
go_dn x | x <# y'  = [SI (I# x)]
| otherwise = SI (I# x) : go_dn (x +# delta)
in SI (I# x1) : go_dn x2

-- Requires x2 <= x1
efdtIntDnFB :: (SafeInt -> r -> r) -> r -> Int# -> Int# -> Int# -> r
efdtIntDnFB c n x1 x2 y    -- Be careful about underflow!
| y ># x2 = if y ># x1 then n else SI (I# x1) `c` n
| otherwise = -- Common case: x1 >= x2 >= y
let !delta = x2 -# x1 -- <= 0
!y' = y -# delta  -- y <= y' <= x1; hence y' is representable

-- Invariant: x >= y
-- Note that: z >= y' => z + delta won't underflow
-- so we are guaranteed not to underflow if/when we recurse
go_dn x | x <# y'   = SI (I# x) `c` n
| otherwise = SI (I# x) `c` go_dn (x +# delta)
in SI (I# x1) `c` go_dn x2

-- The following code is copied/adapted from GHC.Real.

instance Real SafeInt where

toRational (SI x) = toInteger x % 1

instance Integral SafeInt where

toInteger (SI (I# i)) = smallInteger i

SI a `quot` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  = SI (a `quotInt` b)

SI a `rem` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  = SI (a `remInt` b)

SI a `div` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  = SI (a `divInt` b)

SI a `mod` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  = SI (a `modInt` b)

SI a `quotRem` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  =  a `quotRemSafeInt` b

SI a `divMod` SI b
| b == 0                     = divZeroError
| a == minBound && b == (-1) = overflowError
| otherwise                  =  a `divModSafeInt` b

quotRemSafeInt :: Int -> Int -> (SafeInt, SafeInt)
quotRemSafeInt a@(I# _) b@(I# _) = (SI (a `quotInt` b), SI (a `remInt` b))

divModSafeInt ::  Int -> Int -> (SafeInt, SafeInt)
divModSafeInt x@(I# _) y@(I# _) = (SI (x `divInt` y), SI (x `modInt` y))

plusSI :: SafeInt -> SafeInt -> SafeInt
plusSI (SI (I# x#)) (SI (I# y#)) =
(# r#, 0# #) -> SI (I# r#)
(# _ , _  #) -> overflowError

minusSI :: SafeInt -> SafeInt -> SafeInt
minusSI (SI (I# x#)) (SI (I# y#)) =
case subIntC# x# y# of
(# r#, 0# #) -> SI (I# r#)
(# _ , _  #) -> overflowError

timesSI :: SafeInt -> SafeInt -> SafeInt
timesSI (SI (I# x#)) (SI (I# y#)) =
case mulIntMayOflo# x# y# of
0# -> SI (I# (x# *# y#))
_  -> overflowError

{-# RULES
"fromIntegral/Int->SafeInt"     fromIntegral = toSafe
"fromIntegral/SafeInt->SafeInt" fromIntegral = id :: SafeInt -> SafeInt
#-}

-- Specialized versions of several functions. They're specialized for
-- Int in the GHC base libraries. We try to get the same effect by
-- including specialized code and adding a rewrite rule.

sumS :: [SafeInt] -> SafeInt
sumS     l       = sum' l 0
where
sum' []     a = a
sum' (x:xs) a = sum' xs (a + x)

productS :: [SafeInt] -> SafeInt
productS l       = prod l 1
where
prod []     a = a
prod (x:xs) a = prod xs (a*x)

{-# RULES
"sum/SafeInt"          sum = sumS;
"product/SafeInt"      product = productS
#-}

{-# RULES
"sum/SafeInt"          sum = sumS;
"product/SafeInt"      product = productS
#-}

lcmS :: SafeInt -> SafeInt -> SafeInt
lcmS _      (SI 0)  =  SI 0
lcmS (SI 0) _       =  SI 0
lcmS (SI x) (SI y)  =  abs (SI (x `quot` (gcd x y)) * SI y)

{-# RULES
"lcm/SafeInt"          lcm = lcmS;
"gcd/SafeInt"          gcd = \ (SI a) (SI b) -> SI (gcd a b)
#-}
```