{-# LANGUAGE CPP #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{- | Conversions between R's RDS/RDA format and haskell data types.

tested with R 3.0.1


Missing:

* Data.Map

* better error reporting when the format is bad?

* more tests

-}
module RlangQQ.Binary (
    -- * functions to serialize many variables
    -- $rdaFmt
    toRDA, fromRDA,

    -- * serializing a single variable
    ToRDS(..), FromRDS(..),

    -- * types / internal
    FromRDA, ToRDSRecord, RDSHLIST, RDA, IxSize(..),

    module Data.HList.CommonMain, ) where

import Control.Lens
import System.Process
import Unsafe.Coerce
import Control.Applicative
import qualified Data.ByteString.Lazy.UTF8 as B
import qualified Data.ByteString.UTF8 as BS (fromString, toString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as B
import qualified Data.ByteString.Lazy.Char8 as B8

import Data.Int
import Data.HList.CommonMain
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Generic as VG
import qualified Data.Map as M

import Data.Binary
import Data.Binary.Get
import Data.Binary.Builder
import Data.Binary.Put
import qualified Data.Binary
import Data.Binary.IEEE754 (doubleToWord, wordToDouble)

import qualified Data.Text as T
import Data.Text.Encoding as E

import Control.Monad.Identity

#if REPA
import qualified Data.Array.Repa as R
#endif

import GHC.TypeLits

import qualified Data.Array as A
import qualified Language.Haskell.TH as TH

import qualified Codec.Compression.GZip as GZip


-- | labels are stored with the variables here. compare with the instance for 'Record' / 'LST' which collects
-- the labels and saves them as an attribute called \"names\"
newtype RDA a = RDA (HList a)
makeWrapped ''RDA

-- | this seems to be a magic number. Corresponds to R 3.0.1
putVersion = put (262153 :: Int32)

-- | doesn't check
getVersion = get :: Get Int32

-- | same as Binary but should be compatible with R's @saveRDS@
-- binary mode, which is for single objects
class ToRDS a where
    toRDS :: a -> Put
class FromRDS a where
    fromRDS :: Get a

-- | the binary instance of Double produces 25 bytes, not the 8 bytes for ieee754
-- perhaps alternatives to binary do better (bytes, cereal etc).
putDouble = put . doubleToWord
getDouble = fmap wordToDouble get

getVectorDouble = do
  14 <- get :: Get Int32
  len <- get :: Get Int32
  VG.replicateM (fromIntegral len) getDouble

putVectorDouble x = do
  put (14 :: Int32)
  put (fromIntegral $ VG.length x :: Int32)
  VG.mapM_ putDouble x

getVectorInt = do
  13 <- get :: Get Int32
  len <- get :: Get Int32
  VG.replicateM (fromIntegral len) get

putVectorInt x = do
  put (13 :: Int32)
  put (fromIntegral $ VG.length x :: Int32)
  VG.mapM_ put x

-- | becomes a numeric vector
instance ToRDS (V.Vector Double) where toRDS = putVectorDouble

-- | becomes a numeric vector
instance ToRDS (VS.Vector Double) where toRDS = putVectorDouble

instance FromRDS (V.Vector Double) where fromRDS = getVectorDouble
instance FromRDS (VS.Vector Double) where fromRDS = getVectorDouble

-- | integer vector
instance ToRDS (V.Vector Int32) where toRDS = putVectorInt
instance FromRDS (V.Vector Int32) where fromRDS = getVectorInt

instance ToRDS (VS.Vector Int32) where toRDS = putVectorInt
instance FromRDS (VS.Vector Int32) where fromRDS = getVectorInt

-- | character vector @c(\'ab\',\'cd\')@
instance ToRDS (V.Vector T.Text) where
    toRDS x = do
        put (16 :: Int32)
        put (fromIntegral (V.length x) :: Int32)
        V.mapM_ (\x -> do
            putVersion
            let x' = E.encodeUtf8 x
            put (fromIntegral $ BS.length x' :: Int32)
            putByteString x')  x

instance FromRDS (V.Vector T.Text) where
    fromRDS = do
        16 <- get :: Get Int32
        nstr <- get :: Get Int32
        V.mapM (const $ do
            get :: Get Int32 -- version
            len <- get :: Get Int32
            bs <- getByteString (fromIntegral len)
            return (E.decodeUtf8 bs)
            ) $ V.replicate (fromIntegral nstr) ()

-- | T.pack "abc" => @c(\'ab\')@
instance ToRDS T.Text where
    toRDS x = toRDS (V.singleton x)

instance FromRDS T.Text where
    fromRDS = (\(x:_) -> x) . V.toList <$> fromRDS

-- | @[\"abc\",\"def\"]@ => @c(\'abc\',\'def\')@
instance ToRDS [String] where
    toRDS x = toRDS $ V.fromList (map T.pack x)
instance FromRDS [String] where
    fromRDS = map T.unpack . V.toList <$> fromRDS

-- | "abc" => @c(\'ab\')@
instance ToRDS String where
    toRDS = toRDS . (:[])
instance FromRDS String where
    fromRDS = fmap (\(x:_) -> x) fromRDS


data FToRDS = FToRDS
instance (ToRDS a, putm ~ Put) => ApplyAB FToRDS a putm where
    applyAB FToRDS x = toRDS x

data FFromRDS = FFromRDS
instance (FromRDS b, Get b ~ getB, a ~ ()) => ApplyAB FFromRDS a getB where
    applyAB FFromRDS _ = fromRDS


type RDSHLIST xs' xs = (HNat2Integral (HLength xs), HFoldr (HSeq FToRDS) Put xs Put)

instance  (RDSHLIST xs' xs) => ToRDS (LST xs) where
    toRDS (LST xs) = do
        put (531::Int32)
        let len = hNat2Integral (Proxy :: Proxy (HLength xs))
        put (fromIntegral (len - 2 :: Int) :: Int32)
        hFoldr (HSeq FToRDS) (return () :: Put) xs :: Put
        put (254::Int32)

type RDSHLIST2 b l = (HSequence Get b l,
      HReplicateF (HLength l) FFromRDS () b,
      HNat2Integral (HLength l),
      HLengthEq l (HLength l))

instance (RDSHLIST2 __ l) => FromRDS (LST l) where
    fromRDS = withSelf $ \(self) -> do
        531 <- get :: Get Int32
        let len = hNat2Integral (hLength self) 
        len2 <- get :: Get Int32
        when (len /= len2) $ error $ "fromRDS expected length: " ++ show len ++ " rds file has length: " ++ show len2
        r <- hSequence $ hReplicateF (hLength self) FFromRDS ()
        254 <- get :: Get Int32
        return (LST r)
      where
        withSelf :: forall (a::[*]) m. (HList a -> m (LST a)) -> m (LST a)
        withSelf x = x (error "RlangQQ.Binary.LST.self")



-- | not to be constructed outside here
newtype LST (a :: [*]) = LST (HList a)

-- | @lab .=. val .*. 'emptyRecord'@ => @list(lab= ('toRDS' val) )@.
--
-- The type variables with underscores should be hidden
instance ToRDSRecord __ ___ xs => ToRDS (Record xs) where
    toRDS (Record xs) = toRDS $ LST $
                recordValues (Record xs) `hAppend`
                (ListStart `HCons` (Label :: Label "names") .=. (recordLabelsString (Proxy :: Proxy (LabelsOf xs))) `HCons` HNil)

type ToRDSRecord __ ___ xs = (RDSHLIST __ ___, ToRDS (LST ___),
            RecordValues xs,
            HList ___ ~ (HList (RecordValuesR xs) `HAppendR` HList '[ListStart, Tagged "names" [String]]),
            RecordLabelsString (LabelsOf xs),
            HAppend (HList (RecordValuesR xs)) (HList '[ ListStart, Tagged "names" [String]]))



-- | this signature shouldn't be necessary... it just repeats all the
-- functions called
type FromRDSRec a c = (
                HReplicateF (HLength c) FFromRDS () a,
                HSequence Get a (RecordValuesR c),
                Unlabeled' c,
                RecordLabelsString (LabelsOf c),
                HNat2Integral (HLength c))

{- | R lists become HList records. @list(x=1,y='b')@ parses to something
like

> let x = Label :: Label "x"
>     y = Label :: Label "y"
> in    x .=. 1
>   .*. y .=. "b"
>   .*. emptyRecord

You have to get the result type right
(ie. @Record [Tagged \"x\" Double, Tagged \"y\" String]@) for things to parse
-}
instance FromRDSRec b d => FromRDS (Record d) where
    fromRDS = do
        531 <- get :: Get Int32
        let lenN = Proxy :: Proxy (HLength d)
            len = hNat2Integral lenN
        len2 <- get :: Get Int32
        when (len /= len2) $ error $ "fromRDS expected length: " ++ show len ++ " rds file has length: " ++ show len2

        c <- hSequence (hReplicateF lenN FFromRDS ())

        getListHdr
        "names" <- getString
        names :: [String] <- fromRDS

        let names' = recordLabelsString (Proxy :: Proxy (LabelsOf d))

        unless (names == names') $ error $ "fromRDS expected names( ): " ++ show names'
                    ++ " rds file has names attribute : " ++ show names

        254 <- get :: Get Int32
        return (c ^. from unlabeled)

class RecordLabelsString (a :: [*]) where
    recordLabelsString :: Proxy a -> [String]

instance RecordLabelsString '[] where
    recordLabelsString _ = []

instance (ShowLabel x, RecordLabelsString xs)
        => RecordLabelsString (Label x ': xs) where
    recordLabelsString _ = showLabel (Label :: Label x) : recordLabelsString (Proxy :: Proxy xs)



instance ToRDS [Double] where
    toRDS = toRDS . V.fromList

instance FromRDS [Double] where
    fromRDS = fmap V.toList $ fromRDS

instance ToRDS Double where toRDS = toRDS . (:[])
instance ToRDS Int32  where toRDS = toRDS . (:[])
instance ToRDS Integer  where toRDS = toRDS . (fromIntegral :: Integer -> Int32)
instance ToRDS Int      where toRDS = toRDS . (fromIntegral :: Int     -> Int32)

instance FromRDS Int32 where
    fromRDS = do
        [x] <- fromRDS
        return x

instance FromRDS Double where
    fromRDS = do
        [x] <- fromRDS
        return x

instance FromRDS Int where
    fromRDS = fromIntegral <$> (fromRDS :: Get Int32)

instance FromRDS Integer where
    fromRDS = fromIntegral <$> (fromRDS :: Get Int32)


instance ToRDS [Int32] where
    toRDS = toRDS . V.fromList
instance FromRDS [Int32] where
    fromRDS = fmap V.toList $ fromRDS

-- | converts to an 'Int32', which is bad on 64 bit systems where
-- @maxBound :: Int@ is a bigger number than @maxBound :: Int32@
instance ToRDS [Int] where
    toRDS = toRDS . V.fromList . map (fromIntegral :: Int -> Int32)
instance FromRDS [Int] where
    fromRDS = map (fromIntegral :: Int32 -> Int) . V.toList <$> fromRDS

-- | converts to an 'Int32' first
instance ToRDS [Integer] where
    toRDS = toRDS . V.fromList . map (fromIntegral :: Integer -> Int32)
instance FromRDS [Integer] where
    fromRDS = map (fromIntegral :: Int32 -> Integer) . V.toList <$> fromRDS


putString s = do
        let s' = BS.fromString s
        put (fromIntegral $ BS.length s' :: Int32)
        putByteString s'
getString = do
        len <- get :: Get Int32
        string <- getByteString (fromIntegral len)
        return (BS.toString string)

confirmString s = do
        s' <- getString
        unless (s' == s) $ error $ "expected "++ s ++ ", got " ++ s'

data ListStart = ListStart
instance ToRDS ListStart where
    toRDS _ = putListHdr
instance FromRDS ListStart where
    fromRDS = do
        getListHdr
        return ListStart


putListHdr = do
        put (1026 :: Int32)
        put (1 :: Int32)
        putVersion

getListHdr = do
        [1026, 1] <- replicateM 2 (get :: Get Int32)
        _ <- getVersion
        return ()


-- | probably internal
instance forall t l. (ToRDS t, ShowLabel l) => ToRDS (Tagged l t) where
    toRDS (Tagged x) = do
        putString (showLabel (undefined :: Label l))
        toRDS x
-- | probably internal
instance forall t l. (FromRDS t, ShowLabel l) => FromRDS (Tagged l t) where
    fromRDS = do
        varName <- getString

        let s = showLabel (undefined :: Label l)
        unless (varName == s) $ fail $ unwords ["FromRDS: expecting label  `", s , "', but got: `" , varName , "'"]

        x <- fromRDS
        return (Tagged x)

-- | internal
instance forall rs l2 t. (ToRDS t, ToRDS (RDA rs), ShowLabel l2) => ToRDS (RDA (Tagged l2 t ': rs)) where
    toRDS (RDA (x `HCons` xs)) = do
        putListHdr
        toRDS x
        toRDS (RDA xs)

-- | internal
instance forall rs l2 t. (FromRDS t, FromRDS (RDA rs), ShowLabel l2) => FromRDS (RDA (Tagged l2 t ': rs)) where
    fromRDS = do
        getListHdr
        x <- fromRDS :: Get (Tagged l2 t)
        RDA xs <- fromRDS :: Get (RDA rs)
        return (RDA (x `HCons` xs))

-- | internal
instance ToRDS (RDA '[]) where
    toRDS _ = put (254::Int32)

-- | internal
instance FromRDS (RDA '[]) where
    fromRDS = do
        254 <- get :: Get Int32
        return (RDA HNil)


-- | given 'A.bounds' of an array, produce a list of how many elements
-- are in each dimension. For example, a 3x2 array produces [3,2].
--
-- A single instance for \"linear\" indices would look like:
--
-- > instance (A.Ix i, Num i) => IxSize i where
-- >     ixSize x = [fromIntegral (A.rangeSize x)]
-- >     fromIxSize [n] = (0, n-1)
--
-- But to avoid overlapping instances all monomorphic index types likely
-- to be used are just repeated here. fromIxSize produces 0-based indexes
-- for instances of 'Num' ('Word', 'Int', 'Integer'), while 'minBound' is
-- used for other types.
--
-- R supports a dimnames attribute. This could be used but it is not so far.
class A.Ix i => IxSize i where
    ixSize :: (i,i) -> [Int32]
    fromIxSize :: [Int32] -> (i,i) -- ^ with 0-based indexes

fmap concat $ forM
 (map (\n -> ([| 0 |], [| fromIntegral |], n)) [''Word8, ''Word64, ''Word32, ''Word16, ''Word,
 ''Int, ''Int8, ''Int16, ''Int32, ''Int64, ''Integer] ++
  map (\n -> ([| minBound |], [| toEnum . fromIntegral |], n)) [''Ordering, ''Char, ''Bool, ''()]) $ \ (zero, fi, name) ->
    let ty = TH.conT name in
 [d| instance IxSize $ty where
        ixSize x = [fromIntegral (A.rangeSize x)]
        fromIxSize [x] = ($zero , $fi (x-1)) |]

instance (IxSize a, IxSize b) => IxSize (a,b) where
    ixSize ((a,b),(a',b')) = ixSize (a,a') ++ ixSize (b,b')
    fromIxSize [n1,n2] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
        in ((a,b),(a',b'))

instance (IxSize a, IxSize b, IxSize c) => IxSize (a,b,c) where
    ixSize ((a,b,c),(a',b',c')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c')
    fromIxSize [n1,n2,n3] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
        in ((a,b,c),(a',b',c'))

instance (IxSize a, IxSize b, IxSize c, IxSize d) => IxSize (a,b,c,d) where
    ixSize ((a,b,c,d),(a',b',c',d')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c') ++ ixSize (d,d')
    fromIxSize [n1,n2,n3,n4] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
            (d,d') = fromIxSize [n4]
        in ((a,b,c,d),(a',b',c',d'))

instance (IxSize a, IxSize b, IxSize c, IxSize d, IxSize e) => IxSize (a,b,c,d,e) where
    ixSize ((a,b,c,d,e),(a',b',c',d',e')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c') ++ ixSize (d,d') ++ ixSize (e,e')
    fromIxSize [n1,n2,n3,n4,n5] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
            (d,d') = fromIxSize [n4]
            (e,e') = fromIxSize [n5]
        in ((a,b,c,d,e),(a',b',c',d',e'))


-- | @"Data.Array".Array@ become arrays in R
instance (IxSize i) => ToRDS (A.Array i Double) where
    toRDS arr = toRDSArray
        True
        (fromIntegral (A.rangeSize (A.bounds arr)))
        (mapM_ putDouble (A.elems arr))
        (ixSize (A.bounds arr))

toRDSArray :: Bool -- ^ is a double array (otherwise Int)
    -> Int32 -- ^  number of elements
    -> Put -- ^ put the elements
    -> [Int32] -- ^ bounds
    -> Put
toRDSArray isDouble size putElts bnds = do
        put (if isDouble then 526 else 524 :: Int32)
        put size
        putElts
        putListHdr
        putString "dim"
        toRDS bnds
        put (254 :: Int32)


instance (IxSize i) => ToRDS (A.Array i Int32) where
    toRDS arr = toRDSArray
        False
        (fromIntegral (A.rangeSize (A.bounds arr)))
        (mapM_ put (A.elems arr))
        (ixSize (A.bounds arr))






-- | note indices become 0-based (see 'IxSize')
instance (IxSize i) => FromRDS (A.Array i Double) where
    fromRDS = do
        (526 :: Int32) <- get
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) getDouble
        getListHdr
        "dim" <- getString
        bds <- fromRDS
        (254 :: Int32) <- get
        return (A.listArray (fromIxSize bds) els)

-- | note indices become 0-based (see 'IxSize')
instance (IxSize i) => FromRDS (A.Array i Int32) where
    fromRDS = do
        (524 :: Int32) <- get
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) get
        getListHdr
        "dim" <- getString
        bds <- fromRDS
        (254 :: Int32) <- get
        return (A.listArray (fromIxSize bds) els)

#if REPA
toRDSRepaArr b putFn arr =
    let nel = R.size (R.extent arr) in
      toRDSArray
        b
        (fromIntegral nel)
        (forM_ [0 .. nel - 1] (putFn . R.linearIndex arr))
        (map fromIntegral (R.listOfShape (R.extent arr)))

-- | repa
instance (R.Source r Double, R.Shape sh) => ToRDS (R.Array r sh Double) where
    toRDS = toRDSRepaArr True putDouble

-- | repa
instance (R.Source r Int32, R.Shape sh) => ToRDS (R.Array r sh Int32) where
    toRDS = toRDSRepaArr False put


fromRDSRepa getElt = do
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) getElt
        getListHdr
        "dim" <- getString
        bds :: [Int32] <- fromRDS
        (254 :: Int32) <- get
        return (R.fromListUnboxed (R.shapeOfList (map fromIntegral bds)) els)


-- | repa
instance R.Shape sh => FromRDS (R.Array R.U sh Double) where
    fromRDS = do
        (526 :: Int32) <- get
        fromRDSRepa getDouble

-- | repa
instance R.Shape sh => FromRDS (R.Array R.U sh Int32) where
    fromRDS = do
        (524 :: Int32) <- get
        fromRDSRepa get
#endif


data AnyRDS where AnyRDS :: (ToRDS a) => a -> AnyRDS
instance ToRDS AnyRDS where toRDS (AnyRDS x) = toRDS x
-- need a witness, which isn't being passed in just yet,
-- or a parser into some intermediate format
-- instance FromRDS AnyRDS where fromRDS = AnyRDS `fmap` fromRDS

-- | @M.fromList [(\"a\",AnyRDS 1),(\"b\", AnyRDS 2)]@ becomes @list(a=1, b=2)@
instance ToRDS (M.Map String AnyRDS) where
    toRDS xs = do
        put (531::Int32)
        put (fromIntegral (M.size xs) :: Int32)
        mapM_ toRDS (M.elems xs)
        putListHdr
        putString "names"
        toRDS (M.keys xs)
        put (254 :: Int32)


toRDA x = GZip.compress $ runPut $ do
    mapM_ put "RDX2\nX\n"
    put (2 :: Int32)
    put (196609 :: Int32)
    put (131840 :: Int32)
    toRDS (RDA (recordValues x))


{- $rdaFmt

A typical type would be

> Record '[Tagged l1 (Tagged m1 x), Tagged l2 (Tagged m2 x)]

The outer labels (@l1@, @l2@) are those used on the haskell-side. The inner labels
are the @m1@ @m2@ which are the names of the variables on the R side.


-}


fromRDA :: FromRDA r => B.ByteString -> Record r
fromRDA x = ( $ GZip.decompress x) $ runGet $ do
    let hdr =  "RDX2\nX\n"
    hdr' <- fmap (BS.toString) $ getByteString (BS.length (BS.fromString hdr))
    unless (hdr == hdr') $ fail "wrong header"
    [{- 2 , 196609, 131840 -} _, _, _ :: Int32 ] <- replicateM 3 get
    fromRDS <&> view ( _Wrapping RDA . from unlabeled)

type FromRDA r = (Unlabeled' r, FromRDS (RDA (RecordValuesR r)))

-- * tests

makeLabels6 (words "x abc lab2")

-- | are these records necessarily this noisy?
sampV1 = x .=. newLVPair x [1,2,3,4 :: Double] .*.
        abc .=. newLVPair abc (V.fromList [4 :: Double]) .*.
        emptyRecord

testPut = B.writeFile "/tmp/foo2" $ toRDA sampV1

roundtrip = toRDA ((fromRDA $ toRDA sampV1) `asTypeOf` sampV1) == toRDA sampV1

sampleList = Record (recordValues sampV1)

testPut2 = B.writeFile "/tmp/foo3" $ toRDA (abc .=. (abc .=. sampleList) .*. emptyRecord)

sampV2 = x .=. x .=. (x .=. sampleList) .*.
            abc .=. (abc .=. sampleList) .*. emptyRecord

sampV3 = x .=. newLVPair x [1,2,3,4 :: Double] .*.
        abc .=. newLVPair abc "hi" .*.
        lab2 .=. newLVPair lab2 sampArr .*.
        emptyRecord

sampArr :: A.Array (Int,Int) Double
sampArr = A.listArray ((0,0),(2,2)) [1 .. 9]

sampMap = M.fromList [("x", AnyRDS [2 :: Double]), ("y", AnyRDS sampArr) ]

testPut3 = do
    B.writeFile "/tmp/foo3" $ toRDA sampV3
    readProcess "R" ["--no-save"] "load('/tmp/foo3')"

testPut4 = do
    B.writeFile "/tmp/foo3" $ toRDA (x .=. (x .=. sampMap) .*. emptyRecord)
    readProcess "R" ["--no-save"] "load('/tmp/foo3')"