-- | Internal functions to invoke JNI methods
--
-- The functions in this module avoid using
-- 'Language.Java.Coercible' so they can be reused in interfaces which
-- use other ways to convert between Haskell and Java values.
--
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}

module Language.Java.Internal
  ( newJ
  , callToJValue
  , callStaticToJValue
  , getStaticFieldAsJValue
  , getClass
  , setGetClassFunction
  ) where

import Data.IORef
import Data.Singletons (SingI(..), SomeSing(..))
import Foreign.JNI hiding (throw)
import Foreign.JNI.Types
import qualified Foreign.JNI.String as JNI
import System.IO.Unsafe (unsafeDupablePerformIO, unsafePerformIO)

-- | Sets the function to use for loading classes.
--
-- 'findClass' is used by default.
--
setGetClassFunction
  :: (forall ty. IsReferenceType ty => Sing (ty :: JType) -> IO JClass)
  -> IO ()
setGetClassFunction :: (forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass)
-> IO ()
setGetClassFunction forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
f = IORef GetClassFun -> GetClassFun -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef GetClassFun
getClassFunctionRef (GetClassFun -> IO ()) -> GetClassFun -> IO ()
forall a b. (a -> b) -> a -> b
$ (forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass)
-> GetClassFun
GetClassFun forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
f

-- | Yields a class referece. It behaves as 'findClass' unless
-- 'setGetClassFunction' is used.
getClass :: IsReferenceType ty => Sing (ty :: JType) -> IO JClass
getClass :: Sing ty -> IO JClass
getClass Sing ty
s = IORef GetClassFun -> IO GetClassFun
forall a. IORef a -> IO a
readIORef IORef GetClassFun
getClassFunctionRef IO GetClassFun -> (GetClassFun -> IO JClass) -> IO JClass
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(GetClassFun forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
f) -> Sing ty -> IO JClass
forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
f Sing ty
s

newtype GetClassFun =
    GetClassFun (forall ty. IsReferenceType ty =>
                   Sing (ty :: JType) -> IO JClass
                )

