{- | Machine Fortran INTEGER values.

This module stores Fortran INTEGER values in a matching Haskell machine integer
type. For example, an @INT(4)@ would be stored in an 'Int32'. This way, we get
both efficient operations and common overflow behaviour (which hopefully matches
most Fortran compilers), and explicitly encode kinding semantics via promoting
integral types.
-}

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}

module Language.Fortran.Repr.Value.Scalar.Int.Machine
  ( FInt(..)
  , SomeFInt
  , type IsFInt

  , fIntUOp
  , fIntUOp'
  , fIntUOpInplace
  , fIntUOpInplace'
  , fIntUOpInternal

  , fIntBOp
  , fIntBOp'
  , fIntBOpInplace
  , fIntBOpInplace'
  , fIntBOpInternal

  , withFInt
  ) where

import Language.Fortran.Repr.Type.Scalar.Int
import Language.Fortran.Repr.Value.Scalar.Common
import Data.Int
import Data.Functor.Const

import Data.Bits ( Bits )

import Language.Fortran.Repr.Util ( natVal'' )
import GHC.TypeNats

-- | A Fortran integer value, tagged with its kind.
data FInt (k :: FTInt) where
    FInt1 :: Int8  -> FInt 'FTInt1 -- ^ @INTEGER(1)@
    FInt2 :: Int16 -> FInt 'FTInt2 -- ^ @INTEGER(2)@
    FInt4 :: Int32 -> FInt 'FTInt4 -- ^ @INTEGER(4)@
    FInt8 :: Int64 -> FInt 'FTInt8 -- ^ @INTEGER(8)@
deriving stock instance Show (FInt k)
deriving stock instance Eq   (FInt k)
deriving stock instance Ord  (FInt k)

type IsFInt a = (Integral a, Bits a)

type SomeFInt = SomeFKinded FTInt FInt
deriving stock instance Show SomeFInt
instance Eq SomeFInt where
    (SomeFKinded FInt fk
l) == :: SomeFInt -> SomeFInt -> Bool
== (SomeFKinded FInt fk
r) = forall r (kl :: FTInt) (kr :: FTInt).
(forall a. IsFInt a => a -> a -> r) -> FInt kl -> FInt kr -> r
fIntBOp forall a. Eq a => a -> a -> Bool
(==) FInt fk
l FInt fk
r

