{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}

{-|
Module      : Data.BitVector.Sized.Overflow
Copyright   : (c) Galois Inc. 2020
License     : BSD-3
Maintainer  : Ben Selfridge <benselfridge@galois.com>
Stability   : experimental
Portability : portable

This module provides alternative definitions of certain bitvector
functions that might produce signed or unsigned overflow. Instead of
producing a pure value, these versions produce the same value along
with overflow flags. We only provide definitions for operators that
might actually overflow.

-}

module Data.BitVector.Sized.Overflow
  ( Overflow(..)
  , UnsignedOverflow(..)
  , SignedOverflow(..)
  , ofUnsigned
  , ofSigned
  , ofResult
  -- * Overflowing bitwise operators
  , shlOf
  -- * Overflowing arithmetic operators
  , addOf
  , subOf
  , mulOf
  , squotOf
  , sremOf
  , sdivOf
  , smodOf
  ) where

import qualified Data.Bits as B
import Numeric.Natural
import GHC.TypeLits

import Data.Parameterized ( NatRepr )
import qualified Data.Parameterized.NatRepr as P

import Data.BitVector.Sized.Internal ( BV(..)
                                     , mkBV'
                                     , asUnsigned
                                     , asSigned
                                     , shiftAmount
                                     )


----------------------------------------
-- Unsigned and signed overflow datatypes

-- | Datatype representing the possibility of unsigned overflow.
data UnsignedOverflow = UnsignedOverflow
                      | NoUnsignedOverflow
  deriving (Int -> UnsignedOverflow -> ShowS
[UnsignedOverflow] -> ShowS
UnsignedOverflow -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnsignedOverflow] -> ShowS
$cshowList :: [UnsignedOverflow] -> ShowS
show :: UnsignedOverflow -> String
$cshow :: UnsignedOverflow -> String
showsPrec :: Int -> UnsignedOverflow -> ShowS
$cshowsPrec :: Int -> UnsignedOverflow -> ShowS
Show, UnsignedOverflow -> UnsignedOverflow -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UnsignedOverflow -> UnsignedOverflow -> Bool
$c/= :: UnsignedOverflow -> UnsignedOverflow -> Bool
== :: UnsignedOverflow -> UnsignedOverflow -> Bool
$c== :: UnsignedOverflow -> UnsignedOverflow -> Bool
Eq)

instance Semigroup UnsignedOverflow where
  UnsignedOverflow
NoUnsignedOverflow <> :: UnsignedOverflow -> UnsignedOverflow -> UnsignedOverflow
<> UnsignedOverflow
NoUnsignedOverflow = UnsignedOverflow
NoUnsignedOverflow
  UnsignedOverflow
_ <> UnsignedOverflow
_ = UnsignedOverflow
UnsignedOverflow

instance Monoid UnsignedOverflow where
  mempty :: UnsignedOverflow
mempty = UnsignedOverflow
NoUnsignedOverflow

-- | Datatype representing the possibility of signed overflow.
data SignedOverflow = SignedOverflow
                    | NoSignedOverflow
  deriving (Int -> SignedOverflow -> ShowS
[SignedOverflow] -> ShowS
SignedOverflow -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SignedOverflow] -> ShowS
$cshowList :: [SignedOverflow] -> ShowS
show :: SignedOverflow -> String
$cshow :: SignedOverflow -> String
showsPrec :: Int -> SignedOverflow -> ShowS
$cshowsPrec :: Int -> SignedOverflow -> ShowS
Show, SignedOverflow -> SignedOverflow -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignedOverflow -> SignedOverflow -> Bool
$c/= :: SignedOverflow -> SignedOverflow -> Bool
== :: SignedOverflow -> SignedOverflow -> Bool
$c== :: SignedOverflow -> SignedOverflow -> Bool
Eq)

instance Semigroup SignedOverflow where
  SignedOverflow
NoSignedOverflow <> :: SignedOverflow -> SignedOverflow -> SignedOverflow
<> SignedOverflow
NoSignedOverflow = SignedOverflow
NoSignedOverflow
  SignedOverflow
_ <> SignedOverflow
_ = SignedOverflow
SignedOverflow

instance Monoid SignedOverflow where
  mempty :: SignedOverflow
mempty = SignedOverflow
NoSignedOverflow

