{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
module LLVM.Extra.Extension (
   T, CallArgs,
   Subtarget(Subtarget), wrap,
   intrinsic, intrinsicAttr,
   run, runWhen, runUnsafe,
   with, with2, with3,
   ) where

import qualified LLVM.Core as LLVM
import LLVM.Core
   (Value, CodeGenFunction, externFunction, call,
    addAttributes, Attribute(ReadNoneAttribute), )

import Data.Map (Map, )
import qualified Data.Map as Map

import Control.Monad.Trans.Writer (Writer, writer, runWriter, )
import qualified Control.Monad.Trans.Writer as Writer
import Control.Monad (join, )
import Control.Applicative (Applicative, pure, (<*>), )

import Prelude hiding (replicate, sum, map, zipWith, )


data Subtarget =
   Subtarget {
      targetName, name :: String,
      check :: forall r. CodeGenFunction r Bool
   }


{- |
This is an Applicative functor that registers,
what extensions are needed in order to run the contained instructions.
You can escape from the functor by calling 'run'
and providing a generic implementation.

We use an applicative functor
since with a monadic interface
we had to create the specialised code in every case,
in order to see which extensions where used
in the course of creating the instructions.

We use only one (unparameterized) type for all extensions,
since this is the most simple solution.
Alternatively we could use a type parameter
where class constraints show what extensions are needed.
This would be just like exceptions that are explicit in the type signature
as in the control-monad-exception package.
However we would still need to lift all basic LLVM instructions to the new monad.
-}
newtype T a =
   Cons (Writer (Map String Subtarget) a)
   deriving (Functor, Applicative)

{- |
Declare that a certain plain LLVM instruction
depends on a particular extension.
This can be useful if you rely on the data layout
of a certain architecture when doing a bitcast,
or if you know that LLVM translates a certain generic operation
to something especially optimal for the declared extension.
-}
wrap :: Subtarget -> a -> T a
wrap tar cgf =
   Cons $
   writer (cgf, Map.singleton (name tar) tar)


{- | Analogous to 'LLVM.FunctionArgs'

The type parameter @r@ and its functional dependency are necessary
since @g@ must be a function of the form @a -> ... -> c -> CodeGenFunction r d@
and we must ensure that the explicit @r@ and the implicit @r@ in the @g@ do match.
-}
class CallArgs g r | g -> r where
   buildIntrinsic :: [Attribute] -> CodeGenFunction r g -> g

instance (CallArgs g r) =>
      CallArgs (Value a -> g) r where
   buildIntrinsic attrs g x =
      buildIntrinsic attrs (fmap ($x) g)

instance CallArgs (CodeGenFunction r (Value a)) r where
   buildIntrinsic attrs g = do
      z <- join g
      addAttributes z 0 attrs
      return z

{- |
Create an intrinsic and register the needed extension.
We cannot immediately check whether the signature matches
or whether the right extension is given.
However, when resolving intrinsics
LLVM will not find the intrinsic if the extension is wrong,
and it also checks the signature.
-}
intrinsic ::
   (LLVM.IsFunction f, LLVM.CallArgs f g, CallArgs g r) =>
   Subtarget -> String -> T g
intrinsic =
   intrinsicAttr [ReadNoneAttribute]

intrinsicAttr ::
   (LLVM.IsFunction f, LLVM.CallArgs f g, CallArgs g r) =>
   [Attribute] -> Subtarget -> String -> T g
intrinsicAttr attrs tar intr =
   wrap tar $
   buildIntrinsic attrs $
   fmap call $
   externFunction $
      "llvm." ++ targetName tar ++ "." ++ name tar ++ "." ++ intr


infixl 1 `run`

{- |
@run generic specific@ generates the @specific@ code
if the required extensions are available on the host processor
and @generic@ otherwise.
-}
run ::
   CodeGenFunction r a ->
   T (CodeGenFunction r a) ->
   CodeGenFunction r a
run alt (Cons m) = do
   let (a,s) = runWriter m
   b <- mapM check (Map.elems s)
   if and b
     then a
     else alt

{- |
Convenient variant of 'run':
Only run the code with extended instructions
if an additional condition is given.
-}
runWhen ::
   Bool ->
   CodeGenFunction r a ->
   T (CodeGenFunction r a) ->
   CodeGenFunction r a
runWhen c alt (Cons m) = do
   let (a,s) = runWriter m
   b <- mapM check (Map.elems s)
   if c && and b
     then a
     else alt

{- |
Only for debugging purposes.
-}
runUnsafe ::
   T a -> a
runUnsafe (Cons m) =
   fst $ runWriter m


with :: (Functor f) => f a -> (a -> b) -> f b
with = flip fmap

with2 :: (Applicative f) => f a -> f b -> (a -> b -> c) -> f c
with2 a b f =
   pure f <*> a <*> b

with3 :: (Applicative f) => f a -> f b -> f c -> (a -> b -> c -> d) -> f d
with3 a b c f =
   pure f <*> a <*> b <*> c