{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeApplications #-}

module HOCD.Monad
  ( OCDT
  , runOCDT
  , MonadOCD(..)
  , halt
  , halt'
  , readMem
  , readMem32
  , readMemCount
  , writeMem
  , writeMem32
  ) where

import Control.Monad.Catch (MonadCatch, MonadMask, MonadThrow)
import Control.Monad.Except (MonadError, throwError)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Reader (MonadReader, ask)
import Control.Monad.Trans (MonadTrans, lift)
import Control.Monad.Trans.Except (ExceptT, runExceptT)
import Control.Monad.Trans.Reader (ReaderT, runReaderT)
import Data.Bits (FiniteBits(..))
import Data.ByteString (ByteString)
import Data.Word (Word32)
import HOCD.Command
  ( Command(..)
  , Capture(..)
  , Halt(..)
  , ReadMemory(..)
  , WriteMemory(..)
  , subChar
  )
import HOCD.Error (OCDError(..))
import HOCD.Types (MemAddress)
import Network.Socket (Socket)
import Text.Printf (PrintfArg)

import qualified Data.ByteString.Char8
import qualified Network.Socket.ByteString

newtype OCDT m a = OCDT
  { forall (m :: * -> *) a.
OCDT m a -> ExceptT OCDError (ReaderT Socket m) a
_unOCDT
      :: ExceptT OCDError
          (ReaderT Socket m) a
  }
  deriving
    ( forall a b. a -> OCDT m b -> OCDT m a
forall a b. (a -> b) -> OCDT m a -> OCDT m b
forall (m :: * -> *) a b. Functor m => a -> OCDT m b -> OCDT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> OCDT m a -> OCDT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OCDT m b -> OCDT m a
$c<$ :: forall (m :: * -> *) a b. Functor m => a -> OCDT m b -> OCDT m a
fmap :: forall a b. (a -> b) -> OCDT m a -> OCDT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> OCDT m a -> OCDT m b
Functor
    , forall a. a -> OCDT m a
forall a b. OCDT m a -> OCDT m b -> OCDT m a
forall a b. OCDT m a -> OCDT m b -> OCDT m b
forall a b. OCDT m (a -> b) -> OCDT m a -> OCDT m b
forall a b c. (a -> b -> c) -> OCDT m a -> OCDT m b -> OCDT m c
forall {m :: * -> *}. Monad m => Functor (OCDT m)
forall (m :: * -> *) a. Monad m => a -> OCDT m a
forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m a
forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m b
forall (m :: * -> *) a b.
Monad m =>
OCDT m (a -> b) -> OCDT m a -> OCDT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> OCDT m a -> OCDT m b -> OCDT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OCDT m a -> OCDT m b -> OCDT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m a
*> :: forall a b. OCDT m a -> OCDT m b -> OCDT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m b
liftA2 :: forall a b c. (a -> b -> c) -> OCDT m a -> OCDT m b -> OCDT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> OCDT m a -> OCDT m b -> OCDT m c
<*> :: forall a b. OCDT m (a -> b) -> OCDT m a -> OCDT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
OCDT m (a -> b) -> OCDT m a -> OCDT m b
pure :: forall a. a -> OCDT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> OCDT m a
Applicative
    , forall a. a -> OCDT m a
forall a b. OCDT m a -> OCDT m b -> OCDT m b
forall a b. OCDT m a -> (a -> OCDT m b) -> OCDT m b
forall (m :: * -> *). Monad m => Applicative (OCDT m)
forall (m :: * -> *) a. Monad m => a -> OCDT m a
forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m b
forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> (a -> OCDT m b) -> OCDT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OCDT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> OCDT m a
>> :: forall a b. OCDT m a -> OCDT m b -> OCDT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> OCDT m b -> OCDT m b
>>= :: forall a b. OCDT m a -> (a -> OCDT m b) -> OCDT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
OCDT m a -> (a -> OCDT m b) -> OCDT m b
Monad
    , MonadReader Socket
    , MonadError OCDError
    , forall e a. Exception e => OCDT m a -> (e -> OCDT m a) -> OCDT m a
forall (m :: * -> *).
MonadThrow m
-> (forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
forall {m :: * -> *}. MonadCatch m => MonadThrow (OCDT m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
OCDT m a -> (e -> OCDT m a) -> OCDT m a
catch :: forall e a. Exception e => OCDT m a -> (e -> OCDT m a) -> OCDT m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
OCDT m a -> (e -> OCDT m a) -> OCDT m a
MonadCatch
    , forall b.
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
forall a b c.
OCDT m a
-> (a -> ExitCase b -> OCDT m c)
-> (a -> OCDT m b)
-> OCDT m (b, c)
forall {m :: * -> *}. MonadMask m => MonadCatch (OCDT m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
forall (m :: * -> *) a b c.
MonadMask m =>
OCDT m a
-> (a -> ExitCase b -> OCDT m c)
-> (a -> OCDT m b)
-> OCDT m (b, c)
forall (m :: * -> *).
MonadCatch m
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> (forall a b c.
    m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c))
-> MonadMask m
generalBracket :: forall a b c.
OCDT m a
-> (a -> ExitCase b -> OCDT m c)
-> (a -> OCDT m b)
-> OCDT m (b, c)
$cgeneralBracket :: forall (m :: * -> *) a b c.
MonadMask m =>
OCDT m a
-> (a -> ExitCase b -> OCDT m c)
-> (a -> OCDT m b)
-> OCDT m (b, c)
uninterruptibleMask :: forall b.
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
$cuninterruptibleMask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
mask :: forall b.
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
$cmask :: forall (m :: * -> *) b.
MonadMask m =>
((forall a. OCDT m a -> OCDT m a) -> OCDT m b) -> OCDT m b
MonadMask
    , forall e a. Exception e => e -> OCDT m a
forall (m :: * -> *).
Monad m -> (forall e a. Exception e => e -> m a) -> MonadThrow m
forall {m :: * -> *}. MonadThrow m => Monad (OCDT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> OCDT m a
throwM :: forall e a. Exception e => e -> OCDT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> OCDT m a
MonadThrow
    , forall a. IO a -> OCDT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall {m :: * -> *}. MonadIO m => Monad (OCDT m)
forall (m :: * -> *) a. MonadIO m => IO a -> OCDT m a
liftIO :: forall a. IO a -> OCDT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> OCDT m a
MonadIO
    )

instance MonadTrans OCDT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> OCDT m a
lift = forall (m :: * -> *) a.
ExceptT OCDError (ReaderT Socket m) a -> OCDT m a
OCDT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Run OCDT transformer
runOCDT
  :: Monad m
  => Socket
  -> OCDT m a
  -> m (Either OCDError a)
runOCDT :: forall (m :: * -> *) a.
Monad m =>
Socket -> OCDT m a -> m (Either OCDError a)
runOCDT Socket
sock =
    (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Socket
sock)
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
OCDT m a -> ExceptT OCDError (ReaderT Socket m) a
_unOCDT

class ( MonadIO m
      , MonadError OCDError m
      ) => MonadOCD m where
  getSocket :: m Socket

instance MonadIO m => MonadOCD (OCDT m) where
  getSocket :: OCDT m Socket
getSocket = forall r (m :: * -> *). MonadReader r m => m r
ask

-- | Perform RPC call
rpc
  :: ( MonadOCD m
     , Command req
     )
  => req
  -> m (Reply req)
rpc :: forall (m :: * -> *) req.
(MonadOCD m, Command req) =>
req -> m (Reply req)
rpc req
cmd = do
  Socket
sock <- forall (m :: * -> *). MonadOCD m => m Socket
getSocket
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$
    Socket -> ByteString -> IO ()
Network.Socket.ByteString.sendAll
      Socket
sock
      (ByteString -> ByteString
rpcCmd forall a b. (a -> b) -> a -> b
$ forall req. Command req => req -> ByteString
request req
cmd)
  forall req.
Command req =>
req -> ByteString -> Either OCDError (Reply req)
reply req
cmd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *}. MonadIO m => Socket -> m ByteString
recvTillSub Socket
sock
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure
  where
    recvTillSub :: Socket -> m ByteString
recvTillSub Socket
s = do
      ByteString
msg <-
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
        forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ByteString
Network.Socket.ByteString.recv
            Socket
s
            Int
1024
      if ByteString -> Char
Data.ByteString.Char8.last ByteString
msg forall a. Eq a => a -> a -> Bool
== Char
subChar
      then forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
msg
      else Socket -> m ByteString
recvTillSub Socket
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
msg forall a. Semigroup a => a -> a -> a
<>)

    -- | Terminate with \SUB
    rpcCmd :: ByteString -> ByteString
    rpcCmd :: ByteString -> ByteString
rpcCmd =
      (forall a. Semigroup a => a -> a -> a
<> Char -> ByteString
Data.ByteString.Char8.singleton Char
subChar)

-- | Halt target
halt
  :: MonadOCD m
  => m ByteString
halt :: forall (m :: * -> *). MonadOCD m => m ByteString
halt = forall (m :: * -> *) req.
(MonadOCD m, Command req) =>
req -> m (Reply req)
rpc forall a b. (a -> b) -> a -> b
$ forall a. a -> Capture a
Capture Halt
Halt

-- | Halt target, discarding reply
halt'
  :: MonadOCD m
  => m ()
halt' :: forall (m :: * -> *). MonadOCD m => m ()
halt' = forall (m :: * -> *). MonadOCD m => m ByteString
halt forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Read multiple memory segments from @MemAddress@
-- according to count argument. Segment size depends
-- on passed in Word type.
readMemCount
  :: forall a m
   . ( MonadOCD m
     , FiniteBits a
     , Integral a
     )
  => MemAddress -- ^ Memory address to read from
  -> Int -- ^ Count
  -> m [a]
readMemCount :: forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, Integral a) =>
MemAddress -> Int -> m [a]
readMemCount MemAddress
ma Int
c =
  forall (m :: * -> *) req.
(MonadOCD m, Command req) =>
req -> m (Reply req)
rpc
    ReadMemory
      { readMemoryAddr :: MemAddress
readMemoryAddr = MemAddress
ma
      , readMemoryCount :: Int
readMemoryCount = Int
c
      }

-- | Read single memory segment from @MemAddress@
-- Segment size depends on passed in Word type.
readMem
  :: forall a m
   . ( MonadOCD m
     , FiniteBits a
     , Integral a
     )
  => MemAddress -- ^ Memory address to read from
  -> m a
readMem :: forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, Integral a) =>
MemAddress -> m a
readMem MemAddress
ma =
  forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, Integral a) =>