----------------------------------------
-- Overflow wrapper
-- | A value annotated with overflow information.
data Overflow a =
  Overflow UnsignedOverflow SignedOverflow a
  deriving (Int -> Overflow a -> ShowS
forall a. Show a => Int -> Overflow a -> ShowS
forall a. Show a => [Overflow a] -> ShowS
forall a. Show a => Overflow a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Overflow a] -> ShowS
$cshowList :: forall a. Show a => [Overflow a] -> ShowS
show :: Overflow a -> String
$cshow :: forall a. Show a => Overflow a -> String
showsPrec :: Int -> Overflow a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Overflow a -> ShowS
Show, Overflow a -> Overflow a -> Bool
forall a. Eq a => Overflow a -> Overflow a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Overflow a -> Overflow a -> Bool
$c/= :: forall a. Eq a => Overflow a -> Overflow a -> Bool
== :: Overflow a -> Overflow a -> Bool
$c== :: forall a. Eq a => Overflow a -> Overflow a -> Bool
Eq)

-- | Return 'True' if a computation caused unsigned overflow.
ofUnsigned :: Overflow a -> Bool
ofUnsigned :: forall a. Overflow a -> Bool
ofUnsigned (Overflow UnsignedOverflow
UnsignedOverflow SignedOverflow
_ a
_) = Bool
True
ofUnsigned Overflow a
_ = Bool
False

-- | Return 'True' if a computation caused signed overflow.
ofSigned :: Overflow a -> Bool
ofSigned :: forall a. Overflow a -> Bool
ofSigned (Overflow UnsignedOverflow
_ SignedOverflow
SignedOverflow a
_) = Bool
True
ofSigned Overflow a
_ = Bool
False

-- | Return the result of a computation.
ofResult :: Overflow a -> a
ofResult :: forall a. Overflow a -> a
ofResult (Overflow UnsignedOverflow
_ SignedOverflow
_ a
res) = a
res

instance Foldable Overflow where
  foldMap :: forall m a. Monoid m => (a -> m) -> Overflow a -> m
foldMap a -> m
f (Overflow UnsignedOverflow
_ SignedOverflow
_ a
a) = a -> m
f a
a

instance Traversable Overflow where
  sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Overflow (f a) -> f (Overflow a)
sequenceA (Overflow UnsignedOverflow
uof SignedOverflow
sof f a
a) = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
uof SignedOverflow
sof forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
a

instance Functor Overflow where
  fmap :: forall a b. (a -> b) -> Overflow a -> Overflow b
fmap a -> b
f (Overflow UnsignedOverflow
uof SignedOverflow
sof a
a) = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
uof SignedOverflow
sof (a -> b
f a
a)

instance Applicative Overflow where
  pure :: forall a. a -> Overflow a
pure a
a = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty a
a
  Overflow UnsignedOverflow
uof SignedOverflow
sof a -> b
f <*> :: forall a b. Overflow (a -> b) -> Overflow a -> Overflow b
<*> Overflow UnsignedOverflow
uof' SignedOverflow
sof' a
a =
    forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow (UnsignedOverflow
uof forall a. Semigroup a => a -> a -> a
<> UnsignedOverflow
uof') (SignedOverflow
sof forall a. Semigroup a => a -> a -> a
<> SignedOverflow
sof') (a -> b
f a
a)

-- | Monad for bitvector operations which might produce signed or
-- unsigned overflow.
instance Monad Overflow where
  Overflow UnsignedOverflow
uof SignedOverflow
sof a
a >>= :: forall a b. Overflow a -> (a -> Overflow b) -> Overflow b
>>= a -> Overflow b
k =
    let Overflow UnsignedOverflow
