{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE EmptyDataDecls #-}
module Knead.Color where

import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp)

import qualified LLVM.Extra.Nice.Value.Storable as Storable
import qualified LLVM.Extra.Nice.Vector as NiceVector
import qualified LLVM.Extra.Nice.Value as NiceValue
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Tuple as Tuple

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Foreign.Storable.Traversable as StoreTrav
import Foreign.Storable (Storable, sizeOf, alignment, poke, peek)
import Foreign.Ptr (Ptr)

import Control.Monad ((<=<))
import Control.Applicative (Applicative, liftA3, pure, (<*>))

import Data.Traversable (Traversable, traverse)
import Data.Foldable (Foldable, foldMap)
import Data.Monoid ((<>))



data C3 space a = C3 a a a

type YUV = C3 SpaceYUV; data SpaceYUV
type RGB = C3 SpaceRGB; data SpaceRGB

instance Functor (C3 space) where
   fmap f (C3 y u v) = C3 (f y) (f u) (f v)

instance Foldable (C3 space) where
   foldMap f (C3 y u v) = f y <> f u <> f v

instance Traversable (C3 space) where
   traverse f (C3 y u v) = liftA3 C3 (f y) (f u) (f v)

instance Applicative (C3 space) where
   pure a = C3 a a a
   C3 fy fu fv <*> C3 y u v = C3 (fy y) (fu u) (fv v)


instance (Storable a) => Storable (C3 space a) where
   sizeOf = StoreTrav.sizeOf . lazyElements
   alignment = StoreTrav.alignment
   peek = StoreTrav.peekApplicative
   poke = StoreTrav.poke

lazyElements :: C3 space a -> C3 space a
lazyElements ~(C3 y u v) = C3 y u v

instance
   (Storable.Vector a, LLVM.IsPrimitive a, LLVM.IsConst a,
    NiceVector.C a,
    NiceVector.Repr TypeNum.D3 a ~ LLVM.Value (LLVM.Vector TypeNum.D3 a)) =>
      Storable.C (C3 space a) where
   load = fmap NiceValue.cast . Storable.load <=< castVectorPtr
   store x = Storable.store (NiceValue.cast x) <=< castVectorPtr

castVectorPtr ::
   LLVM.Value (Ptr (C3 space a)) ->
   LLVM.CodeGenFunction r (LLVM.Value (Ptr (LLVM.Vector TypeNum.D3 a)))
castVectorPtr = LLVM.bitcast


instance
   (LLVM.IsPrimitive a, LLVM.IsConst a) =>
      Tuple.Value (C3 space a) where
   type ValueOf (C3 space a) = LLVM.Value (LLVM.Vector TypeNum.D3 a)
   valueOf (C3 a0 a1 a2) = LLVM.valueOf $ LLVM.consVector a0 a1 a2

instance
   (LLVM.IsPrimitive a, LLVM.IsConst a) =>
      NiceValue.C (C3 space a) where
   type Repr (C3 space a) = LLVM.Value (LLVM.Vector TypeNum.D3 a)
   cons (C3 a0 a1 a2) =
      NiceValue.Cons $ LLVM.valueOf $ LLVM.consVector a0 a1 a2
   undef = NiceValue.undefTuple
   zero = NiceValue.zeroTuple
   phi = NiceValue.phiTuple
   addPhi = NiceValue.addPhiTuple


cons ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
   Exp a -> Exp a -> Exp a -> Exp (C3 space a)
cons =
   Expr.liftReprM3
      (\y u v -> do
         arr0 <- LLVM.insertelement Tuple.undef y (LLVM.valueOf 0)
         arr1 <- LLVM.insertelement arr0 u (LLVM.valueOf 1)
         LLVM.insertelement arr1 v (LLVM.valueOf 2))

yuv ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
   Exp a -> Exp a -> Exp a -> Exp (YUV a)
yuv = cons

rgb ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
   Exp a -> Exp a -> Exp a -> Exp (RGB a)
rgb = cons

class Space space where
   brightness ::
      (NiceValue.Field a, NiceValue.RationalConstant a, NiceValue.Real a,
       LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
      Exp (C3 space a) -> Exp a

luma ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
   Exp (YUV a) -> Exp a
luma = Expr.liftReprM (flip LLVM.extractelement (LLVM.valueOf 0))

instance Space SpaceYUV where
   brightness = luma

red, green, blue ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a) =>
   Exp (RGB a) -> Exp a
red   = Expr.liftReprM (flip LLVM.extractelement (LLVM.valueOf 0))
green = Expr.liftReprM (flip LLVM.extractelement (LLVM.valueOf 1))
blue  = Expr.liftReprM (flip LLVM.extractelement (LLVM.valueOf 2))

instance Space SpaceRGB where
   brightness c = 0.299 * red c + 0.587 * green c + 0.114 * blue c


mapPlain ::
   (LLVM.IsPrimitive a, LLVM.IsPrimitive b) =>
   (forall r. LLVM.Value a -> LLVM.CodeGenFunction r (LLVM.Value b)) ->
   Exp (C3 space a) -> Exp (C3 space b)
mapPlain f = Expr.liftReprM (Vector.map f)

exprUnliftM1 ::
   (NiceValue.Repr a ~ al, NiceValue.Repr b ~ bl) =>
   (Exp a -> Exp b) -> al -> LLVM.CodeGenFunction r bl
exprUnliftM1 f a =
   fmap (\(NiceValue.Cons b) -> b) $ Expr.unliftM1 f $ NiceValue.Cons a

map ::
   (LLVM.IsPrimitive a, NiceValue.Repr a ~ LLVM.Value a,
    LLVM.IsPrimitive b, NiceValue.Repr b ~ LLVM.Value b) =>
   (Exp a -> Exp b) -> Exp (C3 space a) -> Exp (C3 space b)
map f = mapPlain (exprUnliftM1 f)
