{-# LANGUAGE GADTs         #-}
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE RankNTypes    #-}
{-# LANGUAGE TypeFamilies  #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.CodeGen.IR
-- Copyright   : [2015..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.CodeGen.IR (

  Operands(..),
  IROP(..),

) where

import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation

import Data.Array.Accelerate.Error
import Data.Primitive.Vec

import qualified Data.ByteString.Short                              as B


-- We use a data family to represent sequences of LLVM (scalar) operands
-- representing a single Accelerate type. Using a data family rather than a type
-- family means that Operands is bijective.
--
data family Operands e :: *
data instance Operands ()         = OP_Unit
data instance Operands Int        = OP_Int     (Operand Int)
data instance Operands Int8       = OP_Int8    (Operand Int8)
data instance Operands Int16      = OP_Int16   (Operand Int16)
data instance Operands Int32      = OP_Int32   (Operand Int32)
data instance Operands Int64      = OP_Int64   (Operand Int64)
data instance Operands Word       = OP_Word    (Operand Word)
data instance Operands Word8      = OP_Word8   (Operand Word8)
data instance Operands Word16     = OP_Word16  (Operand Word16)
data instance Operands Word32     = OP_Word32  (Operand Word32)
data instance Operands Word64     = OP_Word64  (Operand Word64)
data instance Operands Half       = OP_Half    (Operand Half)
data instance Operands Float      = OP_Float   (Operand Float)
data instance Operands Double     = OP_Double  (Operand Double)
data instance Operands Bool       = OP_Bool    (Operand Bool)
data instance Operands (Vec n a)  = OP_Vec     (Operand (Vec n a))
data instance Operands (a,b)      = OP_Pair    (Operands a) (Operands b)


-- | Given some evidence that 'IR a' represents a scalar type, it can be
-- converted between the IR and Operand data types.
--
class IROP dict where
  op :: HasCallStack => dict a -> Operands a -> Operand a
  ir :: HasCallStack => dict a -> Operand a -> Operands a

instance IROP Type where
  ir :: Type a -> Operand a -> Operands a
ir Type a
VoidType     Operand a
_ = Operands a
Operands ()
OP_Unit
  ir (PrimType PrimType a
t) Operand a
x = PrimType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir PrimType a
t Operand a
x

  op :: Type a -> Operands a -> Operand a
op Type a
VoidType     Operands a
_ = Type () -> Name () -> Operand ()
forall a. Type a -> Name a -> Operand a
LocalReference Type ()
VoidType (ShortByteString -> Name ()
forall a. ShortByteString -> Name a
Name ShortByteString
B.empty)
  op (PrimType PrimType a
t) Operands a
x = PrimType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op PrimType a
t Operands a
x

instance IROP PrimType where
  op :: PrimType a -> Operands a -> Operand a
op (ScalarPrimType ScalarType a
t) = ScalarType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op ScalarType a
t
  op PrimType a
BoolPrimType       = \case OP_Bool x -> Operand a
Operand Bool
x
  op PrimType a
t                  = String -> Operands a -> Operand a
forall a. HasCallStack => String -> a
internalError (String
"unhandled type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType a -> String
forall a. Show a => a -> String
show PrimType a
t)
  ir :: PrimType a -> Operand a -> Operands a
ir (ScalarPrimType ScalarType a
t) = ScalarType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType a
t
  ir PrimType a
BoolPrimType       = Operand a -> Operands a
Operand Bool -> Operands Bool
OP_Bool
  ir PrimType a
t                  = String -> Operand a -> Operands a
forall a. HasCallStack => String -> a
internalError (String
"unhandled type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType a -> String
forall a. Show a => a -> String
show PrimType a
t)

instance IROP ScalarType where
  op :: ScalarType a -> Operands a -> Operand a
op (SingleScalarType SingleType a
t) = SingleType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op SingleType a
t
  op (VectorScalarType VectorType (Vec n a1)
t) = VectorType (Vec n a1) -> Operands (Vec n a1) -> Operand (Vec n a1)
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op VectorType (Vec n a1)
t
  ir :: ScalarType a -> Operand a -> Operands a
ir (SingleScalarType SingleType a
t) = SingleType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir SingleType a
t
  ir (VectorScalarType VectorType (Vec n a1)
t) = VectorType (Vec n a1) -> Operand (Vec n a1) -> Operands (Vec n a1)
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir VectorType (Vec n a1)
t

instance IROP SingleType where
  op :: SingleType a -> Operands a -> Operand a
op (NumSingleType NumType a
t) = NumType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op NumType a
t
  ir :: SingleType a -> Operand a -> Operands a
ir (NumSingleType NumType a
t) = NumType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir NumType a
t

instance IROP VectorType where
  op :: VectorType a -> Operands a -> Operand a
op (VectorType Int
_ SingleType a1
v) = SingleType a1 -> Operands (Vec n a1) -> Operand (Vec n a1)
forall t (n :: Nat).
SingleType t -> Operands (Vec n t) -> Operand (Vec n t)
single SingleType a1
v
    where
      single :: SingleType t -> Operands (Vec n t) -> Operand (Vec n t)
      single :: SingleType t -> Operands (Vec n t) -> Operand (Vec n t)
single (NumSingleType NumType t
t) = NumType t -> Operands (Vec n t) -> Operand (Vec n t)
forall t (n :: Nat).
NumType t -> Operands (Vec n t) -> Operand (Vec n t)
num NumType t
t

      num :: NumType t -> Operands (Vec n t) -> Operand (Vec n t)
      num :: NumType t -> Operands (Vec n t) -> Operand (Vec n t)
num (IntegralNumType IntegralType t
t) = IntegralType t -> Operands (Vec n t) -> Operand (Vec n t)
forall t (n :: Nat).
IntegralType t -> Operands (Vec n t) -> Operand (Vec n t)
integral IntegralType t
t
      num (FloatingNumType FloatingType t
t) = FloatingType t -> Operands (Vec n t) -> Operand (Vec n t)
forall t (n :: Nat).
FloatingType t -> Operands (Vec n t) -> Operand (Vec n t)
floating FloatingType t
t

      integral :: IntegralType t -> Operands (Vec n t) -> Operand (Vec n t)
      integral :: IntegralType t -> Operands (Vec n t) -> Operand (Vec n t)
integral IntegralType t
TypeInt    (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeInt8   (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeInt16  (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeInt32  (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeInt64  (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeWord   (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeWord8  (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeWord16 (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeWord32 (OP_Vec x) = Operand (Vec n t)
x
      integral IntegralType t
TypeWord64 (OP_Vec x) = Operand (Vec n t)
x

      floating :: FloatingType t -> Operands (Vec n t) -> Operand (Vec n t)
      floating :: FloatingType t -> Operands (Vec n t) -> Operand (Vec n t)
floating FloatingType t
TypeHalf   (OP_Vec x) = Operand (Vec n t)
x
      floating FloatingType t
TypeFloat  (OP_Vec x) = Operand (Vec n t)
x
      floating FloatingType t
TypeDouble (OP_Vec x) = Operand (Vec n t)
x

  ir :: VectorType a -> Operand a -> Operands a
ir (VectorType Int
_ SingleType a1
v) = SingleType a1 -> Operand (Vec n a1) -> Operands (Vec n a1)
forall t (n :: Nat).
SingleType t -> Operand (Vec n t) -> Operands (Vec n t)
single SingleType a1
v
    where
      single :: SingleType t -> Operand (Vec n t) -> Operands (Vec n t)
      single :: SingleType t -> Operand (Vec n t) -> Operands (Vec n t)
single (NumSingleType NumType t
t) = NumType t -> Operand (Vec n t) -> Operands (Vec n t)
forall t (n :: Nat).
NumType t -> Operand (Vec n t) -> Operands (Vec n t)
num NumType t
t

      num :: NumType t -> Operand (Vec n t) -> Operands (Vec n t)
      num :: NumType t -> Operand (Vec n t) -> Operands (Vec n t)
num (IntegralNumType IntegralType t
t) = IntegralType t -> Operand (Vec n t) -> Operands (Vec n t)
forall t (n :: Nat).
IntegralType t -> Operand (Vec n t) -> Operands (Vec n t)
integral IntegralType t
t
      num (FloatingNumType FloatingType t
t) = FloatingType t -> Operand (Vec n t) -> Operands (Vec n t)
forall t (n :: Nat).
FloatingType t -> Operand (Vec n t) -> Operands (Vec n t)
floating FloatingType t
t

      integral :: IntegralType t -> Operand (Vec n t) -> Operands (Vec n t)
      integral :: IntegralType t -> Operand (Vec n t) -> Operands (Vec n t)
integral IntegralType t
TypeInt    = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeInt8   = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeInt16  = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeInt32  = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeInt64  = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeWord   = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeWord8  = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeWord16 = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeWord32 = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      integral IntegralType t
TypeWord64 = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec

      floating :: FloatingType t -> Operand (Vec n t) -> Operands (Vec n t)
      floating :: FloatingType t -> Operand (Vec n t) -> Operands (Vec n t)
floating FloatingType t
TypeHalf   = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      floating FloatingType t
TypeFloat  = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec
      floating FloatingType t
TypeDouble = Operand (Vec n t) -> Operands (Vec n t)
forall (n :: Nat) a. Operand (Vec n a) -> Operands (Vec n a)
OP_Vec

instance IROP NumType where
  op :: NumType a -> Operands a -> Operand a
op (IntegralNumType IntegralType a
t) = IntegralType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType a
t
  op (FloatingNumType FloatingType a
t) = FloatingType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op FloatingType a
t
  ir :: NumType a -> Operand a -> Operands a
ir (IntegralNumType IntegralType a
t) = IntegralType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType a
t
  ir (FloatingNumType FloatingType a
t) = FloatingType a -> Operand a -> Operands a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir FloatingType a
t

instance IROP IntegralType where
  op :: IntegralType a -> Operands a -> Operand a
op IntegralType a
TypeInt     (OP_Int     x) = Operand a
Operand Int
x
  op IntegralType a
TypeInt8    (OP_Int8    x) = Operand a
Operand Int8
x
  op IntegralType a
TypeInt16   (OP_Int16   x) = Operand a
Operand Int16
x
  op IntegralType a
TypeInt32   (OP_Int32   x) = Operand a
Operand Int32
x
  op IntegralType a
TypeInt64   (OP_Int64   x) = Operand a
Operand Int64
x
  op IntegralType a
TypeWord    (OP_Word    x) = Operand a
Operand Word
x
  op IntegralType a
TypeWord8   (OP_Word8   x) = Operand a
Operand Word8
x
  op IntegralType a
TypeWord16  (OP_Word16  x) = Operand a
Operand Word16
x
  op IntegralType a
TypeWord32  (OP_Word32  x) = Operand a
Operand Word32
x
  op IntegralType a
TypeWord64  (OP_Word64  x) = Operand a
Operand Word64
x
  --
  ir :: IntegralType a -> Operand a -> Operands a
ir IntegralType a
TypeInt     = Operand a -> Operands a
Operand Int -> Operands Int
OP_Int
  ir IntegralType a
TypeInt8    = Operand a -> Operands a
Operand Int8 -> Operands Int8
OP_Int8
  ir IntegralType a
TypeInt16   = Operand a -> Operands a
Operand Int16 -> Operands Int16
OP_Int16
  ir IntegralType a
TypeInt32   = Operand a -> Operands a
Operand Int32 -> Operands Int32
OP_Int32
  ir IntegralType a
TypeInt64   = Operand a -> Operands a
Operand Int64 -> Operands Int64
OP_Int64
  ir IntegralType a
TypeWord    = Operand a -> Operands a
Operand Word -> Operands Word
OP_Word
  ir IntegralType a
TypeWord8   = Operand a -> Operands a
Operand Word8 -> Operands Word8
OP_Word8
  ir IntegralType a
TypeWord16  = Operand a -> Operands a
Operand Word16 -> Operands Word16
OP_Word16
  ir IntegralType a
TypeWord32  = Operand a -> Operands a
Operand Word32 -> Operands Word32
OP_Word32
  ir IntegralType a
TypeWord64  = Operand a -> Operands a
Operand Word64 -> Operands Word64
OP_Word64

instance IROP FloatingType where
  op :: FloatingType a -> Operands a -> Operand a
op FloatingType a
TypeHalf   (OP_Half   x) = Operand a
Operand Half
x
  op FloatingType a
TypeFloat  (OP_Float  x) = Operand a
Operand Float
x
  op FloatingType a
TypeDouble (OP_Double x) = Operand a
Operand Double
x
  --
  ir :: FloatingType a -> Operand a -> Operands a
ir FloatingType a
TypeHalf   = Operand a -> Operands a
Operand Half -> Operands Half
OP_Half
  ir FloatingType a
TypeFloat  = Operand a -> Operands a
Operand Float -> Operands Float
OP_Float
  ir FloatingType a
TypeDouble = Operand a -> Operands a
Operand Double -> Operands Double
OP_Double