{-# LINE 1 "System/Modbus.hsc" #-}
{- | Haskell bindings to the C modbus library https://libmodbus.org/ -}

{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}

module System.Modbus (
        -- * Equivilance to the C library
        -- | Functions in this module are named the same as those in the C
        -- library, but without the leading "modbus_". You may wish to import
        -- this module qualified as Modbus to make the names match up.
        --
        -- See the C library documentation for details about the use
        -- of any function. https://libmodbus.org/documentation/
        --
        -- When a function in the C library returns a special value on
        -- error, this module will instead throw an exception.
        --
        -- This module has been tested with version 3.1.4 of the C library.
        -- It may also work with other versions.

        -- * Example
        -- 
        -- | This example reads some of the registers of an Epever solar
        -- charge controller. It shows how `Data.Binary` can be used
        -- to decode the modbus registers into a haskell data structure.
        -- 
        -- > import System.Modbus
        -- > import Data.Binary.Get
        -- > 
        -- > main = do
        -- > 	mb <- new_rtu "/dev/ttyS1" (Baud 115200) ParityNone (DataBits 8) (StopBits 1)
        -- > 	set_slave mb (DeviceAddress 1)
        -- > 	connect mb
        -- > 	regs <- mkRegisterVector 5
        -- > 	b <- read_input_registers mb (Addr 0x3100) regs
        -- > 	print $ runGet getEpever b
        -- > 
        -- > data Epever = Epever
        -- > 	{ pv_array_voltage :: Float
        -- > 	, pv_array_current :: Float
        -- > 	, pv_array_power :: Float
        -- > 	, battery_voltage :: Float
        -- > 	} deriving (Show)
        -- > 
        -- > getEpever :: Get Epever
        -- > getEpever = Epever
        -- >	<$> epeverfloat  -- register 0x3100
        -- > 	<*> epeverfloat  -- register 0x3101
        -- >	<*> epeverfloat2 -- register 0x3102 (low) and 0x3103 (high)
        -- > 	<*> epeverfloat  -- register 0x3104
        -- >  where
        -- >	 epeverfloat = decimals 2 <$> getWord16host
        -- >	 epeverfloat2 = do
        -- >	 	l <- getWord16host
        -- >	 	h <- getWord16host
        -- >	 	return (decimals 2 (l + h*2^16))
        -- > 	 decimals n v = fromIntegral v / (10^n)

        -- * Core data types
        Context,
        Addr(..),

        -- * RTU Context
        Baud(..),
        Parity(..),
        DataBits(..),
        StopBits(..),
        new_rtu,
        SerialMode(..),
        rtu_get_serial_mode,
        rtu_set_serial_mode,
        RTS(..),
        rtu_get_rts,
        rtu_set_rts,
        rtu_get_rts_delay,
        rtu_set_rts_delay,

        -- * TCP (IPv4) Context
        IPAddress(..),
        Port(..),
        new_tcp,

        -- * TCP PI (IPv4 and IPv6) Context
        Node(..),
        Service(..),
        new_tcp_pi,

        -- * Configuration
        DeviceAddress(..),
        broadcastAddress,
        set_slave,
        connect,
        set_debug,
        Timeout(..),
        get_byte_timeout,
        set_byte_timeout,
        get_response_timeout,
        set_response_timeout,

        -- * Accessing registers
        RegisterVector,
        mkRegisterVector,
        RegisterData(..),
        read_registers,
        read_input_registers,
        write_registers,
        write_register,
        write_and_read_registers,

        -- * Accessing bits/coils
        BitVector,
        mkBitVector,
        BitData(..),
        Bit,
        boolBit,
        bitBool,
        read_bits,
        read_input_bits,
        write_bits,
        write_bit,
) where



import Foreign
import Foreign.C
import Data.Char
import Data.Default
import qualified Data.Vector.Storable.Mutable as VM
import qualified Data.Vector.Storable as V
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import GHC.IO.Exception

foreign import ccall unsafe "modbus.h &modbus_close" modbus_close
        :: FunPtr (Ptr () -> IO ())

foreign import ccall unsafe "modbus.h &modbus_free" modbus_free
        :: FunPtr (Ptr () -> IO ())

foreign import ccall unsafe "modbus.h modbus_new_rtu" modbus_new_rtu
        :: CString -> Int -> CChar -> Int -> Int -> IO (Ptr ())

foreign import ccall unsafe "modbus.h modbus_rtu_get_serial_mode" modbus_rtu_get_serial_mode
        :: Ptr () -> IO Int

foreign import ccall unsafe "modbus.h modbus_rtu_set_serial_mode" modbus_rtu_set_serial_mode
        :: Ptr () -> Int -> IO Int

foreign import ccall unsafe "modbus.h modbus_rtu_get_rts" modbus_rtu_get_rts
        :: Ptr () -> IO Int

foreign import ccall unsafe "modbus.h modbus_rtu_set_rts" modbus_rtu_set_rts
        :: Ptr () -> Int -> IO Int

foreign import ccall unsafe "modbus.h modbus_rtu_get_rts_delay" modbus_rtu_get_rts_delay
        :: Ptr () -> IO Int

foreign import ccall unsafe "modbus.h modbus_rtu_set_rts_delay" modbus_rtu_set_rts_delay
        :: Ptr () -> Int -> IO Int

foreign import ccall unsafe "modbus.h modbus_new_tcp" modbus_new_tcp
        :: CString -> Int -> IO (Ptr ())

foreign import ccall unsafe "modbus.h modbus_new_tcp" modbus_new_tcp_pi
        :: CString -> CString -> IO (Ptr ())

foreign import ccall unsafe "modbus.h modbus_set_slave" modbus_set_slave
        :: Ptr () -> Int -> IO Int

foreign import ccall unsafe "modbus.h modbus_connect" modbus_connect
        :: Ptr () -> IO Int

foreign import ccall unsafe "modbus.h modbus_set_debug" modbus_set_debug
        :: Ptr () -> Int -> IO Int

foreign import ccall unsafe "modbus.h modbus_get_byte_timeout" modbus_get_byte_timeout
        :: Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int

foreign import ccall unsafe "modbus.h modbus_set_byte_timeout" modbus_set_byte_timeout
        :: Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int

foreign import ccall unsafe "modbus.h modbus_get_response_timeout" modbus_get_response_timeout
        :: Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int

foreign import ccall unsafe "modbus.h modbus_set_response_timeout" modbus_set_response_timeout
        :: Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int

foreign import ccall unsafe "modbus.h modbus_read_registers" modbus_read_registers
        :: Ptr () -> Int -> Int -> Ptr Word16 -> IO Int

foreign import ccall unsafe "modbus.h modbus_read_input_registers" modbus_read_input_registers
        :: Ptr () -> Int -> Int -> Ptr Word16 -> IO Int

foreign import ccall unsafe "modbus.h modbus_write_registers" modbus_write_registers
        :: Ptr () -> Int -> Int -> Ptr Word16 -> IO Int

foreign import ccall unsafe "modbus.h modbus_write_register" modbus_write_register
        :: Ptr () -> Int -> Word16 -> IO Int

foreign import ccall unsafe "modbus.h modbus_write_and_read_registers" modbus_write_and_read_registers
        :: Ptr () -> Int -> Int -> Ptr Word16 -> Int -> Int -> Ptr Word16 -> IO Int

foreign import ccall unsafe "modbus.h modbus_read_bits" modbus_read_bits
        :: Ptr () -> Int -> Int -> Ptr Word8 -> IO Int

foreign import ccall unsafe "modbus.h modbus_read_input_bits" modbus_read_input_bits
        :: Ptr () -> Int -> Int -> Ptr Word8 -> IO Int

foreign import ccall unsafe "modbus.h modbus_write_bits" modbus_write_bits
        :: Ptr () -> Int -> Int -> Ptr Word8 -> IO Int

foreign import ccall unsafe "modbus.h modbus_write_bit" modbus_write_bit
        :: Ptr () -> Int -> Int -> IO Int

accessVector
        :: Storable t
        => Context
        -> Addr
        -> VM.IOVector t
        -> (Ptr () -> Int -> Int -> Ptr t -> IO Int)
        -> String
        -> IO ()
accessVector h (Addr addr) v action actionname = withContext h $ \ctx -> do
        let (fptr, nb) = VM.unsafeToForeignPtr0 v
        r <- withForeignPtr fptr $ action ctx addr nb
        if r == -1
                then throwErrno actionname
                else if r /= nb
                        then ioError $ IOError Nothing OtherError
                                 actionname "short read/write" Nothing Nothing
                        else return ()

-- | A modbus device context.
--
-- The context will automatically be closed and freed when it is
-- garbage collected.
data Context = Context (ForeignPtr ())

mkContext :: Ptr () -> IO Context
mkContext ctx = do
        ptr <- newForeignPtr_ ctx
        addForeignPtrFinalizer modbus_free ptr
        -- this will run before modbus_free
        addForeignPtrFinalizer modbus_close ptr
        return (Context ptr)

withContext :: Context -> (Ptr () -> IO a) -> IO a
withContext (Context ptr) = withForeignPtr ptr

newtype Baud = Baud Int
        deriving (Show, Eq)

data Parity = ParityNone | ParityEven | ParityOdd
        deriving (Show, Eq)

newtype DataBits = DataBits Int
        deriving (Show, Eq)

newtype StopBits = StopBits Int
        deriving (Show, Eq)

-- | Create a modbus Remote Terminal Unit context.
-- 
-- The FilePath is the serial device to connect to.
new_rtu :: FilePath -> Baud -> Parity -> DataBits -> StopBits -> IO Context
new_rtu f (Baud b) p (DataBits d) (StopBits s) = do
        ctx <- withCString f $ \cf ->
                modbus_new_rtu cf b pc d s
        if ctx == nullPtr
                then throwErrno "modbus_new_rtu"
                else mkContext ctx
  where
        pc = fromIntegral $ ord $ case p of
                ParityNone -> 'N'
                ParityEven -> 'E'
                ParityOdd -> 'O'

-- | IPv4 address to connect to. In server mode, use AnyAddress to listen
-- to any addresses.
data IPAddress = IPAddress String | AnyAddress
        deriving (Show, Eq)

newtype Port = Port Int
        deriving (Show, Eq)

instance Default Port where
        def = Port 502
{-# LINE 291 "System/Modbus.hsc" #-}

-- | Create a modbus TCP/IPv4 context.
new_tcp :: IPAddress -> Port -> IO Context
new_tcp ipaddr (Port port) = do
        ctx <- case ipaddr of
                IPAddress s ->
                        withCString s $ \cipaddr ->
                                modbus_new_tcp cipaddr port
                AnyAddress ->
                        modbus_new_tcp nullPtr port
        if ctx == nullPtr
                then throwErrno "modbus_new_tcp"
                else mkContext ctx

-- | Host name or IP address to connect to. In server mode, use AnyNode
-- to listen to any addresses.
data Node = Node String | AnyNode
        deriving (Show, Eq)

-- | Service name/port number to connect to.
newtype Service = Service String
        deriving (Show, Eq)

instance Default Service where
        def = Service (show p)
          where
                p :: Int
                p = 502
{-# LINE 319 "System/Modbus.hsc" #-}

new_tcp_pi :: Node -> Service -> IO Context
new_tcp_pi node (Service service) = withCString service $ \cservice -> do
        ctx <- case node of
                Node s ->
                        withCString s $ \cnode ->
                                modbus_new_tcp_pi cnode cservice
                AnyNode ->
                        modbus_new_tcp_pi nullPtr cservice
        if ctx == nullPtr
                then throwErrno "modbus_new_tcp_pi"
                else mkContext ctx

data SerialMode = RTU_RS232 | RTU_RS485
        deriving (Show, Eq)

rtu_get_serial_mode :: Context -> IO SerialMode
rtu_get_serial_mode h = withContext h $ \ctx -> do
        r <- modbus_rtu_get_serial_mode ctx
        if r == 0
{-# LINE 339 "System/Modbus.hsc" #-}
                then return RTU_RS232
                else if r == 1
{-# LINE 341 "System/Modbus.hsc" #-}
                        then return RTU_RS485
                        else throwErrno "modbus_rtu_get_serial_mode"

rtu_set_serial_mode :: Context -> SerialMode -> IO ()
rtu_set_serial_mode h m = withContext h $ \ctx -> do
        r <- modbus_rtu_set_serial_mode ctx $ case m of
                RTU_RS232 -> 0
{-# LINE 348 "System/Modbus.hsc" #-}
                RTU_RS485 -> 1
{-# LINE 349 "System/Modbus.hsc" #-}
        if r == 0
                then return ()
                else throwErrno "modbus_rtu_set_serial_mode"

data RTS = RTU_RTS_NONE | RTU_RTS_UP | RTU_RTS_DOWN
        deriving (Show, Eq)

rtu_get_rts :: Context -> IO RTS
rtu_get_rts h = withContext h $ \ctx -> do
        r <- modbus_rtu_get_rts ctx
        if r == 0
{-# LINE 360 "System/Modbus.hsc" #-}
                then return RTU_RTS_NONE
                else if r == 1
{-# LINE 362 "System/Modbus.hsc" #-}
                        then return RTU_RTS_UP
                        else if r == 2
{-# LINE 364 "System/Modbus.hsc" #-}
                                then return RTU_RTS_DOWN
                                else throwErrno "modbus_rtu_get_serial_mode"

rtu_set_rts :: Context -> RTS -> IO ()
rtu_set_rts h m = withContext h $ \ctx -> do
        r <- modbus_rtu_set_rts ctx $ case m of
                RTU_RTS_NONE -> 0
{-# LINE 371 "System/Modbus.hsc" #-}
                RTU_RTS_UP -> 1
{-# LINE 372 "System/Modbus.hsc" #-}
                RTU_RTS_DOWN -> 2
{-# LINE 373 "System/Modbus.hsc" #-}
        if r == 0
                then return ()
                else throwErrno "modbus_rtu_set_rts"

rtu_get_rts_delay :: Context -> IO Int
rtu_get_rts_delay h = withContext h $ \ctx -> do
        r <- modbus_rtu_get_rts_delay ctx
        if r /= -1
                then return r
                else throwErrno "modbus_rtu_get_rts_delay"

rtu_set_rts_delay :: Context -> Int -> IO ()
rtu_set_rts_delay h n = withContext h $ \ctx -> do
        r <- modbus_rtu_set_rts_delay ctx n
        if r == 0
                then return ()
                else throwErrno "modbus_rtu_set_rts_delay"

-- | The address of a modbus device.
newtype DeviceAddress = DeviceAddress Int
        deriving (Show, Eq)

broadcastAddress :: DeviceAddress
broadcastAddress = DeviceAddress 0
{-# LINE 397 "System/Modbus.hsc" #-}

-- | Set the address of the modbus device that the Context should
-- communicate with.
set_slave :: Context -> DeviceAddress -> IO ()
set_slave h (DeviceAddress n) = withContext h $ \ctx -> do
        r <- modbus_set_slave ctx n
        if r == 0
                then return ()
                else throwErrno "modbus_set_slave"

connect :: Context -> IO ()
connect h = withContext h $ \ctx -> do
        r <- modbus_connect ctx
        if r == 0
                then return ()
                else throwErrno "modbus_connect"

set_debug :: Context -> Bool -> IO ()
set_debug h b = withContext h $ \ctx -> do
        r <- modbus_set_debug ctx $
                if b
                        then 1
{-# LINE 419 "System/Modbus.hsc" #-}
                        else 0
{-# LINE 420 "System/Modbus.hsc" #-}
        if r == 0
                then return ()
                else throwErrno "modbus_set_debug"

data Timeout = Timeout
        { to_sec :: Word32
        , to_usec :: Word32
        }
        deriving (Eq, Show)

get_timeout ::(Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int) -> String -> Context -> IO Timeout
get_timeout action actionname h =
        withContext h $ \ctx ->
                alloca $ \secp ->
                        alloca $ \usecp -> do
                                r <- action ctx secp usecp
                                if r == 0
                                        then do
                                                sec <- peek secp
                                                usec <- peek usecp
                                                return $ Timeout sec usec
                                        else throwErrno actionname

set_timeout :: (Ptr () -> Ptr Word32 -> Ptr Word32 -> IO Int) -> String -> Context -> Timeout -> IO ()
set_timeout action actionname h timeout =
        withContext h $ \ctx ->
                alloca $ \secp ->
                        alloca $ \usecp -> do
                                poke secp (to_sec timeout)
                                poke usecp (to_usec timeout)
                                r <- action ctx secp usecp
                                if r == 0
                                        then return ()
                                        else throwErrno actionname

get_byte_timeout :: Context -> IO Timeout
get_byte_timeout = get_timeout
        modbus_get_byte_timeout
        "modbus_get_byte_timeout"

set_byte_timeout :: Context -> Timeout -> IO ()
set_byte_timeout = set_timeout
        modbus_set_byte_timeout
        "modbus_set_byte_timeout"

get_response_timeout :: Context -> IO Timeout
get_response_timeout = get_timeout
        modbus_get_response_timeout
        "modbus_get_response_timeout"

set_response_timeout :: Context -> Timeout -> IO ()
set_response_timeout = set_timeout
        modbus_set_response_timeout
        "modbus_set_response_timeout"

-- | An address on a modbus device.
newtype Addr = Addr Int
        deriving (Show, Eq)

-- | A mutable vector that is used as a buffer when reading or writing 
-- registers of a modbus device.
type RegisterVector = VM.IOVector Word16

-- | Allocates a vector holding the contents of a specified number
-- of registers.
--
-- The values are initialized to 0 to start.
mkRegisterVector :: Int -> IO RegisterVector
mkRegisterVector sz = VM.replicate sz 0

-- | Types that can hold modbus register data.
-- 
-- Of these, `RegisterVector` is the most efficient as it avoids
-- allocating new memory on each read or write. But it can be more useful
-- to get a ByteString and use a library such as cereal or binary to
-- parse the contents of the modbus registers.
class RegisterData t where
        fromRegisterVector :: RegisterVector -> IO t
        toRegisterVector :: t -> IO RegisterVector

instance RegisterData RegisterVector where
        fromRegisterVector = pure
        toRegisterVector = pure

instance RegisterData (V.Vector Word16) where
        fromRegisterVector = V.freeze
        toRegisterVector = V.thaw

instance RegisterData B.ByteString where
        fromRegisterVector v =
                B.pack . V.toList . castbytes <$> fromRegisterVector v
          where
                -- Simply interpret the vector as bytes.
                castbytes :: V.Vector Word16 -> V.Vector Word8
                castbytes = V.unsafeCast
        toRegisterVector =
                toRegisterVector . castbytes . V.fromList . B.unpack
          where
                -- If there are an odd number of bytes, the last
                -- byte will be omitted.
                castbytes :: V.Vector Word8 -> V.Vector Word16
                castbytes = V.unsafeCast

instance RegisterData L.ByteString where
        fromRegisterVector v = L.fromStrict <$> fromRegisterVector v
        toRegisterVector = toRegisterVector . L.toStrict

-- | Reads the holding registers from the modbus device, starting at
-- the Addr, into the RegisterVector buffer.
read_registers :: RegisterData t => Context -> Addr -> RegisterVector -> IO t
read_registers h addr v = do
        accessVector h addr v
                modbus_read_registers
                "modbus_read_registers"
        fromRegisterVector v

-- | Reads the input registers from the modbus device, starting at
-- the Addr, into the RegisterVector buffer.
read_input_registers :: RegisterData t => Context -> Addr -> RegisterVector -> IO t
read_input_registers h addr v = do
        accessVector h addr v
                modbus_read_input_registers
                "modbus_read_input_registers"
        fromRegisterVector v

-- | Writes the registers to the modbus device, starting at
-- the Addr.
write_registers :: Context -> Addr -> RegisterVector -> IO ()
write_registers h addr v =
        accessVector h addr v
                modbus_write_registers
                "modbus_write_registers"

write_register :: Context -> Addr -> Word16 -> IO ()
write_register h (Addr addr) val = withContext h $ \ctx -> do
        r <- modbus_write_register ctx addr val
        if r == -1
                then throwErrno "modbus_write_register"
                else return ()

write_and_read_registers
        :: Context
        -> Addr
        -- ^ address to write to
        -> RegisterVector
        -- ^ data to write
        -> Addr
        -- ^ address to read from
        -> RegisterVector
        -- ^ data to read
        -> IO ()
write_and_read_registers h (Addr write_addr) write_v (Addr read_addr) read_v =
        withContext h $ \ctx -> do
                let (write_fptr, write_nb) = VM.unsafeToForeignPtr0 write_v
                let (read_fptr, read_nb) = VM.unsafeToForeignPtr0 read_v
                r <- withForeignPtr write_fptr $ \write_ptr ->
                        withForeignPtr read_fptr $ \read_ptr ->
                                modbus_write_and_read_registers ctx
                                        write_addr write_nb write_ptr
                                        read_addr read_nb read_ptr
                if r == -1
                        then throwErrno actionname
                        else if r /= read_nb
                                then ioError $ IOError Nothing OtherError
                                         actionname "short read" Nothing Nothing
                                else return ()
  where
        actionname = "modbus_write_and_read_registers"

-- | A mutable vector that is used as a buffer when reading or writing 
-- bits (coils) of a modbus device.
type BitVector = VM.IOVector Bit

-- | Allocates a vector holding the specified number of bits.
--
-- The bits are set to start.
mkBitVector :: Int -> IO BitVector
mkBitVector sz = VM.replicate sz 1
{-# LINE 598 "System/Modbus.hsc" #-}

type Bit = Word8

boolBit :: Bit -> Bool
boolBit b = b == 1
{-# LINE 603 "System/Modbus.hsc" #-}

bitBool :: Bool -> Bit
bitBool True = 1
{-# LINE 606 "System/Modbus.hsc" #-}
bitBool False = 0
{-# LINE 607 "System/Modbus.hsc" #-}

-- | Types that can hold modbus bit data.
-- 
-- Of these, `BitVector` is the most efficient as it avoids
-- allocating new memory on each read or write. But it can be easier
-- to use a Vector of Bool.
class BitData t where
        fromBitVector :: BitVector -> IO t
        toBitVector :: t -> IO BitVector

instance BitData BitVector where
        fromBitVector = pure
        toBitVector = pure

instance BitData (V.Vector Word8) where
        fromBitVector = V.freeze
        toBitVector = V.thaw

instance BitData (V.Vector Bool) where
        fromBitVector v = V.map boolBit <$> fromBitVector v
        toBitVector = toBitVector . V.map bitBool

-- | Reads the bits (coils) from the modbus device, starting at
-- the Addr, into the BitVector.
read_bits :: BitData t => Context -> Addr -> BitVector -> IO t
read_bits h addr v = do
        accessVector h addr v
                modbus_read_bits
                "modbus_read_bits"
        fromBitVector v

-- | Reads the input bits from the modbus device, starting at
-- the Addr, into the BitVector.
read_input_bits :: BitData t => Context -> Addr -> BitVector -> IO t
read_input_bits h addr v = do
        accessVector h addr v
                modbus_read_input_bits
                "modbus_read_input_bits"
        fromBitVector v

-- | Writes the bits (coils) of the modbus device, starting at
-- the Addr.
write_bits :: Context -> Addr -> BitVector -> IO ()
write_bits h addr v = accessVector h addr v
        modbus_write_bits
        "modbus_write_bits"

write_bit :: Context -> Addr -> Bit -> IO ()
write_bit h (Addr addr) val = withContext h $ \ctx -> do
        r <- modbus_write_bit ctx addr (fromIntegral val)
        if r == -1
                then throwErrno "modbus_write_bit"
                else return ()