{-# LANGUAGE DataKinds, MultiParamTypeClasses, TypeFamilies #-}

-- | Module, containing some boilerplate for support of
-- arithmetic operations in Michelson language.

module Michelson.Typed.Arith
  ( ArithOp (..)
  , UnaryArithOp (..)
  , ArithError (..)
  , ArithErrorType (..)
  , Add
  , Sub
  , Mul
  , Abs
  , Neg
  , Or
  , And
  , Xor
  , Not
  , Lsl
  , Lsr
  , Compare
  , Eq'
  , Neq
  , Lt
  , Gt
  , Le
  , Ge
  ) where

import Data.Bits (complement, shift, xor, (.&.), (.|.))
import Fmt (Buildable(build))

import Michelson.Typed.CValue (CValue(..))
import Michelson.Typed.T (CT(..))
import Tezos.Core (addMutez, mulMutez, subMutez, timestampFromSeconds, timestampToSeconds)

-- | Class for binary arithmetic operation.
--
-- Takes binary operation marker as @op@ parameter,
-- types of left operand @n@ and right operand @m@.
class ArithOp aop (n :: CT) (m :: CT) where

  -- | Type family @ArithRes@ denotes the type resulting from
  -- computing operation @op@ from operands of types @n@ and @m@.
  --
  -- For instance, adding integer to natural produces integer,
  -- which is reflected in following instance of type family:
  -- @ArithRes Add CNat CInt = CInt@.
  type ArithRes aop n m :: CT

  -- | Evaluate arithmetic operation on given operands.
  evalOp :: proxy aop -> CValue n -> CValue m -> Either (ArithError (CValue n) (CValue m)) (CValue (ArithRes aop n m))

-- | Denotes the error type occured in the arithmetic operation.
data ArithErrorType
  = AddOverflow
  | MulOverflow
  | SubUnderflow
  | LslOverflow
  | LsrUnderflow
  deriving (Show, Eq, Ord)

-- | Represents an arithmetic error of the operation.
data ArithError n m
  = MutezArithError ArithErrorType n m
  | ShiftArithError ArithErrorType n m
  deriving (Show, Eq, Ord)

-- | Marker data type for add operation.
class UnaryArithOp aop (n :: CT) where
  type UnaryArithRes aop n :: CT
  evalUnaryArithOp :: proxy aop -> CValue n -> CValue (UnaryArithRes aop n)

data Add
data Sub
data Mul
data Abs
data Neg

data Or
data And
data Xor
data Not
data Lsl
data Lsr

data Compare
data Eq'
data Neq
data Lt
data Gt
data Le
data Ge

instance ArithOp Add 'CNat 'CInt where
  type ArithRes Add 'CNat 'CInt = 'CInt
  evalOp _ (CvNat i) (CvInt j) = Right $ CvInt (toInteger i + j)
instance ArithOp Add 'CInt 'CNat where
  type ArithRes Add 'CInt 'CNat = 'CInt
  evalOp _ (CvInt i) (CvNat j) = Right $ CvInt (i + toInteger j)
instance ArithOp Add 'CNat 'CNat where
  type ArithRes Add 'CNat 'CNat = 'CNat
  evalOp _ (CvNat i) (CvNat j) = Right $ CvNat (i + j)
instance ArithOp Add 'CInt 'CInt where
  type ArithRes Add 'CInt 'CInt = 'CInt
  evalOp _ (CvInt i) (CvInt j) = Right $ CvInt (i + j)
instance ArithOp Add 'CTimestamp 'CInt where
  type ArithRes Add 'CTimestamp 'CInt = 'CTimestamp
  evalOp _ (CvTimestamp i) (CvInt j) =
    Right $ CvTimestamp $ timestampFromSeconds $ timestampToSeconds i + j
instance ArithOp Add 'CInt 'CTimestamp where
  type ArithRes Add 'CInt 'CTimestamp = 'CTimestamp
  evalOp _ (CvInt i) (CvTimestamp j) =
    Right $ CvTimestamp $ timestampFromSeconds $ timestampToSeconds j + i
instance ArithOp Add 'CMutez 'CMutez where
  type ArithRes Add 'CMutez 'CMutez = 'CMutez
  evalOp _ n@(CvMutez i) m@(CvMutez j) = res
    where
      res = maybe (Left $ MutezArithError AddOverflow n m) (Right . CvMutez) $ i `addMutez` j

instance ArithOp Sub 'CNat 'CInt where
  type ArithRes Sub 'CNat 'CInt = 'CInt
  evalOp _ (CvNat i) (CvInt j) = Right $ CvInt (toInteger i - j)
instance ArithOp Sub 'CInt 'CNat where
  type ArithRes Sub 'CInt 'CNat = 'CInt
  evalOp _ (CvInt i) (CvNat j) = Right $ CvInt (i - toInteger j)
instance ArithOp Sub 'CNat 'CNat where
  type ArithRes Sub 'CNat 'CNat = 'CInt
  evalOp _ (CvNat i) (CvNat j) = Right $ CvInt (toInteger i - toInteger j)
instance ArithOp Sub 'CInt 'CInt where
  type ArithRes Sub 'CInt 'CInt = 'CInt
  evalOp _ (CvInt i) (CvInt j) = Right $ CvInt (i - j)
instance ArithOp Sub 'CTimestamp 'CInt where
  type ArithRes Sub 'CTimestamp 'CInt = 'CTimestamp
  evalOp _ (CvTimestamp i) (CvInt j) =
    Right $ CvTimestamp $ timestampFromSeconds $ timestampToSeconds i - j
instance ArithOp Sub 'CTimestamp 'CTimestamp where
  type ArithRes Sub 'CTimestamp 'CTimestamp = 'CInt
  evalOp _ (CvTimestamp i) (CvTimestamp j) =
    Right $ CvInt $ timestampToSeconds i - timestampToSeconds j
instance ArithOp Sub 'CMutez 'CMutez where
  type ArithRes Sub 'CMutez 'CMutez = 'CMutez
  evalOp _ n@(CvMutez i) m@(CvMutez j) = res
    where
      res = maybe (Left $ MutezArithError SubUnderflow n m) (Right . CvMutez) $ i `subMutez` j

instance ArithOp Mul 'CNat 'CInt where
  type ArithRes Mul 'CNat 'CInt = 'CInt
  evalOp _ (CvNat i) (CvInt j) = Right $ CvInt (toInteger i * j)
instance ArithOp Mul 'CInt 'CNat where
  type ArithRes Mul 'CInt 'CNat = 'CInt
  evalOp _ (CvInt i) (CvNat j) = Right $ CvInt (i * toInteger j)
instance ArithOp Mul 'CNat 'CNat where
  type ArithRes Mul 'CNat 'CNat = 'CNat
  evalOp _ (CvNat i) (CvNat j) = Right $ CvNat (i * j)
instance ArithOp Mul 'CInt 'CInt where
  type ArithRes Mul 'CInt 'CInt = 'CInt
  evalOp _ (CvInt i) (CvInt j) = Right $ CvInt (i * j)
instance ArithOp Mul 'CNat 'CMutez where
  type ArithRes Mul 'CNat 'CMutez = 'CMutez
  evalOp _ n@(CvNat i) m@(CvMutez j) = res
    where
      res = maybe (Left $ MutezArithError MulOverflow n m) (Right . CvMutez) $ j `mulMutez` i
instance ArithOp Mul 'CMutez 'CNat where
  type ArithRes Mul 'CMutez 'CNat = 'CMutez
  evalOp _ n@(CvMutez i) m@(CvNat j) = res
    where
      res = maybe (Left $ MutezArithError MulOverflow n m) (Right . CvMutez) $ i `mulMutez` j

instance UnaryArithOp Abs 'CInt where
  type UnaryArithRes Abs 'CInt = 'CNat
  evalUnaryArithOp _ (CvInt i) = CvNat (fromInteger $ abs i)

instance UnaryArithOp Neg 'CInt where
  type UnaryArithRes Neg 'CInt = 'CInt
  evalUnaryArithOp _ (CvInt i) = CvInt (-i)
instance UnaryArithOp Neg 'CNat where
  type UnaryArithRes Neg 'CNat = 'CInt
  evalUnaryArithOp _ (CvNat i) = CvInt (- fromIntegral i)

instance ArithOp Or 'CNat 'CNat where
  type ArithRes Or 'CNat 'CNat = 'CNat
  evalOp _ (CvNat i) (CvNat j) = Right $ CvNat (i .|. j)
instance ArithOp Or 'CBool 'CBool where
  type ArithRes Or 'CBool 'CBool = 'CBool
  evalOp _ (CvBool i) (CvBool j) = Right $ CvBool (i .|. j)

instance ArithOp And 'CInt 'CNat where
  type ArithRes And 'CInt 'CNat = 'CInt
  evalOp _ (CvInt i) (CvNat j) = Right $ CvInt (i .&. fromIntegral j)
instance ArithOp And 'CNat 'CNat where
  type ArithRes And 'CNat 'CNat = 'CNat
  evalOp _ (CvNat i) (CvNat j) = Right $ CvNat (i .&. j)
instance ArithOp And 'CBool 'CBool where
  type ArithRes And 'CBool 'CBool = 'CBool
  evalOp _ (CvBool i) (CvBool j) = Right $ CvBool (i .&. j)

instance ArithOp Xor 'CNat 'CNat where
  type ArithRes Xor 'CNat 'CNat = 'CNat
  evalOp _ (CvNat i) (CvNat j) = Right $ CvNat (i `xor` j)
instance ArithOp Xor 'CBool 'CBool where
  type ArithRes Xor 'CBool 'CBool = 'CBool
  evalOp _ (CvBool i) (CvBool j) = Right $ CvBool (i `xor` j)

instance ArithOp Lsl 'CNat 'CNat where
  type ArithRes Lsl 'CNat 'CNat = 'CNat
  evalOp _ n@(CvNat i) m@(CvNat j) =
    if j > 256
    then Left $ ShiftArithError LslOverflow n m
    else Right $ CvNat (fromInteger $ shift (toInteger i) (fromIntegral j))

instance ArithOp Lsr 'CNat 'CNat where
  type ArithRes Lsr 'CNat 'CNat = 'CNat
  evalOp _ n@(CvNat i) m@(CvNat j) =
    if j > 256
    then Left $ ShiftArithError LsrUnderflow n m
    else Right $ CvNat (fromInteger $ shift (toInteger i) (-(fromIntegral j)))

instance UnaryArithOp Not 'CInt where
  type UnaryArithRes Not 'CInt = 'CInt
  evalUnaryArithOp _ (CvInt i) = CvInt (complement i)
instance UnaryArithOp Not 'CNat where
  type UnaryArithRes Not 'CNat = 'CInt
  evalUnaryArithOp _ (CvNat i) = CvInt (complement $ toInteger i)
instance UnaryArithOp Not 'CBool where
  type UnaryArithRes Not 'CBool = 'CBool
  evalUnaryArithOp _ (CvBool i) = CvBool (not i)

instance ArithOp Compare 'CBool 'CBool where
  type ArithRes Compare 'CBool 'CBool = 'CInt
  evalOp _ (CvBool i) (CvBool j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CAddress 'CAddress where
  type ArithRes Compare 'CAddress 'CAddress = 'CInt
  evalOp _ (CvAddress i) (CvAddress j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CNat 'CNat where
  type ArithRes Compare 'CNat 'CNat = 'CInt
  evalOp _ (CvNat i) (CvNat j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CInt 'CInt where
  type ArithRes Compare 'CInt 'CInt = 'CInt
  evalOp _ (CvInt i) (CvInt j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CString 'CString where
  type ArithRes Compare 'CString 'CString = 'CInt
  evalOp _ (CvString i) (CvString j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CBytes 'CBytes where
  type ArithRes Compare 'CBytes 'CBytes = 'CInt
  evalOp _ (CvBytes i) (CvBytes j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CTimestamp 'CTimestamp where
  type ArithRes Compare 'CTimestamp 'CTimestamp = 'CInt
  evalOp _ (CvTimestamp i) (CvTimestamp j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CMutez 'CMutez where
  type ArithRes Compare 'CMutez 'CMutez = 'CInt
  evalOp _ (CvMutez i) (CvMutez j) = Right $
    CvInt $ toInteger $ fromEnum (compare i j) - 1
instance ArithOp Compare 'CKeyHash 'CKeyHash where
  type ArithRes Compare 'CKeyHash 'CKeyHash = 'CInt
  evalOp _ (CvKeyHash i) (CvKeyHash j) =
    Right $ CvInt $ toInteger $ fromEnum (compare i j) - 1

instance UnaryArithOp Eq' 'CInt where
  type UnaryArithRes Eq' 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i == 0)

instance UnaryArithOp Neq 'CInt where
  type UnaryArithRes Neq 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i /= 0)


instance UnaryArithOp Lt 'CInt where
  type UnaryArithRes Lt 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i < 0)

instance UnaryArithOp Gt 'CInt where
  type UnaryArithRes Gt 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i > 0)

instance UnaryArithOp Le 'CInt where
  type UnaryArithRes Le 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i <= 0)

instance UnaryArithOp Ge 'CInt where
  type UnaryArithRes Ge 'CInt = 'CBool
  evalUnaryArithOp _ (CvInt i) = CvBool (i >= 0)


instance Buildable ArithErrorType where
  build AddOverflow = "add overflow"
  build MulOverflow = "mul overflow"
  build SubUnderflow = "sub overflow"
  build LslOverflow = "lsl overflow"
  build LsrUnderflow = "lsr underflow"

instance (Show n, Show m) => Buildable (ArithError n m) where
  build (MutezArithError errType n m) = "Mutez "
    <> build errType <> " with " <> show n <> ", " <> show m
  build (ShiftArithError errType n m) =
    build errType <> " with " <> show n <> ", " <> show m