module LLVM.Extra.TuplePrivate where

import qualified LLVM.Core as LLVM

import qualified Data.FixedLength as FixedLength
import Data.Complex (Complex)

import qualified Type.Data.Num.Unary as Unary

import qualified Control.Applicative.HT as App
import Control.Applicative (Applicative, liftA2, pure)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold

import Data.Orphans ()



-- * class for phi operating on value tuples

class Phi a where
   phi :: LLVM.BasicBlock -> a -> LLVM.CodeGenFunction r a
   addPhi :: LLVM.BasicBlock -> a -> a -> LLVM.CodeGenFunction r ()

instance Phi () where
   phi _ _ = return ()
   addPhi _ _ _ = return ()

instance (LLVM.IsFirstClass a) => Phi (LLVM.Value a) where
   phi bb a = LLVM.phi [(a, bb)]
   addPhi bb a a' = LLVM.addPhiInputs a [(a', bb)]

instance (Phi a, Phi b) => Phi (a, b) where
   phi bb = App.mapPair (phi bb, phi bb)
   addPhi bb (a0,b0) (a1,b1) = do
      addPhi bb a0 a1
      addPhi bb b0 b1

instance (Phi a, Phi b, Phi c) => Phi (a, b, c) where
   phi bb = App.mapTriple (phi bb, phi bb, phi bb)
   addPhi bb (a0,b0,c0) (a1,b1,c1) = do
      addPhi bb a0 a1
      addPhi bb b0 b1
      addPhi bb c0 c1

instance (Phi a, Phi b, Phi c, Phi d) => Phi (a, b, c, d) where
   phi bb (a,b,c,d) =
      App.lift4 (,,,) (phi bb a) (phi bb b) (phi bb c) (phi bb d)
   addPhi bb (a0,b0,c0,d0) (a1,b1,c1,d1) = do
      addPhi bb a0 a1
      addPhi bb b0 b1
      addPhi bb c0 c1
      addPhi bb d0 d1

instance (Phi a) => Phi (Complex a) where
   phi = phiTraversable
   addPhi = addPhiFoldable

instance (Unary.Natural n, Phi a) => Phi (FixedLength.T n a) where
   phi = phiTraversable
   addPhi = addPhiFoldable

phiTraversable ::
   (Phi a, Trav.Traversable f) =>
   LLVM.BasicBlock -> f a -> LLVM.CodeGenFunction r (f a)
phiTraversable bb x = Trav.mapM (phi bb) x

addPhiFoldable ::
   (Phi a, Fold.Foldable f, Applicative f) =>
   LLVM.BasicBlock -> f a -> f a -> LLVM.CodeGenFunction r ()
addPhiFoldable bb x y = Fold.sequence_ (liftA2 (addPhi bb) x y)


-- * class for tuples of undefined values

class Undefined a where
   undef :: a

instance Undefined () where
   undef = ()

instance (LLVM.IsFirstClass a) => Undefined (LLVM.Value a) where
   undef = LLVM.value LLVM.undef

instance (LLVM.IsFirstClass a) => Undefined (LLVM.ConstValue a) where
   undef = LLVM.undef

instance (Undefined a, Undefined b) => Undefined (a, b) where
   undef = (undef, undef)

instance (Undefined a, Undefined b, Undefined c) => Undefined (a, b, c) where
   undef = (undef, undef, undef)

instance
   (Undefined a, Undefined b, Undefined c, Undefined d) =>
      Undefined (a, b, c, d) where
   undef = (undef, undef, undef, undef)

instance (Undefined a) => Undefined (Complex a) where
   undef = undefPointed

instance (Unary.Natural n, Undefined a) => Undefined (FixedLength.T n a) where
   undef = undefPointed

undefPointed :: (Undefined a, Applicative f) => f a
undefPointed = pure undef


-- * class for tuples of zero values

class Zero a where
   zero :: a

instance Zero () where
   zero = ()

instance (LLVM.IsFirstClass a) => Zero (LLVM.Value a) where
   zero = LLVM.value LLVM.zero

instance (LLVM.IsFirstClass a) => Zero (LLVM.ConstValue a) where
   zero = LLVM.zero

instance (Zero a, Zero b) => Zero (a, b) where
   zero = (zero, zero)

instance (Zero a, Zero b, Zero c) => Zero (a, b, c) where
   zero = (zero, zero, zero)

instance (Zero a, Zero b, Zero c, Zero d) => Zero (a, b, c, d) where
   zero = (zero, zero, zero, zero)

instance (Zero a) => Zero (Complex a) where
   zero = zeroPointed

instance (Unary.Natural n, Zero a) => Zero (FixedLength.T n a) where
   zero = zeroPointed

zeroPointed :: (Zero a, Applicative f) => f a
zeroPointed = pure zero