uof' SignedOverflow
sof' b
b = a -> Overflow b
k a
a
    in forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow (UnsignedOverflow
uof forall a. Semigroup a => a -> a -> a
<> UnsignedOverflow
uof') (SignedOverflow
sof forall a. Semigroup a => a -> a -> a
<> SignedOverflow
sof') b
b

getUof :: NatRepr w -> Integer -> UnsignedOverflow
getUof :: forall (w :: Nat). NatRepr w -> Integer -> UnsignedOverflow
getUof NatRepr w
w Integer
x = if Integer
x forall a. Ord a => a -> a -> Bool
< forall (w :: Nat). NatRepr w -> Integer
P.minUnsigned NatRepr w
w Bool -> Bool -> Bool
|| Integer
x forall a. Ord a => a -> a -> Bool
> forall (w :: Nat). NatRepr w -> Integer
P.maxUnsigned NatRepr w
w
             then UnsignedOverflow
UnsignedOverflow
             else UnsignedOverflow
NoUnsignedOverflow

getSof :: NatRepr w -> Integer -> SignedOverflow
getSof :: forall (w :: Nat). NatRepr w -> Integer -> SignedOverflow
getSof NatRepr w
w Integer
x = case forall (n :: Nat). NatRepr n -> Either (n :~: 0) (LeqProof 1 n)
P.isZeroOrGT1 NatRepr w
w of
  Left w :~: 0
P.Refl -> SignedOverflow
NoSignedOverflow
  Right LeqProof 1 w
P.LeqProof ->
    if Integer
x forall a. Ord a => a -> a -> Bool
< forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.minSigned NatRepr w
w Bool -> Bool -> Bool
|| Integer
x forall a. Ord a => a -> a -> Bool
> forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.maxSigned NatRepr w
w
    then SignedOverflow
SignedOverflow
    else SignedOverflow
NoSignedOverflow

-- | This only works if the operation has equivalent signed and
-- unsigned interpretations on bitvectors.
liftBinary :: (1 <= w) => (Integer -> Integer -> Integer)
           -> NatRepr w
           -> BV w -> BV w -> Overflow (BV w)
liftBinary :: forall (w :: Nat).
(1 <= w) =>
(Integer -> Integer -> Integer)
-> NatRepr w -> BV w -> BV w -> Overflow (BV w)
liftBinary Integer -> Integer -> Integer
op NatRepr w
w BV w
xv BV w
yv =
  let ux :: Integer
ux = forall (w :: Nat). BV w -> Integer
asUnsigned BV w
xv
      uy :: Integer
uy = forall (w :: Nat). BV w -> Integer
asUnsigned BV w
yv
      sx :: Integer
sx = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
xv
      sy :: Integer
sy = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
yv

      ures :: Integer
ures = Integer
ux Integer -> Integer -> Integer
`op` Integer
uy
      sres :: Integer
sres = Integer
sx Integer -> Integer -> Integer
`op` Integer
sy

      uof :: UnsignedOverflow
uof = forall (w :: Nat). NatRepr w -> Integer -> UnsignedOverflow
getUof NatRepr w
w Integer
ures
      sof :: SignedOverflow
sof = forall (w :: Nat). NatRepr w -> Integer -> SignedOverflow
getSof NatRepr w
w Integer
sres
  in forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
uof SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w Integer
ures)

-- | Bitvector add.
addOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
addOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
addOf = forall (w :: Nat).
(1 <= w) =>
(Integer -> Integer -> Integer)
-> NatRepr w -> BV w -> BV w -> Overflow (BV w)
liftBinary forall a. Num a => a -> a -> a
(+)

-- | Bitvector subtract.
subOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
subOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
subOf = forall (w :: Nat).
(1 <= w) =>
(Integer -> Integer -> Integer)
-> NatRepr w -> BV w -> BV w -> Overflow (BV w)
liftBinary (-)

-- | Bitvector multiply.
mulOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
mulOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
mulOf = forall (w :: Nat).
(1 <= w) =>
(Integer -> Integer -> Integer)
-> NatRepr w -> BV w -> BV w -> Overflow (BV w)
liftBinary forall a. Num a => a -> a -> a
(*)

-- | Left shift by positive 'Natural'.
shlOf :: (1 <= w) => NatRepr w -> BV w -> Natural -> Overflow (BV w)
shlOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> Nat -> Overflow (BV w)
shlOf NatRepr w
w BV w
xv Nat
shf =
  let ux :: Integer
ux = forall (w :: Nat). BV w -> Integer
asUnsigned BV w
xv
      sx :: Integer
sx = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
xv
      ures :: Integer
ures = Integer
ux forall a. Bits a => a -> Int -> a
`B.shiftL` forall (w :: Nat). NatRepr w -> Nat -> Int
shiftAmount NatRepr w
w Nat
shf
      sres :: Integer