-- | Low-level 'FInt' unary operator. Runs an operation over some 'FInt', and
--   stores it kinded. The user gets to choose how the kind is used: it can be
--   used to wrap the result back into an 'FInt', or ignored using 'Const'.
--
-- Pattern matches are ordered to match more common ops earlier.
fIntUOpInternal
    :: (Int8  -> ft 'FTInt1)
    -> (Int16 -> ft 'FTInt2)
    -> (Int32 -> ft 'FTInt4)
    -> (Int64 -> ft 'FTInt8)
    -> FInt k -> ft k
fIntUOpInternal :: forall (ft :: FTInt -> *) (k :: FTInt).
(Int8 -> ft 'FTInt1)
-> (Int16 -> ft 'FTInt2)
-> (Int32 -> ft 'FTInt4)
-> (Int64 -> ft 'FTInt8)
-> FInt k
-> ft k
fIntUOpInternal Int8 -> ft 'FTInt1
k1f Int16 -> ft 'FTInt2
k2f Int32 -> ft 'FTInt4
k4f Int64 -> ft 'FTInt8
k8f = \case
  FInt4 Int32
i32 -> Int32 -> ft 'FTInt4
k4f Int32
i32
  FInt8 Int64
i64 -> Int64 -> ft 'FTInt8
k8f Int64
i64
  FInt2 Int16
i16 -> Int16 -> ft 'FTInt2
k2f Int16
i16
  FInt1 Int8
i8  -> Int8 -> ft 'FTInt1
k1f Int8
i8

-- | Run an operation over some 'FInt', with a concrete function for each kind.
fIntUOp'
    :: (Int8  -> r)
    -> (Int16 -> r)
    -> (Int32 -> r)
    -> (Int64 -> r)
    -> FInt k -> r
fIntUOp' :: forall r (k :: FTInt).
(Int8 -> r)
-> (Int16 -> r) -> (Int32 -> r) -> (Int64 -> r) -> FInt k -> r
fIntUOp' Int8 -> r
k1f Int16 -> r
k2f Int32 -> r
k4f Int64 -> r
k8f =
      forall {k} a (b :: k). Const a b -> a
getConst
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (ft :: FTInt -> *) (k :: FTInt).
(Int8 -> ft 'FTInt1)
-> (Int16 -> ft 'FTInt2)
-> (Int32 -> ft 'FTInt4)
-> (Int64 -> ft 'FTInt8)
-> FInt k
-> ft k
fIntUOpInternal (forall {k} a (b :: k). a -> Const a b
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int8 -> r
k1f) (forall {k} a (b :: k). a -> Const a b
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int16 -> r
k2f) (forall {k} a (b :: k). a -> Const a b
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> r
k4f) (forall {k} a (b :: k). a -> Const a b
Const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> r
k8f)

-- | Run an operation over some 'FInt'.
fIntUOp
    :: forall r k
    .  (forall a. IsFInt a => a -> r)
    -> FInt k -> r
fIntUOp :: forall r (k :: FTInt).
(forall a. IsFInt a => a -> r) -> FInt k -> r
fIntUOp forall a. IsFInt a => a -> r
f = forall r (k :: FTInt).
(Int8 -> r)
-> (Int16 -> r) -> (Int32 -> r) -> (Int64 -> r) -> FInt k -> r
fIntUOp' forall a. IsFInt a => a -> r
f forall a. IsFInt a => a -> r
f forall a. IsFInt a => a -> r
f forall a. IsFInt a => a -> r
f

-- | Run an inplace operation over some 'FInt', with a concrete function for
--   each kind.
fIntUOpInplace'
    :: (Int8  -> Int8)
    -> (Int16 -> Int16)
    -> (Int32 -> Int32)
    -> (Int64 -> Int64)
    -> FInt k -> FInt k
fIntUOpInplace' :: forall (k :: FTInt).
(Int8 -> Int8)
-> (Int16 -> Int16)
-> (Int32 -> Int32)
-> (Int64 -> Int64)
-> FInt k
-> FInt k
fIntUOpInplace' Int8 -> Int8
k1f Int16 -> Int16
k2f Int32 -> Int32
k4f Int64 -> Int64
k8f =
    forall (ft :: FTInt -> *) (k :: FTInt).
(Int8 -> ft 'FTInt1)
-> (Int16 -> ft 'FTInt2)
-> (Int32 -> ft 'FTInt4)
-> (Int64 -> ft 'FTInt8)
-> FInt k
-> ft k
fIntUOpInternal (Int8 -> FInt 'FTInt1
FInt1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int8 -> Int8
k1f) (Int16 -> FInt 'FTInt2
FInt2 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int16 -> Int16
k2f) (Int32 -> FInt 'FTInt4
FInt4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> Int32
k4f) (Int64 -> FInt 'FTInt8
FInt8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int64
k8f)

-- | Run an inplace operation over some 'FInt'.
fIntUOpInplace
    :: (forall a. IsFInt a => a -> a)
    -> FInt k -> FInt k
fIntUOpInplace :: forall (k :: FTInt).
(forall a. IsFInt a => a -> a) -> FInt k -> FInt k
fIntUOpInplace forall a. IsFInt a => a -> a
f = forall (k :: FTInt).
(Int8 -> Int8)
-> (Int16 -> Int16)
-> (Int32 -> Int32)
-> (Int64 -> Int64)
-> FInt k
-> FInt k
fIntUOpInplace' forall a. IsFInt a => a -> a
f forall a. IsFInt a => a -> a
f forall a. IsFInt a => a -> a
f forall a. IsFInt a => a -> a
f

-- | Low-level 'FInt' binary operator. Combine two 'FInt's, coercing different
--   kinds, and store the result kinded.
--
-- Pattern matches are ordered to match more common ops earlier.
fIntBOpInternal
    :: (Int8  -> Int8  -> ft 'FTInt1)
    -> (Int16 -> Int16 -> ft 'FTInt2)
    -> (Int32 -> Int32 -> ft 'FTInt4)
    -> (Int64 -> Int64 -> ft 'FTInt8)
    -> FInt kl -> FInt kr -> ft (FTIntCombine kl kr)
fIntBOpInternal :: forall (ft :: FTInt -> *) (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> ft 'FTInt1)
-> (Int16 -> Int16 -> ft 'FTInt2)
-> (Int32 -> Int32 -> ft 'FTInt4)
-> (Int64 -> Int64 -> ft 'FTInt8)
-> FInt kl
-> FInt kr
-> ft (FTIntCombine kl kr)
fIntBOpInternal Int8 -> Int8 -> ft 'FTInt1
k1f Int16 -> Int16 -> ft 'FTInt2
k2f Int32 -> Int32 -> ft 'FTInt4
k4f Int64 -> Int64 -> ft 'FTInt8
k8f FInt kl
il FInt kr
ir = case (FInt kl
il, FInt kr
ir) of
  (FInt4 Int32
l32, FInt4 Int32
r32) -> Int32 -> Int32 -> ft 'FTInt4
k4f Int32
l32 Int32
r32
  (FInt8 Int64
l64, FInt8 Int64
r64) -> Int64 -> Int64 -> ft 'FTInt8
k8f Int64
l64 Int64
r64

  (FInt4 Int32
l32, FInt8 Int64
r64) -> Int64 -> Int64 -> ft 'FTInt8
k8f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
l32) Int64
r64
  (FInt8 Int64
l64, FInt4 Int32
r32) -> Int64 -> Int64 -> ft 'FTInt8
k8f Int64
l64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
r32)

  (FInt4 Int32
l32, FInt2 Int16
r16) -> Int32 -> Int32 -> ft 'FTInt4
k4f Int32
l32 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
r16)
  (FInt2 Int16
l16, FInt4 Int32
r32) -> Int32 -> Int32 -> ft 'FTInt4
k4f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
l16) Int32
r32

  (FInt4 Int32
l32, FInt1 Int8
r8)  -> Int32 -> Int32 -> ft 'FTInt4
k4f Int32
l32 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
r8)
  (FInt1 Int8
l8,  FInt4 Int32
r32) -> Int32 -> Int32 -> ft 'FTInt4
k4f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
l8) Int32
r32

  (FInt8 Int64
l64, FInt2 Int16
r16) -> Int64 -> Int64 -> ft 'FTInt8
k8f Int64
l64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
r16)
  (FInt2 Int16
l16, FInt8 Int64
r64) -> Int64 -> Int64 -> ft 'FTInt8
k8f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
l16) Int64
r64

  (FInt8 Int64
l64, FInt1 Int8
r8)  -> Int64 -> Int64 -> ft 'FTInt8
k8f Int64
l64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
r8)
  (FInt1 Int8
l8,  FInt8 Int64
r64) -> Int64 -> Int64 -> ft 'FTInt8
k8f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
l8) Int64
r64

  (FInt2 Int16
l16, FInt2 Int16
r16) -> Int16 -> Int16 -> ft 'FTInt2
k2f Int16
l16 Int16
r16
  (FInt2 Int16
l16, FInt1 Int8
r8)  -> Int16 -> Int16 -> ft 'FTInt2
k2f Int16
l16 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
r8)
  (FInt1 Int8
l8,  FInt2 Int16
r16) -> Int16 -> Int16 -> ft 'FTInt2
k2f (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
l8) Int16
r16

  (FInt1 Int8
l8,  FInt1 Int8
r8)  -> Int8 -> Int8 -> ft 'FTInt1
k1f Int8
l8 Int8
r8

fIntBOp'
    :: (Int8  -> Int8  -> r)
    -> (Int16 -> Int16 -> r)
    -> (Int32 -> Int32 -> r)
    -> (Int64 -> Int64 -> r)
    -> FInt kl -> FInt kr -> r
fIntBOp' :: forall r (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> r)
-> (Int16 -> Int16 -> r)
-> (Int32 -> Int32 -> r)
-> (Int64 -> Int64 -> r)
-> FInt kl
-> FInt kr
-> r
fIntBOp' Int8 -> Int8 -> r
k1f Int16 -> Int16 -> r
k2f Int32 -> Int32 -> r
k4f Int64 -> Int64 -> r
k8f FInt kl
il FInt kr
ir =
      forall {k} a (b :: k). Const a b -> a
getConst
    forall a b. (a -> b) -> a -> b
$ forall (ft :: FTInt -> *) (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> ft 'FTInt1)
-> (Int16 -> Int16 -> ft 'FTInt2)
-> (Int32 -> Int32 -> ft 'FTInt4)
-> (Int64 -> Int64 -> ft 'FTInt8)
-> FInt kl
-> FInt kr
-> ft (FTIntCombine kl kr)
fIntBOpInternal (forall {k} {t} {t} {a} {b :: k}.
(t -> t -> a) -> t -> t -> Const a b
go Int8 -> Int8 -> r
k1f) (forall {k} {t} {t} {a} {b :: k}.
(t -> t -> a) -> t -> t -> Const a b
go Int16 -> Int16 -> r
k2f) (forall {k} {t} {t} {a} {b :: k}.
(t -> t -> a) -> t -> t -> Const a b
go Int32 -> Int32 -> r
k4f) (forall {k} {t} {t} {a} {b :: k}.
(t -> t -> a) -> t -> t -> Const a b
go Int64 -> Int64 -> r
k8f) FInt kl
il FInt kr
ir
  where go :: (t -> t -> a) -> t -> t -> Const a b
go t -> t -> a
g t
l t
r = forall {k} a (b :: k). a -> Const a b
Const forall a b. (a -> b) -> a -> b
$ t -> t -> a
g t
l t
r

fIntBOp
    :: (forall a. IsFInt a => a -> a -> r)
    -> FInt kl -> FInt kr -> r
fIntBOp :: forall r (kl :: FTInt) (kr :: FTInt).
(forall a. IsFInt a => a -> a -> r) -> FInt kl -> FInt kr -> r
fIntBOp forall a. IsFInt a => a -> a -> r
f = forall r (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> r)
-> (Int16 -> Int16 -> r)
-> (Int32 -> Int32 -> r)
-> (Int64 -> Int64 -> r)
-> FInt kl
-> FInt kr
-> r
fIntBOp' forall a. IsFInt a => a -> a -> r
f forall a. IsFInt a => a -> a -> r
f forall a. IsFInt a => a -> a -> r
f forall a. IsFInt a => a -> a -> r
f

fIntBOpInplace'
    :: (Int8  -> Int8  -> Int8)
    -> (Int16 -> Int16 -> Int16)
    -> (Int32 -> Int32 -> Int32)
    -> (Int64 -> Int64 -> Int64)
    -> FInt kl -> FInt kr -> FInt (FTIntCombine kl kr)
fIntBOpInplace' :: forall (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> Int8)
-> (Int16 -> Int16 -> Int16)
-> (Int32 -> Int32 -> Int32)
-> (Int64 -> Int64 -> Int64)
-> FInt kl
-> FInt kr
-> FInt (FTIntCombine kl kr)
fIntBOpInplace' Int8 -> Int8 -> Int8
k1f Int16 -> Int16 -> Int16
k2f Int32 -> Int32 -> Int32
k4f Int64 -> Int64 -> Int64
k8f =
    forall (ft :: FTInt -> *) (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> ft 'FTInt1)
-> (Int16 -> Int16 -> ft 'FTInt2)
-> (Int32 -> Int32 -> ft 'FTInt4)
-> (Int64 -> Int64 -> ft 'FTInt8)
-> FInt kl
-> FInt kr
-> ft (FTIntCombine kl kr)
fIntBOpInternal (forall {a} {b} {t} {t}. (a -> b) -> (t -> t -> a) -> t -> t -> b
go Int8 -> FInt 'FTInt1
FInt1 Int8 -> Int8 -> Int8
k1f) (forall {a} {b} {t} {t}. (a -> b) -> (t -> t -> a) -> t -> t -> b
go Int16 -> FInt 'FTInt2
FInt2 Int16 -> Int16 -> Int16
k2f) (forall {a} {b} {t} {t}. (a -> b) -> (t -> t -> a) -> t -> t -> b
go Int32 -> FInt 'FTInt4
FInt4 Int32 -> Int32 -> Int32
k4f) (forall {a} {b} {t} {t}. (a -> b) -> (t -> t -> a) -> t -> t -> b
go Int64 -> FInt 'FTInt8
FInt8 Int64 -> Int64 -> Int64
k8f)
  where go :: (a -> b) -> (t -> t -> a) -> t -> t -> b
go a -> b
f t -> t -> a
g t
l t
r = a -> b
f forall a b. (a -> b) -> a -> b
$ t -> t -> a
g t
l t
r

fIntBOpInplace
    :: (forall a. IsFInt a => a -> a -> a)
    -> FInt kl -> FInt kr -> FInt (FTIntCombine kl kr)
fIntBOpInplace :: forall (kl :: FTInt) (kr :: FTInt).
(forall a. IsFInt a => a -> a -> a)
-> FInt kl -> FInt kr -> FInt (FTIntCombine kl kr)
fIntBOpInplace forall a. IsFInt a => a -> a -> a
f = forall (kl :: FTInt) (kr :: FTInt).
(Int8 -> Int8 -> Int8)
-> (Int16 -> Int16 -> Int16)
-> (Int32 -> Int32 -> Int32)
-> (Int64 -> Int64 -> Int64)
-> FInt kl
-> FInt kr
-> FInt (FTIntCombine kl kr)
fIntBOpInplace' forall a. IsFInt a => a -> a -> a
f forall a. IsFInt a => a -> a -> a
f forall a. IsFInt a => a -> a -> a
f forall a. IsFInt a => a -> a -> a
f

-- | Treat any 'FInt' as a 'Num'.
--
-- TODO remove. means being explicit with coercions to real in eval.
withFInt :: Num a => FInt k -> a
withFInt :: forall a (k :: FTInt). Num a => FInt k -> a
withFInt = forall r (k :: FTInt).
(forall a. IsFInt a => a -> r) -> FInt k -> r
fIntUOp forall a b. (Integral a, Num b) => a -> b
fromIntegral

fIntMax :: forall (k :: FTInt). KnownNat (FTIntMax k) => Int64
fIntMax :: forall (k :: FTInt). KnownNat (FTIntMax k) => Int64
fIntMax = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (a :: Natural). KnownNat a => Natural
natVal'' @(FTIntMax k)

fIntMin :: forall (k :: FTInt). KnownNat (FTIntMin k) => Int64
fIntMin :: forall (k :: FTInt). KnownNat (FTIntMin k) => Int64
fIntMin = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (a :: Natural). KnownNat a => Natural
natVal'' @(FTIntMin k)

-- TODO improve (always return answer, and a flag indicating if there was an
-- error)
fIntCoerceChecked
    :: forall kout kin
    .  (KnownNat (FTIntMax kout), KnownNat (FTIntMin kout))
    => SFTInt kout -> FInt kin -> Either String (FInt kout)
fIntCoerceChecked :: forall (kout :: FTInt) (kin :: FTInt).
(KnownNat (FTIntMax kout), KnownNat (FTIntMin kout)) =>
SFTInt kout -> FInt kin -> Either String (FInt kout)
fIntCoerceChecked SFTInt kout
ty = forall r (k :: FTInt).
(forall a. IsFInt a => a -> r) -> FInt k -> r
fIntUOp forall a b. (a -> b) -> a -> b
$ \a
n ->
    if forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n forall a. Ord a => a -> a -> Bool
> forall (k :: FTInt). KnownNat (FTIntMax k) => Int64
fIntMax @kout then
        forall a b. a -> Either a b
Left String
"too large for new size"
    else if forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n forall a. Ord a => a -> a -> Bool
< forall (k :: FTInt). KnownNat (FTIntMin k) => Int64
fIntMin @kout then
        forall a b. a -> Either a b
Left String
"too small for new size"
    else
        case SFTInt kout
ty of
          SFTInt kout
SFTInt1  -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int8 -> FInt 'FTInt1
FInt1 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n
          SFTInt kout
SFTInt2  -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int16 -> FInt 'FTInt2
FInt2 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n
          SFTInt kout
SFTInt4  -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int32 -> FInt 'FTInt4
FInt4 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n
          SFTInt kout
SFTInt8  -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int64 -> FInt 'FTInt8
FInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n
          SFTInt kout
SFTInt16 -> forall a b. a -> Either a b
Left String
"can't represent INTEGER(16) yet, sorry"

-- can also define this (and stronger funcs) with singletons
fIntType :: FInt (k :: FTInt) -> FTInt
fIntType :: forall (k :: FTInt). FInt k -> FTInt
fIntType = \case
  FInt1{} -> FTInt
FTInt1
  FInt2{} -> FTInt
FTInt2
  FInt4{} -> FTInt
FTInt4
  FInt8{} -> FTInt
FTInt8