{-|
Copyright  :  (C) 2019, Myrtle Software Ltd
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE CPP           #-}
{-# LANGUAGE DataKinds     #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE TypeFamilies  #-}
{-# LANGUAGE TypeOperators #-}

#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType  #-}
#endif

{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}

module Clash.Class.Exp (Exp, ExpResult, (^)) where

import qualified Prelude                       as P
import           Prelude                       hiding ((^))

import           Clash.Annotations.Primitive   (hasBlackBox)
import           Clash.Promoted.Nat            (SNat(..), snatToInteger)
import           Clash.Sized.Internal.Index    (Index)
import           Clash.Sized.Internal.Signed   (Signed)
import           Clash.Sized.Internal.Unsigned (Unsigned)

import           GHC.TypeLits
  (KnownNat, Nat, type (^), type (*))

-- | Type class implementing exponentiation with explicitly resizing results.
class Exp a where
  type ExpResult a (n :: Nat)

  -- | Exponentiation with known exponent.
  (^)
    :: a
    -- ^ Base
    -> SNat n
    -- ^ Exponent
    -> ExpResult a n
    -- ^ Resized result, guaranteed to not have overflown

instance KnownNat m => Exp (Index m) where
  type ExpResult (Index m) n = Index (m ^ n)

  ^ :: Index m -> SNat n -> ExpResult (Index m) n
(^) = Index m -> SNat n -> ExpResult (Index m) n
forall (m :: Nat) (n :: Nat).
KnownNat m =>
Index m -> SNat n -> Index (m ^ n)
expIndex#
  {-# INLINE (^) #-}

instance KnownNat m => Exp (Signed m) where
  type ExpResult (Signed m) n = Signed (m * n)

  ^ :: Signed m -> SNat n -> ExpResult (Signed m) n
(^) = Signed m -> SNat n -> ExpResult (Signed m) n
forall (m :: Nat) (n :: Nat).
KnownNat m =>
Signed m -> SNat n -> Signed (m * n)
expSigned#
  {-# INLINE (^) #-}

instance KnownNat m => Exp (Unsigned m) where
  type ExpResult (Unsigned m) n = Unsigned (m * n)

  ^ :: Unsigned m -> SNat n -> ExpResult (Unsigned m) n
(^) = Unsigned m -> SNat n -> ExpResult (Unsigned m) n
forall (m :: Nat) (n :: Nat).
KnownNat m =>
Unsigned m -> SNat n -> Unsigned (m * n)
expUnsigned#
  {-# INLINE (^) #-}

expIndex#
  :: KnownNat m
  => Index m
  -> SNat n
  -> Index (m ^ n)
expIndex# :: Index m -> SNat n -> Index (m ^ n)
expIndex# b :: Index m
b e :: SNat n
e@SNat n
SNat =
  Integer -> Index (m ^ n)
forall a. Num a => Integer -> a
fromInteger (Index m -> Integer
forall a. Integral a => a -> Integer
toInteger Index m
b Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
P.^ SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
e)
{-# NOINLINE expIndex# #-}
{-# ANN expIndex# hasBlackBox #-}

expSigned#
  :: KnownNat m
  => Signed m
  -> SNat n
  -> Signed (m * n)
expSigned# :: Signed m -> SNat n -> Signed (m * n)
expSigned# b :: Signed m
b e :: SNat n
e@SNat n
SNat =
  Integer -> Signed (m * n)
forall a. Num a => Integer -> a
fromInteger (Signed m -> Integer
forall a. Integral a => a -> Integer
toInteger Signed m
b Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
P.^ SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
e)
{-# NOINLINE expSigned# #-}
{-# ANN expSigned# hasBlackBox #-}

expUnsigned#
  :: KnownNat m
  => Unsigned m
  -> SNat n
  -> Unsigned (m * n)
expUnsigned# :: Unsigned m -> SNat n -> Unsigned (m * n)
expUnsigned# b :: Unsigned m
b e :: SNat n
e@SNat n
SNat =
  Integer -> Unsigned (m * n)
forall a. Num a => Integer -> a
fromInteger (Unsigned m -> Integer
forall a. Integral a => a -> Integer
toInteger Unsigned m
b Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
P.^ SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
e)
{-# NOINLINE expUnsigned# #-}
{-# ANN expUnsigned# hasBlackBox #-}