sres = Integer
sx forall a. Bits a => a -> Int -> a
`B.shiftL` forall (w :: Nat). NatRepr w -> Nat -> Int
shiftAmount NatRepr w
w Nat
shf
      uof :: UnsignedOverflow
uof = forall (w :: Nat). NatRepr w -> Integer -> UnsignedOverflow
getUof NatRepr w
w Integer
ures
      sof :: SignedOverflow
sof = forall (w :: Nat). NatRepr w -> Integer -> SignedOverflow
getSof NatRepr w
w Integer
sres
  in forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
uof SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w Integer
ures)

-- | Bitvector division (signed). Rounds to zero. Division by zero
-- yields a runtime error.
squotOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
squotOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
squotOf NatRepr w
w BV w
bv1 BV w
bv2 = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
NoUnsignedOverflow SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w (Integer
x forall a. Integral a => a -> a -> a
`quot` Integer
y))
  where x :: Integer
x = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv1
        y :: Integer
y = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv2
        sof :: SignedOverflow
sof = if (Integer
x forall a. Eq a => a -> a -> Bool
== forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.minSigned NatRepr w
w Bool -> Bool -> Bool
&& Integer
y forall a. Eq a => a -> a -> Bool
== -Integer
1)
              then SignedOverflow
SignedOverflow
              else SignedOverflow
NoSignedOverflow

-- | Bitvector remainder after division (signed), when rounded to
-- zero. Division by zero yields a runtime error.
sremOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
sremOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
sremOf NatRepr w
w BV w
bv1 BV w
bv2 = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
NoUnsignedOverflow SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w (Integer
x forall a. Integral a => a -> a -> a
`rem` Integer
y))
  where x :: Integer
x = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv1
        y :: Integer
y = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv2
        sof :: SignedOverflow
sof = if (Integer
x forall a. Eq a => a -> a -> Bool
== forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.minSigned NatRepr w
w Bool -> Bool -> Bool
&& Integer
y forall a. Eq a => a -> a -> Bool
== -Integer
1)
              then SignedOverflow
SignedOverflow
              else SignedOverflow
NoSignedOverflow

-- | Bitvector division (signed). Rounds to zero. Division by zero
-- yields a runtime error.
sdivOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
sdivOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
sdivOf NatRepr w
w BV w
bv1 BV w
bv2 = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
NoUnsignedOverflow SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w (Integer
x forall a. Integral a => a -> a -> a
`div` Integer
y))
  where x :: Integer
x = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv1
        y :: Integer
y = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv2
        sof :: SignedOverflow
sof = if (Integer
x forall a. Eq a => a -> a -> Bool
== forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.minSigned NatRepr w
w Bool -> Bool -> Bool
&& Integer
y forall a. Eq a => a -> a -> Bool
== -Integer
1)
              then SignedOverflow
SignedOverflow
              else SignedOverflow
NoSignedOverflow

-- | Bitvector remainder after division (signed), when rounded to
-- zero. Division by zero yields a runtime error.
smodOf :: (1 <= w) => NatRepr w -> BV w -> BV w -> Overflow (BV w)
smodOf :: forall (w :: Nat).
(1 <= w) =>
NatRepr w -> BV w -> BV w -> Overflow (BV w)
smodOf NatRepr w
w BV w
bv1 BV w
bv2 = forall a. UnsignedOverflow -> SignedOverflow -> a -> Overflow a
Overflow UnsignedOverflow
NoUnsignedOverflow SignedOverflow
sof (forall (w :: Nat). NatRepr w -> Integer -> BV w
mkBV' NatRepr w
w (Integer
x forall a. Integral a => a -> a -> a
`mod` Integer
y))
  where x :: Integer
x = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv1
        y :: Integer
y = forall (w :: Nat). (1 <= w) => NatRepr w -> BV w -> Integer
asSigned NatRepr w
w BV w
bv2
        sof :: SignedOverflow
sof = if (Integer
x forall a. Eq a => a -> a -> Bool
== forall (w :: Nat). (1 <= w) => NatRepr w -> Integer
P.minSigned NatRepr w
w Bool -> Bool -> Bool
&& Integer
y forall a. Eq a => a -> a -> Bool
== -Integer
1)
              then SignedOverflow
SignedOverflow
              else SignedOverflow
NoSignedOverflow