{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Multi.Class where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum


class C value where
   type Size value :: *
   switch ::
      f MultiValue.T ->
      f (MultiVector.T (Size value)) ->
      f value

instance C MultiValue.T where
   type Size MultiValue.T = TypeNum.D1
   switch x _ = x

instance (TypeNum.Positive n) => C (MultiVector.T n) where
   type Size (MultiVector.T n) = n
   switch _ x = x


newtype Const a value = Const {getConst :: value a}

undef ::
   (C value, Size value ~ n, TypeNum.Positive n, MultiVector.C a) =>
   value a
undef =
   getConst $
   switch
      (Const MultiValue.undef)
      (Const MultiVector.undef)

zero ::
   (C value, Size value ~ n, TypeNum.Positive n, MultiVector.C a) =>
   value a
zero =
   getConst $
   switch
      (Const MultiValue.zero)
      (Const MultiVector.zero)


newtype
   Op0 r a value =
      Op0 {runOp0 :: LLVM.CodeGenFunction r (value a)}

newtype
   Op1 r a b value =
      Op1 {runOp1 :: value a -> LLVM.CodeGenFunction r (value b)}

newtype
   Op2 r a b c value =
      Op2 {runOp2 :: value a -> value b -> LLVM.CodeGenFunction r (value c)}

add, sub ::
   (TypeNum.Positive n, MultiVector.Additive a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
add = runOp2 $ switch (Op2 A.add) (Op2 A.add)
sub = runOp2 $ switch (Op2 A.sub) (Op2 A.sub)

neg ::
   (TypeNum.Positive n, MultiVector.Additive a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
neg = runOp1 $ switch (Op1 A.neg) (Op1 A.neg)


mul ::
   (TypeNum.Positive n, MultiVector.PseudoRing a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
mul = runOp2 $ switch (Op2 A.mul) (Op2 A.mul)
fdiv ::
   (TypeNum.Positive n, MultiVector.Field a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
fdiv = runOp2 $ switch (Op2 A.fdiv) (Op2 A.fdiv)

scale ::
   (TypeNum.Positive n, MultiVector.PseudoModule v,
    n ~ Size value, C value) =>
   value (MultiValue.Scalar v) -> value v -> LLVM.CodeGenFunction r (value v)
scale = runOp2 $ switch (Op2 A.scale) (Op2 A.scale)

min, max ::
   (TypeNum.Positive n, MultiVector.Real a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
min = runOp2 $ switch (Op2 A.min) (Op2 A.min)
max = runOp2 $ switch (Op2 A.max) (Op2 A.max)

abs, signum ::
   (TypeNum.Positive n, MultiVector.Real a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
abs = runOp1 $ switch (Op1 A.abs) (Op1 A.abs)
signum = runOp1 $ switch (Op1 A.signum) (Op1 A.signum)

truncate, fraction ::
   (TypeNum.Positive n, MultiVector.Fraction a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
truncate = runOp1 $ switch (Op1 A.truncate) (Op1 A.truncate)
fraction = runOp1 $ switch (Op1 A.fraction) (Op1 A.fraction)

sqrt ::
   (TypeNum.Positive n, MultiVector.Algebraic a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
sqrt = runOp1 $ switch (Op1 A.sqrt) (Op1 A.sqrt)

pi ::
   (TypeNum.Positive n, MultiVector.Transcendental a,
    n ~ Size value, C value) =>
   LLVM.CodeGenFunction r (value a)
pi = runOp0 $ switch (Op0 A.pi) (Op0 A.pi)

sin, cos, exp, log ::
   (TypeNum.Positive n, MultiVector.Transcendental a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
sin = runOp1 $ switch (Op1 A.sin) (Op1 A.sin)
cos = runOp1 $ switch (Op1 A.cos) (Op1 A.cos)
exp = runOp1 $ switch (Op1 A.exp) (Op1 A.exp)
log = runOp1 $ switch (Op1 A.log) (Op1 A.log)

pow ::
   (TypeNum.Positive n, MultiVector.Transcendental a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
pow = runOp2 $ switch (Op2 A.pow) (Op2 A.pow)


cmp ::
   (TypeNum.Positive n, MultiVector.Comparison a,
    n ~ Size value, C value) =>
   LLVM.CmpPredicate ->
   value a -> value a -> LLVM.CodeGenFunction r (value Bool)
cmp p = runOp2 $ switch (Op2 $ A.cmp p) (Op2 $ A.cmp p)

fcmp ::
   (TypeNum.Positive n, MultiVector.FloatingComparison a,
    n ~ Size value, C value) =>
   LLVM.FPPredicate ->
   value a -> value a -> LLVM.CodeGenFunction r (value Bool)
fcmp p = runOp2 $ switch (Op2 $ A.fcmp p) (Op2 $ A.fcmp p)


and, or, xor ::
   (TypeNum.Positive n, MultiVector.Logic a,
    n ~ Size value, C value) =>
   value a -> value a -> LLVM.CodeGenFunction r (value a)
and = runOp2 $ switch (Op2 A.and) (Op2 A.and)
or = runOp2 $ switch (Op2 A.or) (Op2 A.or)
xor = runOp2 $ switch (Op2 A.xor) (Op2 A.xor)

inv ::
   (TypeNum.Positive n, MultiVector.Logic a,
    n ~ Size value, C value) =>
   value a -> LLVM.CodeGenFunction r (value a)
inv = runOp1 $ switch (Op1 A.inv) (Op1 A.inv)