{- |
Copyright   :  (c) Henning Thielemann 2007

Maintainer  :  haskell@henning-thielemann.de
Stability   :  stable
Portability :  Haskell 98

A type for non-negative numbers.
It performs a run-time check at construction time (i.e. at run-time)
and is a member of the non-negative number type class
'Numeric.NonNegative.Class.C'.
-}
module Numeric.NonNegative.Wrapper
   (T, fromNumber, fromNumberMsg, fromNumberClip, fromNumberUnsafe, toNumber,
    Int, Integer, Float, Double, Ratio, Rational) where

import qualified Numeric.NonNegative.Class as NonNeg

import Test.QuickCheck (Arbitrary(..))
import Numeric.NonNegative.Utility(mapPair, mapSnd)
import Control.Monad (liftM)

import qualified Data.Ratio as R
import qualified Prelude as P
import Prelude hiding (Int, Integer, Float, Double, Rational)


newtype T a = Cons {unwrap :: a}
   deriving (Eq, Ord)

instance Show a => Show (T a) where
   showsPrec p (Cons a) = showsPrec p a


{- |
Convert a number to a non-negative number.
If a negative number is given, an error is raised.
-}
fromNumber :: (Ord a, Num a) =>
      a
   -> T a
fromNumber = fromNumberMsg "fromNumber"

fromNumberMsg :: (Ord a, Num a) =>
      String  {- ^ name of the calling function to be used in the error message -}
   -> a
   -> T a
fromNumberMsg funcName x =
   if x>=0
     then Cons x
     else error (funcName++": negative number")

fromNumberWrap :: (Ord a, Num a) =>
      String
   -> a
   -> T a
fromNumberWrap funcName =
   fromNumberMsg ("NonNegative.Wrapper."++funcName)

{- |
Convert a number to a non-negative number.
A negative number will be replaced by zero.
Use this function with care since it may hide bugs.
-}
fromNumberClip :: (Ord a, Num a) =>
      a
   -> T a
fromNumberClip = Cons . max 0

{- |
Wrap a number into a non-negative number without doing checks.
This routine exists entirely for efficiency reasons
and must be used only in cases where you are absolutely sure,
that the input number is non-negative.
-}
fromNumberUnsafe ::
      a
   -> T a
fromNumberUnsafe = Cons

{-
export only this in order to disable direct access to the record field
by record update syntax
-}
toNumber :: T a -> a
toNumber = unwrap


{- |
Results are not checked for positivity.
-}
lift :: (a -> a) -> (T a -> T a)
lift f = Cons . f . toNumber

liftWrap :: (Ord a, Num a) => String -> (a -> a) -> (T a -> T a)
liftWrap msg f = fromNumberWrap msg . f . toNumber


{- |
Results are not checked for positivity.
-}
lift2 :: (a -> a -> a) -> (T a -> T a -> T a)
lift2 f (Cons x) (Cons y) = Cons $ f x y



instance (Ord a, Num a) => NonNeg.C (T a) where
   (Cons x) -| (Cons y) = fromNumberClip (x-y)

instance (Ord a, Num a) => Num (T a) where
   (+)    = lift2 (+)
   (Cons x) - (Cons y) = fromNumberWrap "-" (x-y)
   negate = liftWrap "negate" negate
   fromInteger x = fromNumberWrap "fromInteger" (fromInteger x)
   (*)    = lift2 (*)
   abs    = lift abs
   signum = lift signum

instance Real a => Real (T a) where
   toRational = toRational . toNumber

{- required for Integral instance -}
instance (Ord a, Num a, Enum a) => Enum (T a) where
   toEnum   = fromNumberWrap "toEnum" . toEnum
   fromEnum = fromEnum . toNumber

instance (Ord a, Num a, Bounded a) => Bounded (T a) where
   minBound = fromNumberClip minBound
   maxBound = fromNumberWrap "maxBound" maxBound

instance Integral a => Integral (T a) where
   toInteger = toInteger . toNumber
   quot = lift2 quot
   rem  = lift2 rem
   quotRem (Cons x) (Cons y) =
      mapPair (Cons, Cons) (quotRem x y)
   div  = lift2 div
   mod  = lift2 mod
   divMod (Cons x) (Cons y) =
      mapPair (Cons, Cons) (divMod x y)

instance (Ord a, Fractional a) => Fractional (T a) where
   fromRational = fromNumberWrap "fromRational" . fromRational
   (/) = lift2 (/)


instance (RealFrac a) => RealFrac (T a) where
   properFraction = mapSnd fromNumberUnsafe . properFraction . toNumber
   truncate = truncate . toNumber
   round    = round    . toNumber
   ceiling  = ceiling  . toNumber
   floor    = floor    . toNumber

instance (Ord a, Floating a) => Floating (T a) where
   pi = fromNumber pi
   exp  = lift exp
   sqrt = lift sqrt
   log  = liftWrap "log" log
   (**) = lift2 (**)
   logBase (Cons x) = liftWrap "logBase" (logBase x)
   sin = liftWrap "sin" sin
   tan = liftWrap "tan" tan
   cos = liftWrap "cos" cos
   asin = liftWrap "asin" asin
   atan = liftWrap "atan" atan
   acos = liftWrap "acos" acos
   sinh = liftWrap "sinh" sinh
   tanh = liftWrap "tanh" tanh
   cosh = liftWrap "cosh" cosh
   asinh = liftWrap "asinh" asinh
   atanh = liftWrap "atanh" atanh
   acosh = liftWrap "acosh" acosh


instance (Num a, Arbitrary a) => Arbitrary (T a) where
   arbitrary = liftM (Cons . abs) arbitrary
   coarbitrary = undefined


type Int      = T P.Int
type Integer  = T P.Integer
type Ratio a  = T (R.Ratio a)
type Rational = T P.Rational
type Float    = T P.Float
type Double   = T P.Double