{-# LANGUAGE MagicHash #-}
{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, FunctionalDependencies #-}
{-# LANGUAGE DataKinds, PolyKinds, TypeFamilies #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE CPP #-}

module Language.Asm.Inline(defineAsmFun) where

import qualified Data.ByteString as BS
import Control.Monad
import Data.Generics.Uniplate.Data
import Data.List
import Foreign.Ptr
import GHC.Prim
import GHC.Ptr
import GHC.Types hiding (Type)
import GHC.Word
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import System.IO.Unsafe

import Language.Asm.Inline.AsmCode
import Language.Asm.Inline.Util

class AsmArg a (rep :: RuntimeRep) (unboxedTy :: TYPE rep) | a -> rep, a -> unboxedTy where
  unbox :: a -> unboxedTy
  rebox :: unboxedTy -> a

instance AsmArg Int 'IntRep Int# where
  unbox :: Int -> Int#
unbox (I# Int#
w) = Int#
w
  rebox :: Int# -> Int
rebox = Int# -> Int
I#

instance AsmArg Word 'WordRep Word# where
  unbox :: Word -> Word#
unbox (W# Word#
w) = Word#
w
  rebox :: Word# -> Word
rebox = Word# -> Word
W#

instance AsmArg Word8 'WordRep Word# where
  unbox :: Word8 -> Word#
unbox (W8# Word#
w) = Word#
w
  rebox :: Word# -> Word8
rebox = Word# -> Word8
W8#

instance AsmArg Double 'DoubleRep Double# where
  unbox :: Double -> Double#
unbox (D# Double#
d) = Double#
d
  rebox :: Double# -> Double
rebox = Double# -> Double
D#

instance AsmArg Float 'FloatRep Float# where
  unbox :: Float -> Float#
unbox (F# Float#
f) = Float#
f
  rebox :: Float# -> Float
rebox = Float# -> Float
F#

instance AsmArg (Ptr a) 'AddrRep Addr# where
  unbox :: Ptr a -> Addr#
unbox (Ptr Addr#
p) = Addr#
p
  rebox :: Addr# -> Ptr a
rebox = Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr

replace :: String -> String -> String -> String
replace :: String -> String -> String -> String
replace String
what String
with = String -> String
go
  where
    go :: String -> String
go [] = []
    go str :: String
str@(Char
s:String
ss) | String
what String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
str = String
with String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
go (Int -> String -> String
forall a. Int -> [a] -> [a]
drop (String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
what) String
str)
                  | Bool
otherwise = Char
s Char -> String -> String
forall a. a -> [a] -> [a]
: String -> String
go String
ss

defineAsmFun :: AsmCode tyAnn code => String -> tyAnn -> code -> Q [Dec]
defineAsmFun :: String -> tyAnn -> code -> Q [Dec]
defineAsmFun String
name tyAnn
tyAnn code
asmCode = do
  ForeignSrcLang -> String -> Q ()
addForeignSource ForeignSrcLang
LangAsm (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines [ String
".global " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
asmName
                                     , String
asmName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
":"
                                     , String -> String -> String -> String
replace String
"RET_HASK" String
retToHask (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ tyAnn -> code -> String
forall tyAnn code. AsmCode tyAnn code => tyAnn -> code -> String
codeToString tyAnn
tyAnn code
asmCode
                                     , String
retToHask
                                     ]
  Type
funTy <- tyAnn -> Q Type
forall tyAnn code. AsmCode tyAnn code => tyAnn -> Q Type
toTypeQ tyAnn
tyAnn
  let importedName :: Name
importedName = String -> Name
mkName String
asmName
  Dec
wrapperFunD <- String -> Name -> Type -> Q Dec
mkFunD String
name Name
importedName Type
funTy
  [Dec] -> Q [Dec]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ Foreign -> Dec
ForeignD (Foreign -> Dec) -> Foreign -> Dec
forall a b. (a -> b) -> a -> b
$ Callconv -> Safety -> String -> Name -> Type -> Foreign
ImportF Callconv
Prim Safety
Safe String
asmName Name
importedName (Type -> Foreign) -> Type -> Foreign
forall a b. (a -> b) -> a -> b
$ Type -> Type
unliftType Type
funTy
    , Name -> Type -> Dec
SigD Name
name' Type
funTy
    , Dec
wrapperFunD
    , Pragma -> Dec
PragmaD (Pragma -> Dec) -> Pragma -> Dec
forall a b. (a -> b) -> a -> b
$ Name -> Inline -> RuleMatch -> Phases -> Pragma
InlineP Name
name' Inline
Inline RuleMatch
ConLike Phases
AllPhases
    ]
  where
    name' :: Name
name' = String -> Name
mkName String
name
    asmName :: String
asmName = String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_unlifted"
    retToHask :: String
retToHask = String
"jmp *(%rbp)"

mkFunD :: String -> Name -> Type -> Q Dec
mkFunD :: String -> Name -> Type -> Q Dec
mkFunD String
funName Name
importedName Type
funTy = do
  [Name]
argNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Type -> Int
countArgs Type
funTy) (Q Name -> Q [Name]) -> Q Name -> Q [Name]
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"arg"
  Exp
funAppE <- (Exp -> (Exp, Type) -> Q Exp) -> Exp -> [(Exp, Type)] -> Q Exp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Exp -> (Exp, Type) -> Q Exp
f (Name -> Exp
VarE Name
importedName) ([(Exp, Type)] -> Q Exp) -> [(Exp, Type)] -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Type] -> [(Exp, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Name -> Exp
VarE (Name -> Exp) -> [Name] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Type -> [Type]
getArgs Type
funTy)
  Exp
body <- case Type -> Maybe Int
detectRetTuple Type
funTy of
               Maybe Int
Nothing -> [e| rebox $(pure funAppE) |]
               Just Int
n -> do
                  [Name]
retNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Q Name -> Q [Name]) -> Q Name -> Q [Name]
forall a b. (a -> b) -> a -> b
$ String -> Q Name
newName String
"ret"
#if MIN_VERSION_template_haskell(2, 16, 0)
                  [Maybe Exp]
boxing <- [Name] -> (Name -> Q (Maybe Exp)) -> Q [Maybe Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Name]
retNames ((Name -> Q (Maybe Exp)) -> Q [Maybe Exp])
-> (Name -> Q (Maybe Exp)) -> Q [Maybe Exp]
forall a b. (a -> b) -> a -> b
$ \Name
name -> Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Q Exp -> Q (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [e| rebox $(pure $ VarE name) |]
#else
                  boxing <- forM retNames $ \name -> [e| rebox $(pure $ VarE name) |]
#endif
                  [e| case $(pure funAppE) of
                           $(pure $ UnboxedTupP $ VarP <$> retNames) -> $(pure $ TupE boxing)
                    |]
  Dec -> Q Dec
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dec -> Q Dec) -> Dec -> Q Dec
forall a b. (a -> b) -> a -> b
$ Name -> [Clause] -> Dec
FunD (String -> Name
mkName String
funName) [[Pat] -> Body -> [Dec] -> Clause
Clause (Name -> Pat
VarP (Name -> Pat) -> [Name] -> [Pat]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
argNames) (Exp -> Body
NormalB Exp
body) []]
  where
    f :: Exp -> (Exp, Type) -> Q Exp
f Exp
acc (Exp
argName, Type
argType) | Type
argType Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT ''BS.ByteString = [e| $(pure acc)
                                                                            (unbox $ getBSAddr $(pure argName))
                                                                            (unbox $ BS.length $(pure argName))
                                                                   |]
                             | Bool
otherwise = [e| $(pure acc) (unbox $(pure argName)) |]

unliftType :: Type -> Type
unliftType :: Type -> Type
unliftType = (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftTuple
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name -> Name) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Name -> Name
unliftBaseTy
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftPtrs
           (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Type) -> Type -> Type
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi Type -> Type
unliftBS
  where
    unliftBaseTy :: Name -> Name
unliftBaseTy Name
x | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Word = ''Word#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Word8 = ''Word#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Int = ''Int#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Double = ''Double#
                   | Name
x Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Float = ''Float#
                   | Bool
otherwise = Name
x

    unliftPtrs :: Type -> Type
unliftPtrs (AppT (ConT Name
name) Type
_) | Name
name Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''Ptr = Name -> Type
ConT ''Addr#
    unliftPtrs Type
x = Type
x

    unliftBS :: Type -> Type
unliftBS (AppT (AppT Type
ArrowT (ConT Name
bs)) Type
rhs) | Name
bs Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== ''BS.ByteString = IO Type -> Type
forall a. IO a -> a
unsafePerformIO (IO Type -> Type) -> IO Type -> Type
forall a b. (a -> b) -> a -> b
$ Q Type -> IO Type
forall (m :: * -> *) a. Quasi m => Q a -> m a
runQ [t| Addr# -> Int# -> $(pure rhs) |]
    unliftBS Type
x = Type
x

    unliftTuple :: Type -> Type
unliftTuple (TupleT Int
n) = Int -> Type
UnboxedTupleT Int
n
    unliftTuple Type
x = Type
x

detectRetTuple :: Type -> Maybe Int
detectRetTuple :: Type -> Maybe Int
detectRetTuple (AppT (AppT Type
ArrowT Type
_) Type
rhs) = Type -> Maybe Int
detectRetTuple Type
rhs
detectRetTuple (AppT Type
lhs Type
_) = Type -> Maybe Int
detectRetTuple Type
lhs
detectRetTuple (TupleT Int
n) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
n
detectRetTuple Type
_ = Maybe Int
forall a. Maybe a
Nothing

getArgs :: Type -> [Type]
getArgs :: Type -> [Type]
getArgs Type
ty = [ Type
argTy | AppT Type
ArrowT Type
argTy <- Type -> [Type]
forall from to. Biplate from to => from -> [to]
universeBi Type
ty ]

countArgs :: Type -> Int
countArgs :: Type -> Int
countArgs Type
ty = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ArrowT) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Type -> [Type]
forall from to. Biplate from to => from -> [to]
universeBi Type
ty