-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- | Provides a singleton type for a subset of 'Nat's, represented by 'Int'.
--
-- This is particularly useful when working with length-indexed array types,
-- since the array primitives generally expect lengths and indices to be
-- 'Int's.  Thus, there's no need to pay the runtime cost of lugging around
-- 'Natural's to handle greater-than-maxInt-length arrays, since the underlying
-- primitives don't handle them either.
--
-- An @'SInt' n@ is trusted absolutely by downstream code to contain an 'Int'
-- @n'@ s.t. @fromIntegral n' == natVal' \@n Proxy#@.  In particular, this
-- trust extends to a willingness to use two runtime-equal 'SInt's as proof
-- that their type parameters are equal, or to use GHC primitives in a way
-- that's only memory-safe if this property holds.  This means it should be
-- considered /unsafe/ to construct an 'SInt' in any way that's not statically
-- guaranteed to produce the correct runtime value, and to construct one with
-- an incorrect runtime value is equivalent to using 'unsafeCoerce'
-- incorrectly.
--
-- 'SInt' should be seen as a more efficient implementation of
-- @data SNat n = KnownNat n => SNat@, so that constructing an incorrect 'SInt'
-- would be equivalent to producing an incorrect 'KnownNat' instance.
--
-- 'SInt's are constructed safely by 'staticSIntVal' with no overhead,
-- by 'sintVal' with runtime bounds checks based on a 'KnownNat' instance, or
-- by various arithmetic functions.

{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}

#include "MachDeps.h"