{-# NOINLINE getClassFunctionRef #-}
getClassFunctionRef :: IORef GetClassFun
getClassFunctionRef :: IORef GetClassFun
getClassFunctionRef =
    IO (IORef GetClassFun) -> IORef GetClassFun
forall a. IO a -> a
unsafePerformIO (IO (IORef GetClassFun) -> IORef GetClassFun)
-> IO (IORef GetClassFun) -> IORef GetClassFun
forall a b. (a -> b) -> a -> b
$ GetClassFun -> IO (IORef GetClassFun)
forall a. a -> IO (IORef a)
newIORef ((forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass)
-> GetClassFun
GetClassFun (ReferenceTypeName -> IO JClass
findClass (ReferenceTypeName -> IO JClass)
-> (SJType ty -> ReferenceTypeName) -> SJType ty -> IO JClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SJType ty -> ReferenceTypeName
forall (ty :: JType).
IsReferenceType ty =>
Sing ty -> ReferenceTypeName
referenceTypeName))

newJ
  :: forall sym ty.
     ( ty ~ 'Class sym
     , SingI ty
     )
  => [SomeSing JType] -- ^ Singletons of argument types
  -> [JValue]
  -> IO (J ty)
{-# INLINE newJ #-}
newJ :: [SomeSing JType] -> [JValue] -> IO (J ty)
newJ [SomeSing JType]
argsings [JValue]
args = do
    let voidsing :: Sing 'Void
voidsing = Sing 'Void
forall k (a :: k). SingI a => Sing a
sing :: Sing 'Void
        klass :: JClass
klass = IO JClass -> JClass
forall a. IO a -> a
unsafeDupablePerformIO (IO JClass -> JClass) -> IO JClass -> JClass
forall a b. (a -> b) -> a -> b
$ do
          JClass
lk <- Sing ('Class sym) -> IO JClass
forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
getClass (Sing ('Class sym)
forall k (a :: k). SingI a => Sing a
sing :: Sing ('Class sym))
          JClass
gk <- JClass -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
newGlobalRef JClass
lk
          JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
deleteLocalRef JClass
lk
          JClass -> IO JClass
forall (m :: * -> *) a. Monad m => a -> m a
return JClass
gk
    J ('Class "java.lang.Object") -> J ty
forall (a :: JType) (b :: JType). J a -> J b
unsafeCast (J ('Class "java.lang.Object") -> J ty)
-> IO (J ('Class "java.lang.Object")) -> IO (J ty)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass
-> MethodSignature
-> [JValue]
-> IO (J ('Class "java.lang.Object"))
newObject JClass
klass ([SomeSing JType] -> Sing 'Void -> MethodSignature
forall (ty :: JType).
[SomeSing JType] -> Sing ty -> MethodSignature
methodSignature [SomeSing JType]
argsings Sing 'Void
SJType 'Void
voidsing) [JValue]
args

callToJValue
  :: forall ty1 k. (IsReferenceType ty1, SingI ty1)
  => Sing (k :: JType)
  -> J ty1 -- ^ Any object
  -> JNI.String -- ^ Method name
  -> [SomeSing JType] -- ^ Singletons of argument types
  -> [JValue] -- ^ Arguments
  -> IO JValue
{-# INLINE callToJValue #-}
callToJValue :: Sing k
-> J ty1 -> String -> [SomeSing JType] -> [JValue] -> IO JValue
callToJValue Sing k
retsing J ty1
obj String
mname [SomeSing JType]
argsings [JValue]
args = do
    let klass :: JClass
klass = IO JClass -> JClass
forall a. IO a -> a
unsafeDupablePerformIO (IO JClass -> JClass) -> IO JClass -> JClass
forall a b. (a -> b) -> a -> b
$ do
                  JClass
lk <- Sing ty1 -> IO JClass
forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
getClass (Sing ty1
forall k (a :: k). SingI a => Sing a
sing :: Sing ty1)
                  JClass
gk <- JClass -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
newGlobalRef JClass
lk
                  JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
deleteLocalRef JClass
lk
                  JClass -> IO JClass
forall (m :: * -> *) a. Monad m => a -> m a
return JClass
gk
        method :: JMethodID
method = IO JMethodID -> JMethodID
forall a. IO a -> a
unsafeDupablePerformIO (IO JMethodID -> JMethodID) -> IO JMethodID -> JMethodID
forall a b. (a -> b) -> a -> b
$ JClass -> String -> MethodSignature -> IO JMethodID
getMethodID JClass
klass String
mname ([SomeSing JType] -> Sing k -> MethodSignature
forall (ty :: JType).
[SomeSing JType] -> Sing ty -> MethodSignature
methodSignature [SomeSing JType]
argsings Sing k
retsing)
    case Sing k
retsing of
      SPrim "boolean" -> Word8 -> JValue
JBoolean (Word8 -> JValue) -> (Bool -> Word8) -> Bool -> JValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Bool -> Int) -> Bool -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum (Bool -> JValue) -> IO Bool -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                           J ty1 -> JMethodID -> [JValue] -> IO Bool
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Bool
callBooleanMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "byte" -> CChar -> JValue
JByte (CChar -> JValue) -> IO CChar -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO CChar
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO CChar
callByteMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "char" -> Word16 -> JValue
JChar (Word16 -> JValue) -> IO Word16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Word16
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Word16
callCharMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "short" -> Int16 -> JValue
JShort (Int16 -> JValue) -> IO Int16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Int16
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Int16
callShortMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "int" -> Int32 -> JValue
JInt (Int32 -> JValue) -> IO Int32 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Int32
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Int32
callIntMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "long" -> Int64 -> JValue
JLong (Int64 -> JValue) -> IO Int64 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Int64
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Int64
callLongMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "float" -> Float -> JValue
JFloat (Float -> JValue) -> IO Float -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Float
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Float
callFloatMethod J ty1
obj JMethodID
method [JValue]
args
      SPrim "double" -> Double -> JValue
JDouble (Double -> JValue) -> IO Double -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1 -> JMethodID -> [JValue] -> IO Double
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO Double
callDoubleMethod J ty1
obj JMethodID
method [JValue]
args

      Sing k
SVoid -> do
        J ty1 -> JMethodID -> [JValue] -> IO ()
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO ()
callVoidMethod J ty1
obj JMethodID
method [JValue]
args
        -- The void result is not inspected.
        JValue -> IO JValue
forall (m :: * -> *) a. Monad m => a -> m a
return ([Char] -> JValue
forall a. HasCallStack => [Char] -> a
error [Char]
"inspected output of method returning void")
      Sing k
_ -> J ('Class "java.lang.Object") -> JValue
forall (a :: JType). SingI a => J a -> JValue
JObject (J ('Class "java.lang.Object") -> JValue)
-> IO (J ('Class "java.lang.Object")) -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> J ty1
-> JMethodID -> [JValue] -> IO (J ('Class "java.lang.Object"))
forall o (a :: JType).
Coercible o (J a) =>
o -> JMethodID -> [JValue] -> IO (J ('Class "java.lang.Object"))
callObjectMethod J ty1
obj JMethodID
method [JValue]
args

callStaticToJValue
  :: Sing (k :: JType)
  -> JNI.String -- ^ Class name
  -> JNI.String -- ^ Method name
  -> [SomeSing JType] -- ^ Singletons of argument types
  -> [JValue] -- ^ Arguments
  -> IO JValue
{-# INLINE callStaticToJValue #-}
callStaticToJValue :: Sing k
-> String -> String -> [SomeSing JType] -> [JValue] -> IO JValue
callStaticToJValue Sing k
retsing String
cname String
mname [SomeSing JType]
argsings [JValue]
args = do
    let klass :: JClass
klass = IO JClass -> JClass
forall a. IO a -> a
unsafeDupablePerformIO (IO JClass -> JClass) -> IO JClass -> JClass
forall a b. (a -> b) -> a -> b
$ do
                  JClass
lk <- Sing ('Class Any) -> IO JClass
forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
getClass ([Char] -> SJType ('Class Any)
forall (sym :: Symbol). [Char] -> SJType ('Class sym)
SClass (String -> [Char]
JNI.toChars String
cname))
                  JClass
gk <- JClass -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
newGlobalRef JClass
lk
                  JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
deleteLocalRef JClass
lk
                  JClass -> IO JClass
forall (m :: * -> *) a. Monad m => a -> m a
return JClass
gk
        method :: JMethodID
method = IO JMethodID -> JMethodID
forall a. IO a -> a
unsafeDupablePerformIO (IO JMethodID -> JMethodID) -> IO JMethodID -> JMethodID
forall a b. (a -> b) -> a -> b
$ JClass -> String -> MethodSignature -> IO JMethodID
getStaticMethodID JClass
klass String
mname ([SomeSing JType] -> Sing k -> MethodSignature
forall (ty :: JType).
[SomeSing JType] -> Sing ty -> MethodSignature
methodSignature [SomeSing JType]
argsings Sing k
retsing)
    case Sing k
retsing of
      SPrim "boolean" -> Word8 -> JValue
JBoolean (Word8 -> JValue) -> (Bool -> Word8) -> Bool -> JValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Bool -> Int) -> Bool -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum (Bool -> JValue) -> IO Bool -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                           JClass -> JMethodID -> [JValue] -> IO Bool
callStaticBooleanMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "byte" -> CChar -> JValue
JByte (CChar -> JValue) -> IO CChar -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO CChar
callStaticByteMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "char" -> Word16 -> JValue
JChar (Word16 -> JValue) -> IO Word16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Word16
callStaticCharMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "short" -> Int16 -> JValue
JShort (Int16 -> JValue) -> IO Int16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Int16
callStaticShortMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "int" -> Int32 -> JValue
JInt (Int32 -> JValue) -> IO Int32 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Int32
callStaticIntMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "long" -> Int64 -> JValue
JLong (Int64 -> JValue) -> IO Int64 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Int64
callStaticLongMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "float" -> Float -> JValue
JFloat (Float -> JValue) -> IO Float -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Float
callStaticFloatMethod JClass
klass JMethodID
method [JValue]
args
      SPrim "double" -> Double -> JValue
JDouble (Double -> JValue) -> IO Double -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JMethodID -> [JValue] -> IO Double
callStaticDoubleMethod JClass
klass JMethodID
method [JValue]
args
      Sing k
SVoid -> do
        JClass -> JMethodID -> [JValue] -> IO ()
callStaticVoidMethod JClass
klass JMethodID
method [JValue]
args
        -- The void result is not inspected.
        JValue -> IO JValue
forall (m :: * -> *) a. Monad m => a -> m a
return ([Char] -> JValue
forall a. HasCallStack => [Char] -> a
error [Char]
"inspected output of method returning void")
      Sing k
_ -> J ('Class "java.lang.Object") -> JValue
forall (a :: JType). SingI a => J a -> JValue
JObject (J ('Class "java.lang.Object") -> JValue)
-> IO (J ('Class "java.lang.Object")) -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass
-> JMethodID -> [JValue] -> IO (J ('Class "java.lang.Object"))
callStaticObjectMethod JClass
klass JMethodID
method [JValue]
args

getStaticFieldAsJValue
  :: Sing (ty :: JType)
  -> JNI.String -- ^ Class name
  -> JNI.String -- ^ Static field name
  -> IO JValue
{-# INLINE getStaticFieldAsJValue #-}
getStaticFieldAsJValue :: Sing ty -> String -> String -> IO JValue
getStaticFieldAsJValue Sing ty
retsing String
cname String
fname = do
  let klass :: JClass
klass = IO JClass -> JClass
forall a. IO a -> a
unsafeDupablePerformIO (IO JClass -> JClass) -> IO JClass -> JClass
forall a b. (a -> b) -> a -> b
$ do
                JClass
lk <- Sing ('Class Any) -> IO JClass
forall (ty :: JType). IsReferenceType ty => Sing ty -> IO JClass
getClass ([Char] -> SJType ('Class Any)
forall (sym :: Symbol). [Char] -> SJType ('Class sym)
SClass (String -> [Char]
JNI.toChars String
cname))
                JClass
gk <- JClass -> IO JClass
forall o (ty :: JType). Coercible o (J ty) => o -> IO o
newGlobalRef JClass
lk
                JClass -> IO ()
forall o (ty :: JType). Coercible o (J ty) => o -> IO ()
deleteLocalRef JClass
lk
                JClass -> IO JClass
forall (m :: * -> *) a. Monad m => a -> m a
return JClass
gk
      field :: JFieldID
field = IO JFieldID -> JFieldID
forall a. IO a -> a
unsafeDupablePerformIO (IO JFieldID -> JFieldID) -> IO JFieldID -> JFieldID
forall a b. (a -> b) -> a -> b
$ JClass -> String -> Signature -> IO JFieldID
getStaticFieldID JClass
klass String
fname (Sing ty -> Signature
forall (ty :: JType). Sing ty -> Signature
signature Sing ty
retsing)
  case Sing ty
retsing of
    SPrim "boolean" -> Word8 -> JValue
JBoolean (Word8 -> JValue) -> IO Word8 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Word8
getStaticBooleanField JClass
klass JFieldID
field
    SPrim "byte" -> CChar -> JValue
JByte (CChar -> JValue) -> IO CChar -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO CChar
getStaticByteField JClass
klass JFieldID
field
    SPrim "char" -> Word16 -> JValue
JChar (Word16 -> JValue) -> IO Word16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Word16
getStaticCharField JClass
klass JFieldID
field
    SPrim "short" -> Int16 -> JValue
JShort (Int16 -> JValue) -> IO Int16 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Int16
getStaticShortField JClass
klass JFieldID
field
    SPrim "int" -> Int32 -> JValue
JInt (Int32 -> JValue) -> IO Int32 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Int32
getStaticIntField JClass
klass JFieldID
field
    SPrim "long" -> Int64 -> JValue
JLong (Int64 -> JValue) -> IO Int64 -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Int64
getStaticLongField JClass
klass JFieldID
field
    SPrim "float" -> Float -> JValue
JFloat (Float -> JValue) -> IO Float -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Float
getStaticFloatField JClass
klass JFieldID
field
    SPrim "double" -> Double -> JValue
JDouble (Double -> JValue) -> IO Double -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO Double
getStaticDoubleField JClass
klass JFieldID
field
    Sing ty
SVoid -> [Char] -> IO JValue
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"getStaticField cannot yield an object of type void"
    Sing ty
_ -> J ('Class "java.lang.Object") -> JValue
forall (a :: JType). SingI a => J a -> JValue
JObject (J ('Class "java.lang.Object") -> JValue)
-> IO (J ('Class "java.lang.Object")) -> IO JValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JClass -> JFieldID -> IO (J ('Class "java.lang.Object"))
getStaticObjectField JClass
klass JFieldID
field