{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
module LLVM.Extra.Multi.Vector.Memory where

import qualified LLVM.Extra.Multi.Vector as MultiVector
import LLVM.Extra.MemoryPrivate (decomposeFromLoad, composeFromStore, )

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction, Value, )

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:*:), )

import Foreign.Ptr (Ptr, )

import Control.Applicative (liftA2, )


class
   (TypeNum.Positive n, MultiVector.C a, LLVM.IsSized (Struct n a)) =>
      C n a where
   {-# MINIMAL (load|decompose), (store|compose) #-}
   type Struct n a :: *
   load :: Value (Ptr (Struct n a)) -> CodeGenFunction r (MultiVector.T n a)
   load ptr  =  decompose =<< LLVM.load ptr
   store :: MultiVector.T n a -> Value (Ptr (Struct n a)) -> CodeGenFunction r ()
   store r ptr  =  flip LLVM.store ptr =<< compose r
   decompose :: Value (Struct n a) -> CodeGenFunction r (MultiVector.T n a)
   decompose = decomposeFromLoad load
   compose :: MultiVector.T n a -> CodeGenFunction r (Value (Struct n a))
   compose = composeFromStore store

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D32)) =>
      C n Float where
   type Struct n Float = LLVM.Vector n Float
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance
   (TypeNum.Positive n, TypeNum.Positive (n :*: TypeNum.D64)) =>
      C n Double where
   type Struct n Double = LLVM.Vector n Double
   load = fmap MultiVector.consPrim . LLVM.load
   store = LLVM.store . MultiVector.deconsPrim
   decompose = return . MultiVector.consPrim
   compose = return . MultiVector.deconsPrim

instance (C n a, C n b) => C n (a,b) where
   type Struct n (a,b) = (LLVM.Struct (Struct n a, (Struct n b, ())))
   decompose ab =
      liftA2 MultiVector.zip
         (decompose =<< LLVM.extractvalue ab TypeNum.d0)
         (decompose =<< LLVM.extractvalue ab TypeNum.d1)
   compose ab =
      case MultiVector.unzip ab of
         (a,b) -> do
            sa <- compose a
            sb <- compose b
            ra <- LLVM.insertvalue (LLVM.value LLVM.undef) sa TypeNum.d0
            LLVM.insertvalue ra sb TypeNum.d1