-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE ForeignFunctionInterface #-} module TensorFlow.Internal.Raw where #include "third_party/tensorflow/c/c_api.h" import Foreign import Foreign.C {#enum TF_DataType as DataType {} deriving (Show, Eq) #} {#enum TF_Code as Code {} deriving (Show, Eq) #} -- Status. {#pointer *TF_Status as Status newtype #} newStatus :: IO Status newStatus = {# call TF_NewStatus as ^ #} deleteStatus :: Status -> IO () deleteStatus = {# call TF_DeleteStatus as ^ #} setStatus :: Status -> Code -> CString -> IO () setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c) getCode :: Status -> IO Code getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s message :: Status -> IO CString message = {# call TF_Message as ^ #} -- Buffer. data Buffer {#pointer *TF_Buffer as BufferPtr -> Buffer #} getBufferData :: BufferPtr -> IO (Ptr ()) getBufferData = {#get TF_Buffer->data #} getBufferLength :: BufferPtr -> IO CULong getBufferLength ={#get TF_Buffer->length #} -- Tensor. {#pointer *TF_Tensor as Tensor newtype #} instance Storable Tensor where sizeOf (Tensor t) = sizeOf t alignment (Tensor t) = alignment t peek p = fmap Tensor (peek (castPtr p)) poke p (Tensor t) = poke (castPtr p) t -- A synonym for the int64_t type, which is used in the TensorFlow API. -- On some platforms it's `long`; on others (e.g., Mac OS X) it's `long long`; -- and as far as Haskell is concerned, those are distinct types (`CLong` vs -- `CLLong`). type CInt64 = {#type int64_t #} newTensor :: DataType -> Ptr CInt64 -- dimensions array -> CInt -- num dimensions -> Ptr () -- data -> CULong -- data len -> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator -> Ptr () -- deallocator arg -> IO Tensor newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt) deleteTensor :: Tensor -> IO () deleteTensor = {# call TF_DeleteTensor as ^ #} tensorType :: Tensor -> IO DataType tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t numDims :: Tensor -> IO CInt numDims = {# call TF_NumDims as ^ #} dim :: Tensor -> CInt -> IO CInt64 dim = {# call TF_Dim as ^ #} tensorByteSize :: Tensor -> IO CULong tensorByteSize = {# call TF_TensorByteSize as ^ #} tensorData :: Tensor -> IO (Ptr ()) tensorData = {# call TF_TensorData as ^ #} -- Session Options. {# pointer *TF_SessionOptions as SessionOptions newtype #} newSessionOptions :: IO SessionOptions newSessionOptions = {# call TF_NewSessionOptions as ^ #} setTarget :: SessionOptions -> CString -> IO () setTarget = {# call TF_SetTarget as ^ #} setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO () setConfig = {# call TF_SetConfig as ^ #} deleteSessionOptions :: SessionOptions -> IO () deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #} -- Session. {# pointer *TF_DeprecatedSession as Session newtype #} newSession :: SessionOptions -> Status -> IO Session newSession = {# call TF_NewDeprecatedSession as ^ #} closeSession :: Session -> Status -> IO () closeSession = {# call TF_CloseDeprecatedSession as ^ #} deleteSession :: Session -> Status -> IO () deleteSession = {# call TF_DeleteDeprecatedSession as ^ #} extendGraph :: Session -> Ptr () -> CULong -> Status -> IO () extendGraph = {# call TF_ExtendGraph as ^ #} run :: Session -> BufferPtr -- RunOptions proto. -> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count). -> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count). -> Ptr CString -> CInt -- Target nodes (names, count). -> BufferPtr -- RunMetadata proto. -> Status -> IO () run = {# call TF_Run as ^ #} -- FFI helpers. type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO () foreign import ccall "wrapper" wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn) -- | Get the OpList of all OpDefs defined in this address space. -- Returns a BufferPtr, ownership of which is transferred to the caller -- (and can be freed using deleteBuffer). -- -- The data in the buffer will be the serialized OpList proto for ops registered -- in this address space. getAllOpList :: IO BufferPtr getAllOpList = {# call TF_GetAllOpList as ^ #} foreign import ccall "&TF_DeleteBuffer" deleteBuffer :: FunPtr (BufferPtr -> IO ())