{-# LANGUAGE MagicHash, UnboxedTuples #-}
{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, FunctionalDependencies #-}
{-# LANGUAGE DataKinds, PolyKinds, TypeFamilies #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE CPP #-}

#include "MachDeps.h"

module Language.Asm.Inline
( defineAsmFun
, defineAsmFunM
, Unit(..)
) where

import qualified Data.ByteString as BS
import Control.Monad
import Control.Monad.Primitive
import Data.Generics.Uniplate.Data
import Data.List
import Foreign.Ptr
import GHC.Int
import GHC.Prim
import GHC.Ptr
import GHC.Types hiding (Type)
import GHC.Word
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import System.IO.Unsafe

import Language.Asm.Inline.AsmCode
import Language.Asm.Inline.Util

class AsmArg a (rep :: RuntimeRep) (unboxedTy :: TYPE rep) | a -> rep, a -> unboxedTy where
  unbox :: a -> unboxedTy
  rebox :: unboxedTy -> a

data Unit = Unit

instance AsmArg Unit 'IntRep Int# where
  unbox :: Unit -> Int#
unbox Unit
_ = Int#
0#
  rebox :: Int# -> Unit
rebox Int#
_ = Unit
Unit

instance AsmArg Int 'IntRep Int# where
  unbox :: Int -> Int#
unbox (I# Int#
w) = Int#
w
  rebox :: Int# -> Int
rebox = Int# -> Int
I#

instance AsmArg Int8 'IntRep Int# where
  unbox :: Int8 -> Int#
unbox (I8# Int#
w) = Int#
w
  rebox :: Int# -> Int8
rebox = Int# -> Int8
I8#

instance AsmArg Int16 'IntRep Int# where
  unbox :: Int16 -> Int#
unbox (I16# Int#
w) = Int#
w
  rebox :: Int# -> Int16
rebox = Int# -> Int16
I16#

instance AsmArg Int32 'IntRep Int# where
  unbox :: Int32 -> Int#
unbox (I32# Int#
w) = Int#
w
  rebox :: Int# -> Int32
rebox = Int# -> Int32
I32#

#if WORD_SIZE_IN_BITS > 32
instance AsmArg Int64 'IntRep Int# where
#else
instance AsmArg Int64 'Int64Rep Int64# where
#endif
  unbox :: Int64 -> Int#
unbox (I64# Int#
w) = Int#
w
  rebox :: Int# -> Int64
rebox = Int# -> Int64
I64#

instance AsmArg Word 'WordRep Word# where
  unbox :: Word -> Word#
unbox (W# Word#
w) = Word#
w
  rebox :: Word# -> Word
rebox = Word# -> Word
W#

instance AsmArg Word8 'WordRep Word# where
  unbox :: Word8 -> Word#
unbox (W8# Word#
w) = Word#
w
  rebox :: Word# -> Word8
rebox = Word# -> Word8
W8#

instance AsmArg Word16 'WordRep Word# where
  unbox :: Word16 -> Word#
unbox (W16# Word#
w) = Word#
w
  rebox :: Word# -> Word16
rebox = Word# -> Word16
W16#

instance AsmArg Word32 'WordRep Word# where
  unbox :: Word32 -> Word#
unbox (W32# Word#
w) = Word#
w
  rebox :: Word# -> Word32
rebox = Word# -> Word32
W32#

#if WORD_SIZE_IN_BITS > 32
instance AsmArg Word64 'WordRep Word# where
#else
instance AsmArg Word64 'Word64Rep Word64# where
#endif
  unbox :: Word64 -> Word#
unbox (W64# Word#
w) = Word#
w
  rebox :: Word# -> Word64
rebox = Word# -> Word64
W64#

instance AsmArg Double 'DoubleRep Double# where
  unbox :: Double -> Double#
unbox (D# Double#
d) = Double#
d
  rebox :: Double# -> Double
rebox = Double# -> Double
D#

instance AsmArg Float 'FloatRep Float# where
  unbox :: Float -> Float#
unbox (F# Float#
f) = Float#
f
  rebox :: Float# -> Float
rebox = Float# -> Float
F#

instance AsmArg (Ptr a) 'AddrRep Addr# where
  unbox :: Ptr a -> Addr#
unbox (Ptr Addr#
p) = Addr#
p
  rebox :: Addr# -> Ptr a
rebox = Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr

replace :: String -> String -> String -> String
replace :: String -> String -> String -> String
replace String
what String
with = String -> String
go
  where
    go :: String -> String
go [] = []
    go str :: String
str@(Char
s:String
ss) | String
what String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
str = String
with String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
go (Int -> String -> String
forall a. Int -> [a] -> [a]
drop (String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
what) String
str)
                  | Bool
otherwise = Char
s Char -> String -> String
forall a. a -> [a] -> [a]
: String -> String
go String
ss

data FunKind = Pure | Monadic

defineAsmFunImpl :: AsmCode tyAnn code => FunKind -> String -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl :: FunKind -> String -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
kind String
name tyAnn
tyAnn code
asmCode = do
  ForeignSrcLang -> String -> Q ()
addForeignSource ForeignSrcLang
LangAsm (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [ String
".global " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
asmName
                                     , String
asmName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":"
                                     , String -> String -> String -> String
replace String
"RET_HASK" String
retToHask (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ tyAnn -> code -> String
forall tyAnn code. AsmCode tyAnn code => tyAnn -> code -> String
codeToString tyAnn
tyAnn code
asmCode
                                     , String
retToHask
                                     ]
  Type
funTy <- tyAnn -> Q Type
forall tyAnn code. AsmCode tyAnn code => tyAnn -> Q Type
toTypeQ tyAnn
tyAnn
  (Type
importedTy, Type
sigTy) <- case FunKind
kind of
                              FunKind
Pure -> (Type, Type) -> Q (Type, Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
funTy, Type
funTy)
                              FunKind
Monadic -> (,) (Type -> Type -> (Type, Type))
-> Q Type -> Q (Type -> (Type, Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> Q Type
stateifyUnlifted Type
funTy Q (Type -> (Type, Type)) -> Q Type -> Q (Type, Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> Q Type
stateifyLifted Type
funTy
  let importedName :: Name
importedName = String -> Name
mkName String
asmName
  Dec
wrapperFunD <- FunKind -> String -> Name -> Type -> Q Dec
mkFunD FunKind
kind String
name Name
importedName Type
funTy
  [Dec] -> Q [Dec]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ Foreign -> Dec
ForeignD (Foreign -> Dec) -> Foreign -> Dec
forall a b. (a -> b) -> a -> b
$ Callconv -> Safety -> String -> Name -> Type -> Foreign
ImportF Callconv
Prim Safety
Safe String
asmName Name
importedName (Type -> Foreign) -> Type -> Foreign
forall a b. (a -> b) -> a -> b
$ Type -> Type
unliftType Type
importedTy
    , Name -> Type -> Dec
SigD Name
name' Type
sigTy
    , Dec
wrapperFunD
    , Pragma -> Dec
PragmaD (Pragma -> Dec) -> Pragma -> Dec
forall a b. (a -> b) -> a -> b
$ Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP Name
name' Inline
Inline RuleMatch
FunLike Phases
AllPhases
    ]
  where
    name' :: Name
name' = String -> Name
mkName String
name
    asmName :: String
asmName = String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_unlifted"
    retToHask :: String
retToHask = String
"jmp *(%rbp)"

defineAsmFun :: AsmCode tyAnn code => String -> tyAnn -> code -> Q [Dec]
defineAsmFun :: String -> tyAnn -> code -> Q [Dec]
defineAsmFun = FunKind -> String -> tyAnn -> code -> Q [Dec]
forall tyAnn code.
AsmCode tyAnn code =>
FunKind -> String -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
Pure

defineAsmFunM :: AsmCode tyAnn code => String -> tyAnn -> code -> Q [Dec]
defineAsmFunM :: String -> tyAnn -> code -> Q [Dec]
defineAsmFunM = FunKind -> String -> tyAnn -> code -> Q [Dec]
forall tyAnn code.
AsmCode tyAnn code =>
FunKind -> String -> tyAnn -> code -> Q [Dec]
defineAsmFunImpl FunKind
Monadic

-- |Converts the wrapped function type to live in a 'PrimMonad':
-- given 'Ty1 -> Ty2 -> Ret' it produces
-- 'forall m. PrimMonad m => Ty1 -> Ty2 -> m Ret'.
stateifyLifted :: Type -> Q Type
stateifyLifted :: Type -> Q Type
stateifyLifted Type
ty = do
  Name
m <- String -> Q Name
newName String
"m"
  [TyVarBndr] -> Cxt -> Type -> Type
ForallT [Name -> TyVarBndr
PlainTV Name
m] [Type -> Type -> Type
AppT (Name -> Type
ConT ''PrimMonad) (Name -> Type
VarT Name
m)] (Type -> Type) -> Q Type -> Q Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> Q Type
go Name
m Type
ty
  where
    go :: Name -> Type -> Q Type
go Name
m (AppT (AppT Type
ArrowT Type
lhs) Type
rhs) = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
lhs) (Type -> Type) -> Q Type -> Q Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> Q Type
go Name
m Type
rhs
    go Name
m Type
rhs = [t| $(pure $ VarT m) $(pure rhs) |]

-- |Converts the unwrapped/unlifted function type to be a 'primitive' action:
-- given 'Ty1# -> Ty2# -> Ret#' it produces
-- 'forall s. Ty1# -> Ty2# -> State# s -> (# State# s, Ret# #)'.
stateifyUnlifted :: Type -> Q Type
stateifyUnlifted :: Type -> Q Type
stateifyUnlifted Type
ty = do
  Name
s <- String -> Q Name
newName String
"s"
  [TyVarBndr] -> Cxt -> Type -> Type
ForallT [Name -> TyVarBndr
PlainTV Name
s] [] (Type -> Type) -> Q Type -> Q Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> Q Type
go Name
s Type
ty
  where
    go :: Name -> Type -> Q Type
go Name
s (AppT (AppT Type
ArrowT Type
lhs) Type
rhs) = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
lhs) (Type -> Type) -> Q Type -> Q Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> Type -> Q Type
go Name
s Type
rhs
    go Name
s Type
rhs = [t| State# $(pure $ VarT s) -> (# State# $(pure $ VarT s), $(pure rhs) #) |]

mkFunD :: FunKind -> String -> Name -> Type -> Q Dec
mkFunD :: FunKind -> String -> Name -> Type -> Q Dec
mkFunD FunKind
kind String
funName Name
importedName Type
funTy = do
  Name
token <- String -> Q Name
newName String
"token"
  [Name]
argNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Type -> Int
countArgs Type
funTy) (Q Name -> Q [Name]) -> Q Name -> Q [Name]
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"arg"
  Exp
funAppE <- (Exp -> (Exp, Type) -> Q Exp) -> Exp -> [(Exp, Type)] -> Q Exp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Exp -> (Exp, Type) -> Q Exp
f (Name -> Exp
VarE Name
importedName) ([(Exp, Type)] -> Q Exp) -> [(Exp, Type)] -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Cxt -> [(Exp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Name -> Exp
VarE (Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Type -> Cxt
getArgs Type
funTy)
  Exp
fullFunAppE <- case FunKind
kind of
                      FunKind
Pure -> Exp -> Q Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
funAppE
                      FunKind
Monadic -> [e| $(pure funAppE) $(pure $ VarE token) |]

  Exp
body <- case Type -> Maybe Int
detectRetTuple Type
funTy of
               Maybe Int
Nothing ->
                 case FunKind
kind of
                      FunKind
Pure ->
                        [e| rebox $(pure fullFunAppE) |]
                      FunKind
Monadic ->
                        [e| case $(pure fullFunAppE) of
                                 (# token', res #) -> (# token', rebox res #)
                          |]
               Just Int
n -> do
                  [Name]
retNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Q Name -> Q [Name]) -> Q Name -> Q [Name]
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"ret"
                  [Maybe Exp]
boxing <- [Name] -> (Name -> Q (Maybe Exp)) -> Q [Maybe Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Name]
retNames ((Name -> Q (Maybe Exp)) -> Q [Maybe Exp])
-> (Name -> Q (Maybe Exp)) -> Q [Maybe Exp]
forall a b. (a -> b) -> a -> b
$ \Name
name -> Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Q Exp -> Q (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [e| rebox $(pure $ VarE name) |]
                  case FunKind
kind of
                       FunKind
Pure ->
                          [e| case $(pure fullFunAppE) of
                                   $(pure $ UnboxedTupP $ VarP <$> retNames) -> $(pure $ TupE boxing)
                            |]
                       FunKind
Monadic ->
                          [e| case $(pure fullFunAppE) of
                                   (# token', $(pure $ UnboxedTupP $ VarP <$> retNames) #) -> (# token', $(pure $ TupE boxing) #)
                            |]

  Exp
body' <- case FunKind
kind of
                FunKind
Pure -> Exp -> Q Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
body
                FunKind
Monadic -> [e| primitive (\ $(pure $ VarP token) -> $(pure body)) |]
  Dec -> Q Dec
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$ Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
funName) [[Pat] -> Body -> [Dec] -> Clause
Clause (Name -> Pat
VarP (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Exp -> Body
NormalB Exp
body') []]
  where
    f :: Exp -> (Exp, Type) -> Q Exp
f Exp
acc (Exp
argName, Type
argType) | Type
argType Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT ''BS.ByteString = [e| $(pure acc)
                                                                            (unbox $ getBSAddr $(pure argName))
                                                                            (unbox $ BS.length $(pure argName))
                                                                   |]
                             | Bool
otherwise = [e| $(pure acc) (unbox $(pure argName)) |]

{-# NOINLINE unliftType #-}
unliftType :: Type -> Type
unliftType :: Type -> Type
unliftType = (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftTuple
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name -> Name) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Name -> Name
unliftBaseTy
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftPtrs
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftBS
  where
    unliftBaseTy :: Name -> Name
unliftBaseTy Name
x | Name
x Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ ''Word, ''Word8, ''Word16, ''Word32, ''Word64 ] = ''Word#
                   | Name
x Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ ''Int, ''Int8, ''Int16, ''Int32, ''Int64 ] = ''Int#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double = ''Double#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Float = ''Float#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Unit = ''Int#
                   | Bool
otherwise = Name
x

    unliftPtrs :: Type -> Type
unliftPtrs (AppT (ConT Name
name) Type
_) | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Ptr = Name -> Type
ConT ''Addr#
    unliftPtrs Type
x = Type
x

    unliftBS :: Type -> Type
unliftBS (AppT (AppT Type
ArrowT (ConT Name
bs)) Type
rhs) | Name
bs Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''BS.ByteString = IO Type -> Type
forall a. IO a -> a
unsafePerformIO (IO Type -> Type) -> IO Type -> Type
forall a b. (a -> b) -> a -> b
$ Q Type -> IO Type
forall (m :: * -> *) a. Quasi m => Q a -> m a
runQ [t| Addr# -> Int# -> $(pure rhs) |]
    unliftBS Type
x = Type
x

    unliftTuple :: Type -> Type
unliftTuple (TupleT Int
n) = Int -> Type
UnboxedTupleT Int
n
    unliftTuple Type
x = Type
x

detectRetTuple :: Type -> Maybe Int
detectRetTuple :: Type -> Maybe Int
detectRetTuple (AppT (AppT Type
ArrowT Type
_) Type
rhs) = Type -> Maybe Int
detectRetTuple Type
rhs
detectRetTuple (AppT Type
lhs Type
_) = Type -> Maybe Int
detectRetTuple Type
lhs
detectRetTuple (TupleT Int
n) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
n
detectRetTuple Type
_ = Maybe Int
forall a. Maybe a
Nothing

getArgs :: Type -> [Type]
getArgs :: Type -> Cxt
getArgs Type
ty = [ Type
argTy | AppT Type
ArrowT Type
argTy <- Type -> Cxt
forall from to. Biplate from to => from -> [to]
universeBi Type
ty ]

countArgs :: Type -> Int
countArgs :: Type -> Int
countArgs Type
ty = Cxt -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Cxt -> Int) -> Cxt -> Int
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> Cxt -> Cxt
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ArrowT) (Cxt -> Cxt) -> Cxt -> Cxt
forall a b. (a -> b) -> a -> b
$ Type -> Cxt
forall from to. Biplate from to => from -> [to]
universeBi Type
ty