MemAddress -> Int -> m [a]
readMemCount MemAddress
ma Int
1
  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        [a
one] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
one
        [a]
_ -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError OCDError
OCDError_ExpectedOneButGotMore

-- | Shorthand for reading @Word32@ sized segment
readMem32
  :: MonadOCD m
  => MemAddress -- ^ Memory address to read from
  -> m Word32
readMem32 :: forall (m :: * -> *). MonadOCD m => MemAddress -> m Word32
readMem32 = forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, Integral a) =>
MemAddress -> m a
readMem @Word32

-- | Write multiple memory segments to @MemAddress@
writeMem
  :: forall a m
   . ( MonadOCD m
     , FiniteBits a
     , PrintfArg a
     , Integral a
     )
  => MemAddress -- ^ Memory address to write to
  -> [a] -- ^ Data to write
  -> m ()
writeMem :: forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, PrintfArg a, Integral a) =>
MemAddress -> [a] -> m ()
writeMem MemAddress
ma [a]
xs =
  forall (m :: * -> *) req.
(MonadOCD m, Command req) =>
req -> m (Reply req)
rpc
    WriteMemory
      { writeMemoryAddr :: MemAddress
writeMemoryAddr = MemAddress
ma
      , writeMemoryData :: [a]
writeMemoryData = [a]
xs
      }

-- | Shorthand for writing @Word32@ sized segment
writeMem32
  :: MonadOCD m
  => MemAddress -- ^ Memory address to write to
  -> [Word32] -- ^ Data to write
  -> m ()
writeMem32 :: forall (m :: * -> *). MonadOCD m => MemAddress -> [Word32] -> m ()
writeMem32 = forall a (m :: * -> *).
(MonadOCD m, FiniteBits a, PrintfArg a, Integral a) =>
MemAddress -> [a] -> m ()
writeMem @Word32