module Data.SInt
         ( SInt(SI#, SI, unSInt), trySIntVal, sintVal, reifySInt, withSInt
         , addSInt, subSInt, subSIntLE, subSIntL, mulSInt, divSIntL, divSIntR
         , staticSIntVal
           -- * Internal
         , IntMaxP1
         ) where

import Data.Proxy (Proxy(..))
import GHC.Exts (Int(I#), addIntC#, mulIntMayOflo#, proxy#)
import GHC.Stack (HasCallStack)
import GHC.TypeNats
         ( type (<=), type (+), type (-), type (*), type (^), CmpNat
         , KnownNat, Nat, natVal', SomeNat(..), someNatVal
         )
import Numeric.Natural (Natural)

import Data.Portray (Portray)
import Data.Portray.Diff (Diff)

#if MIN_VERSION_base(4,15,0)
import Unsafe.Coerce (unsafeEqualityProof, UnsafeEquality(..))
#else
import Data.Type.Equality ((:~:)(..))
import Unsafe.Coerce (unsafeCoerce)
#endif

-- | A singleton type linking a runtime 'Int' and a type-level 'Nat'.
newtype SInt (n :: Nat) = MkSInt Int
  deriving newtype (Int -> SInt n -> ShowS
[SInt n] -> ShowS
SInt n -> String
(Int -> SInt n -> ShowS)
-> (SInt n -> String) -> ([SInt n] -> ShowS) -> Show (SInt n)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (n :: Nat). Int -> SInt n -> ShowS
forall (n :: Nat). [SInt n] -> ShowS
forall (n :: Nat). SInt n -> String
showList :: [SInt n] -> ShowS
$cshowList :: forall (n :: Nat). [SInt n] -> ShowS
show :: SInt n -> String
$cshow :: forall (n :: Nat). SInt n -> String
showsPrec :: Int -> SInt n -> ShowS
$cshowsPrec :: forall (n :: Nat). Int -> SInt n -> ShowS
Show, [SInt n] -> Portrayal
SInt n -> Portrayal
(SInt n -> Portrayal)
-> ([SInt n] -> Portrayal) -> Portray (SInt n)
forall a. (a -> Portrayal) -> ([a] -> Portrayal) -> Portray a
forall (n :: Nat). [SInt n] -> Portrayal
forall (n :: Nat). SInt n -> Portrayal
portrayList :: [SInt n] -> Portrayal
$cportrayList :: forall (n :: Nat). [SInt n] -> Portrayal
portray :: SInt n -> Portrayal
$cportray :: forall (n :: Nat). SInt n -> Portrayal
Portray, SInt n -> SInt n -> Maybe Portrayal
(SInt n -> SInt n -> Maybe Portrayal) -> Diff (SInt n)
forall a. (a -> a -> Maybe Portrayal) -> Diff a
forall (n :: Nat). SInt n -> SInt n -> Maybe Portrayal
diff :: SInt n -> SInt n -> Maybe Portrayal
$cdiff :: forall (n :: Nat). SInt n -> SInt n -> Maybe Portrayal
Diff)

-- We must take care to prevent 'SInt's from being coerced across @n@.
type role SInt nominal

-- | Construct an 'SInt' unsafely.  Incorrect uses cause undefined behavior.
--
-- See the module intro for more details; prefer to use safe methods to
-- construct 'SInt's, and treat this constructor equivalently to
-- 'unsafeCoerce'.
pattern SI# :: Int -> SInt n
pattern $bSI# :: Int -> SInt n
$mSI# :: forall r (n :: Nat). SInt n -> (Int -> r) -> (Void# -> r) -> r
SI# x = MkSInt x
{-# COMPLETE SI# #-}

-- | A unidirectional pattern for safely deconstructing 'SInt's.
--
-- This lets us export 'unSInt' as if it were a field selector, without making
-- it legal to use in record updates (because this pattern is unidirectional).
pattern SI :: Int -> SInt n
pattern $mSI :: forall r (n :: Nat). SInt n -> (Int -> r) -> (Void# -> r) -> r
SI {SInt n -> Int
unSInt} <- MkSInt unSInt
{-# COMPLETE SI #-}

-- | Use an 'Int' as an existentially-quantified 'SInt'.
withSInt :: HasCallStack => Int -> (forall n. SInt n -> r) -> r
withSInt :: Int -> (forall (n :: Nat). SInt n -> r) -> r
withSInt Int
n forall (n :: Nat). SInt n -> r
f
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = String -> r
forall a. HasCallStack => String -> a
error String
"withSInt: negative value"
  | Bool
otherwise = SInt Any -> r
forall (n :: Nat). SInt n -> r
f (Int -> SInt Any
forall (n :: Nat). Int -> SInt n
SI# Int
n)

maxInt :: Natural
maxInt :: Natural
maxInt = Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)

-- | Produce an 'SInt' for a given 'KnownNat', or 'Nothing' if out of range.
trySIntVal :: forall n. KnownNat n => Maybe (SInt n)
trySIntVal :: Maybe (SInt n)
trySIntVal =
  let n :: Natural
n = Proxy# n -> Natural
forall (n :: Nat). KnownNat n => Proxy# n -> Natural
natVal' @n Proxy# n
forall k (a :: k). Proxy# a
proxy#
  in  if Natural
n Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
maxInt then SInt n -> Maybe (SInt n)
forall a. a -> Maybe a
Just (Int -> SInt n
forall (n :: Nat). Int -> SInt n
MkSInt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n)) else Maybe (SInt n)
forall a. Maybe a
Nothing
{-# INLINE trySIntVal #-}

-- | Produce an 'SInt' for a given 'KnownNat', or 'error' if out of range.
sintVal :: forall n. (HasCallStack, KnownNat n) => SInt n
sintVal :: SInt n
sintVal = case Maybe (SInt n)
forall (n :: Nat). KnownNat n => Maybe (SInt n)
trySIntVal of
  Just SInt n
n -> SInt n
n
  Maybe (SInt n)
Nothing -> String -> SInt n
forall a. HasCallStack => String -> a
error (String -> SInt n) -> String -> SInt n
forall a b. (a -> b) -> a -> b
$
    String
"Nat " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Natural -> String
forall a. Show a => a -> String
show (Proxy# n -> Natural
forall (n :: Nat). KnownNat n => Proxy# n -> Natural
natVal' @n Proxy# n
forall k (a :: k). Proxy# a
proxy#) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" out of range for Int."
{-# INLINE sintVal #-}

-- | One more than the maximum representable 'Int' on the current platform.
type IntMaxP1 = 2 ^ (WORD_SIZE_IN_BITS - 1)

-- | Like 'sintVal', but with static proof that it's in-bounds.
--
-- This optimizes down to an actual primitive literal wrapped in the
-- appropriate constructors, unlike 'sintVal', where the bounds checking gets
-- in the way.  If you're constructing a statically-known 'SInt', use
-- 'staticSIntVal'; while if you're constructing an 'SInt' from a runtime
-- 'KnownNat' instance, you'll have to use 'sintVal'.
staticSIntVal :: forall n. (CmpNat n IntMaxP1 ~ 'LT, KnownNat n) => SInt n
staticSIntVal :: SInt n
staticSIntVal = Int -> SInt n
forall (n :: Nat). Int -> SInt n
MkSInt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# n -> Natural
forall (n :: Nat). KnownNat n => Proxy# n -> Natural
natVal' @n Proxy# n
forall k (a :: k). Proxy# a
proxy#))
{-# INLINE staticSIntVal #-}

-- | Add two 'SInt's with bounds checks; 'error' if the result overflows.
addSInt :: HasCallStack => SInt m -> SInt n -> SInt (m + n)
addSInt :: SInt m -> SInt n -> SInt (m + n)
addSInt (SI# (I# Int#
m)) (SI# (I# Int#
n)) =
  case Int# -> Int# -> (# Int#, Int# #)
addIntC# Int#
m Int#
n of
    (# Int#
mn, Int#
ovf #)
      | Int# -> Int
I# Int#
ovf Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 -> Int -> SInt (m + n)
forall (n :: Nat). Int -> SInt n
SI# (Int# -> Int
I# Int#
mn)
      | Bool
otherwise   -> String -> SInt (m + n)
forall a. HasCallStack => String -> a
error (String -> SInt (m + n)) -> String -> SInt (m + n)
forall a b. (a -> b) -> a -> b
$
          String
"Nat " String -> ShowS
forall a. [a] -> [a] -> [a]
++
          Natural -> String
forall a. Show a => a -> String
show (Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# Int#
m) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# Int#
n) :: Natural) String -> ShowS
forall a. [a] -> [a] -> [a]
++
          String
" out of range for Int."

-- | Multiply two 'SInt's with bounds checks; 'error' if the result overflows.
mulSInt :: HasCallStack => SInt m -> SInt n -> SInt (m * n)
mulSInt :: SInt m -> SInt n -> SInt (m * n)
mulSInt (SI# m :: Int
m@(I# Int#
m')) (SI# n :: Int
n@(I# Int#
n')) =
   case Int# -> Int# -> Int#
mulIntMayOflo# Int#
m' Int#
n' of
     Int#
ovf
       | Int# -> Int
I# Int#
ovf Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 -> Int -> SInt (m * n)
forall (n :: Nat). Int -> SInt n
SI# Int
mn
       | Int
mn Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
mn Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
mnNat -> Int -> SInt (m * n)
forall (n :: Nat). Int -> SInt n
SI# Int
mn
       | Bool
otherwise -> String -> SInt (m * n)
forall a. HasCallStack => String -> a
error (String -> SInt (m * n)) -> String -> SInt (m * n)
forall a b. (a -> b) -> a -> b
$ String
"Nat " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Natural -> String
forall a. Show a => a -> String
show Natural
mnNat String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" out of range for Int."
 where
  mn :: Int
mn = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n
  mnNat :: Natural
mnNat = Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Natural

-- | Subtract two 'SInt's with bounds checks; 'error' if the result is negative.
subSInt :: HasCallStack => SInt m -> SInt n -> SInt (m - n)
subSInt :: SInt m -> SInt n -> SInt (m - n)
subSInt (SI# Int
m) (SI# Int
n)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
m = String -> SInt (m - n)
forall a. HasCallStack => String -> a
error (String -> SInt (m - n)) -> String -> SInt (m - n)
forall a b. (a -> b) -> a -> b
$ String
"Nat " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" out of range."
  | Bool
otherwise = Int -> SInt (m - n)
forall (n :: Nat). Int -> SInt n
SI# (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

-- | Subtract two 'SInt's, using an inequality constraint to rule out overflow.
subSIntLE :: n <= m => SInt m -> SInt n -> SInt (m - n)
subSIntLE :: SInt m -> SInt n -> SInt (m - n)
subSIntLE (SI# Int
m) (SI# Int
n) = Int -> SInt (m - n)
forall (n :: Nat). Int -> SInt n
SI# (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

-- | "Un-add" an 'SInt' from another 'SInt', on the left.
--
-- This form of 'subSInt' is more convenient in certain cases when a type
-- signature ensures a particular 'SInt' is of the form @m + n@.
subSIntL :: SInt (m + n) -> SInt m -> SInt n
subSIntL :: SInt (m + n) -> SInt m -> SInt n
subSIntL (SI# mn) (SI# Int
m) = Int -> SInt n
forall (n :: Nat). Int -> SInt n
SI# (Int
mn Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m)

-- | "Un-multiply" an 'SInt' by another 'SInt', on the left.
--
-- This form of @divSInt@ is more convenient in certain cases when a type
-- signature ensures a particular 'SInt' is of the form @m * n@.
divSIntL :: SInt (m * n) -> SInt m -> SInt n
divSIntL :: SInt (m * n) -> SInt m -> SInt n
divSIntL (SI# mn) (SI# Int
m) = Int -> SInt n
forall (n :: Nat). Int -> SInt n
SI# (Int
mn Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)

-- | "Un-multiply" an 'SInt' by another 'SInt', on the right.
--
-- This form of @divSInt@ is more convenient in certain cases when a type
-- signature ensures a particular 'SInt' is of the form @m * n@.
divSIntR :: SInt (m * n) -> SInt n -> SInt m
divSIntR :: SInt (m * n) -> SInt n -> SInt m
divSIntR (SI# mn) (SI# Int
n) = Int -> SInt m
forall (n :: Nat). Int -> SInt n
SI# (Int
mn Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
n)

-- | Bring an 'SInt' back into the type level as a 'KnownNat' instance.
reifySInt :: forall n r. SInt n -> (KnownNat n => r) -> r
reifySInt :: SInt n -> (KnownNat n => r) -> r
reifySInt (SI# Int
n) KnownNat n => r
r =
  case Natural -> SomeNat
someNatVal (Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) of
    SomeNat (Proxy n
Proxy :: Proxy m) ->
#if MIN_VERSION_base(4,15,0)
      case unsafeEqualityProof @m @n of UnsafeRefl -> r
#else
      case (Any :~: Any) -> n :~: n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall k (a :: k). a :~: a
Refl :: m :~: n of n :~: n
Refl -> r
KnownNat n => r
r
#endif