{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}

{- | Conversions between R's RDS/RDA format and haskell data types.

tested with R 3.0.1


* Data.Map

* better error reporting when the format is bad?

* more tests

module RlangQQ.Binary (
    -- * functions to serialize many variables
    -- $rdaFmt
    toRDA, fromRDA,

    -- * serializing a single variable
    ToRDS(..), FromRDS(..),

    -- * types / internal
    FromRDA, ToRDSRecord, RDSHLIST, RDA, IxSize(..),

    module Data.HList.CommonMain, ) where

import System.Process
import Unsafe.Coerce
import Control.Applicative
import qualified Data.ByteString.Lazy.UTF8 as B
import qualified Data.ByteString.UTF8 as BS (fromString, toString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as B
import qualified Data.ByteString.Lazy.Char8 as B8

import Data.Int
import Data.HList.CommonMain
import qualified Data.Vector as V
import qualified Data.Map as M

import Data.Binary
import Data.Binary.Get
import Data.Binary.Builder
import Data.Binary.Put
import qualified Data.Binary

import qualified Data.Text as T
import Data.Text.Encoding as E

import Control.Monad.Identity
import qualified Data.Array.Repa as R

import GHC.TypeLits

import qualified Data.Array as A
import qualified Language.Haskell.TH as TH

import qualified Codec.Compression.GZip as GZip

import HListExtras

data FLVPair = FLVPair
instance (LVPair l a ~ b) => ApplyAB FLVPair a b where
    applyAB FLVPair x = LVPair x

-- this should not be necessary
type family UnHMapFLVPair (a :: [*]) :: [*]
type instance UnHMapFLVPair (LVPair l a ': as) = a ': UnHMapFLVPair as
type instance UnHMapFLVPair '[] = '[]

-- | this seems to be a magic number. Corresponds to R 3.0.1
putVersion = put (262153 :: Int32)

-- | doesn't check
getVersion = get :: Get Int32

-- | same as Binary but should be compatible with R's @saveRDS@
-- binary mode, which is for single objects
class ToRDS a where
    toRDS :: a -> Put
class FromRDS a where
    fromRDS :: Get a

-- | the binary instance of Double produces 25 bytes, not the 8 bytes for ieee754
-- perhaps alternatives to binary do better (bytes, cereal etc).
putDouble = put . (unsafeCoerce :: Double -> Int64)
getDouble = fmap (unsafeCoerce :: Int64 -> Double) get

-- | becomes a numeric vector
instance ToRDS (V.Vector Double) where
    toRDS x = do
        put (14 :: Int32)
        put (fromIntegral $ V.length x :: Int32)
        V.mapM_ putDouble x
instance FromRDS (V.Vector Double) where
    fromRDS = do
        14 <- get :: Get Int32
        len <- get :: Get Int32
        V.mapM (const getDouble) $ V.replicate (fromIntegral len) ()

-- | integer vector
instance ToRDS (V.Vector Int32) where
    toRDS x = do
        put (13 :: Int32)
        put (fromIntegral $ V.length x :: Int32)
        V.mapM_ put x
instance FromRDS (V.Vector Int32) where
    fromRDS = do
        13 <- get :: Get Int32
        len <- get
        V.mapM (const get) $ V.replicate len ()

-- | character vector @c(\'ab\',\'cd\')@
instance ToRDS (V.Vector T.Text) where
    toRDS x = do
        put (16 :: Int32)
        put (fromIntegral (V.length x) :: Int32)
        V.mapM_ (\x -> do
            let x' = E.encodeUtf8 x
            put (fromIntegral $ BS.length x' :: Int32)
            putByteString x')  x

instance FromRDS (V.Vector T.Text) where
    fromRDS = do
        16 <- get :: Get Int32
        nstr <- get :: Get Int32
        V.mapM (const $ do
            get :: Get Int32 -- version
            len <- get :: Get Int32
            bs <- getByteString (fromIntegral len)
            return (E.decodeUtf8 bs)
            ) $ V.replicate (fromIntegral nstr) ()

-- | T.pack "abc" => @c(\'ab\')@
instance ToRDS T.Text where
    toRDS x = toRDS (V.singleton x)

instance FromRDS T.Text where
    fromRDS = (\(x:_) -> x) . V.toList <$> fromRDS

-- | @[\"abc\",\"def\"]@ => @c(\'abc\',\'def\')@
instance ToRDS [String] where
    toRDS x = toRDS $ V.fromList (map T.pack x)
instance FromRDS [String] where
    fromRDS = map T.unpack . V.toList <$> fromRDS

-- | "abc" => @c(\'ab\')@
instance ToRDS String where
    toRDS = toRDS . (:[])
instance FromRDS String where
    fromRDS = fmap (\(x:_) -> x) fromRDS

data FToRDS = FToRDS
instance (ToRDS a, putm ~ Put) => ApplyAB FToRDS a putm where
    applyAB FToRDS x = toRDS x

data FFromRDS = FFromRDS
instance (FromRDS b, Get b ~ getB, a ~ ()) => ApplyAB FFromRDS a getB where
    applyAB FFromRDS _ = fromRDS

type RDSHLIST xs' xs = (HNat2Integral (HLength xs), HFoldr (HSeq FToRDS) Put xs Put)

instance  (RDSHLIST xs' xs) => ToRDS (LST xs) where
    toRDS (LST xs) = do
        put (531::Int32)
        let [(len,"")] = reads (drop 1 $ show (proxy :: Proxy (HLength xs)))
        put (fromIntegral (len - 2 :: Int) :: Int32)
        hFoldr (HSeq FToRDS) (return () :: Put) xs :: Put
        put (254::Int32)

type RDSHLIST2 b bs' l = (HSequence Get b l, HSequence ((->) ()) bs' b,
      HReplicate (HLength l) FFromRDS,
      HMapAux FApply (HReplicateR (HLength l) FFromRDS) bs',
      SameLength (HReplicateR (HLength l) FFromRDS) bs',
      SameLength bs' (HReplicateR (HLength l) FFromRDS),
      HNat2Integral (HLength l))

instance (RDSHLIST2 ___ __ l) => FromRDS (LST l) where
    fromRDS = withSelf $ \(self) -> do
        531 <- get :: Get Int32
        let [(len,"")] = reads (drop 1 $ show (hLength self))
        len2 <- get :: Get Int32
        when (len /= len2) $ error $ "fromRDS expected length: " ++ show len ++ " rds file has length: " ++ show len2
        r <- hSequence (hSequence (hMap FApply $ hReplicate (hLength self) FFromRDS) ())
        254 <- get :: Get Int32
        return (LST r)
        withSelf :: forall (a::[*]) m. (HList a -> m (LST a)) -> m (LST a)
        withSelf x = x (error "RlangQQ.Binary.LST.self")

-- | not to be constructed outside here
newtype LST (a :: [*]) = LST (HList a)

-- | @lab .=. val .*. 'emptyRecord'@ => @list(lab= ('toRDS' val) )@.
-- The type variables with underscores should be hidden
instance ToRDSRecord __ ___ xs => ToRDS (Record xs) where
    toRDS (Record xs) = toRDS $ LST $
                recordValues (Record xs) `hAppend`
                (ListStart `HCons` (Label :: Label "names") .=. (recordLabelsString (error "recLabs" :: Proxy (RecordLabels xs))) `HCons` HNil)

type ToRDSRecord __ ___ xs = (RDSHLIST __ ___, ToRDS (LST ___),
            RecordValues xs,
            HList ___ ~ (HList (RecordValuesR xs) `HAppendR` HList '[ListStart, LVPair "names" [String]]),
            RecordLabelsString (RecordLabels xs),
            HAppend (HList (RecordValuesR xs)) (HList '[ ListStart, LVPair "names" [String]]))

-- | this signature shouldn't be necessary... it just repeats all the
-- functions called
type FromRDSRec a b as' as'2 bs' = (HSequence Get b as',
                HSequence ((->) ()) a b,
                -- HMap1 FLVPair as' bs',
                HMapAux FLVPair as' bs',
                SameLength as' bs',
                SameLength bs' as',

                HMapAux FApply as'2 a,
                SameLength as'2 a,
                SameLength a as'2,

                HMapOut (HComp FShowLabel FLabelLVPair) bs' String,
                RecordLabelsString (RecordLabels bs'),
                HNat2Integral (HLength bs'),
                HReplicate2 (HLength bs') FFromRDS as'2)

instance FromRDSRec a b as' as'2 bs' => FromRDS (Record bs') where
    fromRDS = do
        531 <- get :: Get Int32
        let len = hNat2Integral (proxy :: Proxy (HLength bs'))
        len2 <- get :: Get Int32
        when (len /= len2) $ error $ "fromRDS expected length: " ++ show len ++ " rds file has length: " ++ show len2

        r <- hSequence
                    (hMap FApply
                        (hReplicate2 (proxy :: Proxy (HLength bs')) FFromRDS) )
        "names" <- getString
        names :: [String] <- fromRDS

        let names' = recordLabelsString (error "recLabs" :: Proxy (RecordLabels bs'))

        unless (names == names') $ error $ "fromRDS expected names( ): " ++ show names'
                    ++ " rds file has names attribute : " ++ show names

        254 <- get :: Get Int32
        return (Record (hMap FLVPair (r :: HList as')))
        -- this hMap1 can't be replaced by hMap?

class RecordLabelsString (a :: [Symbol]) where
    recordLabelsString :: Proxy a -> [String]

instance RecordLabelsString '[] where
    recordLabelsString _ = []

instance (ShowLabel x, RecordLabelsString xs)
        => RecordLabelsString (x ': xs) where
    recordLabelsString _ = showLabel (Label :: Label x) : recordLabelsString (proxy :: Proxy xs)

-- same as recordLabelsString, but less lazy
recLabs (Record xs) = hMapOut (HComp FShowLabel FLabelLVPair) xs

data FLabelLVPair = FLabelLVPair
instance(LVPair l a ~ x, y ~ Label l) => ApplyAB FLabelLVPair x y where
    applyAB FLabelLVPair = labelLVPair

data FShowLabel = FShowLabel
instance (string ~ String, ShowLabel l, ll ~ Label l) => ApplyAB FShowLabel ll string
    where applyAB _ = showLabel

instance ToRDS [Double] where
    toRDS = toRDS . V.fromList

instance FromRDS [Double] where
    fromRDS = fmap V.toList $ fromRDS

instance ToRDS Double where toRDS = toRDS . (:[])
instance ToRDS Int32  where toRDS = toRDS . (:[])

instance FromRDS Int32 where
    fromRDS = do
        [x] <- fromRDS
        return x

instance FromRDS Double where
    fromRDS = do
        [x] <- fromRDS
        return x

instance ToRDS [Int32] where
    toRDS = toRDS . V.fromList
instance FromRDS [Int32] where
    fromRDS = fmap V.toList $ fromRDS

putString s = do
        let s' = BS.fromString s
        put (fromIntegral $ BS.length s' :: Int32)
        putByteString s'
getString = do
        len <- get :: Get Int32
        string <- getByteString (fromIntegral len)
        return (BS.toString string)

confirmString s = do
        s' <- getString
        unless (s' == s) $ error $ "expected "++ s ++ ", got " ++ s'

data ListStart = ListStart
instance ToRDS ListStart where
    toRDS _ = putListHdr
instance FromRDS ListStart where
    fromRDS = do
        return ListStart

putListHdr = do
        put (1026 :: Int32)
        put (1 :: Int32)

getListHdr = do
        [1026, 1, 262153] <- replicateM 3 (get :: Get Int32)
        return ()

-- | probably internal
instance forall t l. (ToRDS t, ShowLabel l) => ToRDS (LVPair l t) where
    toRDS (LVPair x) = do
        putString (showLabel (undefined :: Label l))
        toRDS x
-- | probably internal
instance forall t l. (FromRDS t, ShowLabel l) => FromRDS (LVPair l t) where
    fromRDS = do
        varName <- getString

        let s = showLabel (undefined :: Label l)
        unless (varName == s) $ fail $ unwords ["FromRDS: expecting label  `", s , "', but got: `" , varName , "'"]

        x <- fromRDS
        return (LVPair x)

-- | labels are stored with the variables here. compare with the instance for 'Record' / 'LST' which collects
-- the labels and saves them as an attribute called \"names\"
newtype RDA a = RDA (HList a)

-- | internal
instance forall rs l2 t. (ToRDS t, ToRDS (RDA rs), ShowLabel l2) => ToRDS (RDA (LVPair l2 t ': rs)) where
    toRDS (RDA (x `HCons` xs)) = do
        toRDS x
        toRDS (RDA xs)

-- | internal
instance forall rs l2 t. (FromRDS t, FromRDS (RDA rs), ShowLabel l2) => FromRDS (RDA (LVPair l2 t ': rs)) where
    fromRDS = do
        x <- fromRDS :: Get (LVPair l2 t)
        RDA xs <- fromRDS :: Get (RDA rs)
        return (RDA (x `HCons` xs))

-- | internal
instance ToRDS (RDA '[]) where
    toRDS _ = put (254::Int32)

-- | internal
instance FromRDS (RDA '[]) where
    fromRDS = do
        254 <- get :: Get Int32
        return (RDA HNil)

-- | given 'A.bounds' of an array, produce a list of how many elements
-- are in each dimension. For example, a 3x2 array produces [3,2].
-- A single instance for \"linear\" indices would look like:
-- > instance (A.Ix i, Num i) => IxSize i where
-- >     ixSize x = [fromIntegral (A.rangeSize x)]
-- >     fromIxSize [n] = (0, n-1)
-- But to avoid overlapping instances all monomorphic index types likely
-- to be used are just repeated here. fromIxSize produces 0-based indexes
-- for instances of 'Num' ('Word', 'Int', 'Integer'), while 'minBound' is
-- used for other types.
-- R supports a dimnames attribute. This could be used but it is not so far.
class A.Ix i => IxSize i where
    ixSize :: (i,i) -> [Int32]
    fromIxSize :: [Int32] -> (i,i) -- ^ with 0-based indexes

fmap concat $ forM
 (map (\n -> ([| 0 |], [| fromIntegral |], n)) [''Word8, ''Word64, ''Word32, ''Word16, ''Word,
 ''Int, ''Int8, ''Int16, ''Int32, ''Int64, ''Integer] ++
  map (\n -> ([| minBound |], [| toEnum . fromIntegral |], n)) [''Ordering, ''Char, ''Bool, ''()]) $ \ (zero, fi, name) ->
    let ty = TH.conT name in
 [d| instance IxSize $ty where
        ixSize x = [fromIntegral (A.rangeSize x)]
        fromIxSize [x] = ($zero , $fi (x-1)) |]

instance (IxSize a, IxSize b) => IxSize (a,b) where
    ixSize ((a,b),(a',b')) = ixSize (a,a') ++ ixSize (b,b')
    fromIxSize [n1,n2] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
        in ((a,b),(a',b'))

instance (IxSize a, IxSize b, IxSize c) => IxSize (a,b,c) where
    ixSize ((a,b,c),(a',b',c')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c')
    fromIxSize [n1,n2,n3] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
        in ((a,b,c),(a',b',c'))

instance (IxSize a, IxSize b, IxSize c, IxSize d) => IxSize (a,b,c,d) where
    ixSize ((a,b,c,d),(a',b',c',d')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c') ++ ixSize (d,d')
    fromIxSize [n1,n2,n3,n4] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
            (d,d') = fromIxSize [n4]
        in ((a,b,c,d),(a',b',c',d'))

instance (IxSize a, IxSize b, IxSize c, IxSize d, IxSize e) => IxSize (a,b,c,d,e) where
    ixSize ((a,b,c,d,e),(a',b',c',d',e')) = ixSize (a,a') ++ ixSize (b,b') ++ ixSize (c,c') ++ ixSize (d,d') ++ ixSize (e,e')
    fromIxSize [n1,n2,n3,n4,n5] =
        let (a,a') = fromIxSize [n1]
            (b,b') = fromIxSize [n2]
            (c,c') = fromIxSize [n3]
            (d,d') = fromIxSize [n4]
            (e,e') = fromIxSize [n5]
        in ((a,b,c,d,e),(a',b',c',d',e'))

-- | @"Data.Array".Array@ become arrays in R
instance (IxSize i) => ToRDS (A.Array i Double) where
    toRDS arr = toRDSArray
        (fromIntegral (A.rangeSize (A.bounds arr)))
        (mapM_ putDouble (A.elems arr))
        (ixSize (A.bounds arr))

toRDSArray :: Bool -- ^ is a double array (otherwise Int)
    -> Int32 -- ^  number of elements
    -> Put -- ^ put the elements
    -> [Int32] -- ^ bounds
    -> Put
toRDSArray isDouble size putElts bnds = do
        put (if isDouble then 526 else 524 :: Int32)
        put size
        putString "dim"
        toRDS bnds
        put (254 :: Int32)

instance (IxSize i) => ToRDS (A.Array i Int32) where
    toRDS arr = toRDSArray
        (fromIntegral (A.rangeSize (A.bounds arr)))
        (mapM_ put (A.elems arr))
        (ixSize (A.bounds arr))

-- | note indices become 0-based (see 'IxSize')
instance (IxSize i) => FromRDS (A.Array i Double) where
    fromRDS = do
        (526 :: Int32) <- get
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) getDouble
        "dim" <- getString
        bds <- fromRDS
        (254 :: Int32) <- get
        return (A.listArray (fromIxSize bds) els)

-- | note indices become 0-based (see 'IxSize')
instance (IxSize i) => FromRDS (A.Array i Int32) where
    fromRDS = do
        (524 :: Int32) <- get
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) get
        "dim" <- getString
        bds <- fromRDS
        (254 :: Int32) <- get
        return (A.listArray (fromIxSize bds) els)

toRDSRepaArr b putFn arr =
    let nel = R.size (R.extent arr) in
        (fromIntegral nel)
        (forM_ [0 .. nel - 1] (putFn . R.linearIndex arr))
        (map fromIntegral (R.listOfShape (R.extent arr)))

-- | repa
instance (R.Source r Double, R.Shape sh) => ToRDS (R.Array r sh Double) where
    toRDS = toRDSRepaArr True putDouble

-- | repa
instance (R.Source r Int32, R.Shape sh) => ToRDS (R.Array r sh Int32) where
    toRDS = toRDSRepaArr False put

fromRDSRepa getElt = do
        (nel :: Int32) <- get
        els <- replicateM (fromIntegral nel) getElt
        "dim" <- getString
        bds :: [Int32] <- fromRDS
        (254 :: Int32) <- get
        return (R.fromListUnboxed (R.shapeOfList (map fromIntegral bds)) els)

-- | repa
instance R.Shape sh => FromRDS (R.Array R.U sh Double) where
    fromRDS = do
        (526 :: Int32) <- get
        fromRDSRepa getDouble

-- | repa
instance R.Shape sh => FromRDS (R.Array R.U sh Int32) where
    fromRDS = do
        (524 :: Int32) <- get
        fromRDSRepa get

data AnyRDS where AnyRDS :: (ToRDS a) => a -> AnyRDS
instance ToRDS AnyRDS where toRDS (AnyRDS x) = toRDS x
-- need a witness, which isn't being passed in just yet,
-- or a parser into some intermediate format
-- instance FromRDS AnyRDS where fromRDS = AnyRDS `fmap` fromRDS

-- | @M.fromList [(\"a\",AnyRDS 1),(\"b\", AnyRDS 2)]@ becomes @list(a=1, b=2)@
instance ToRDS (M.Map String AnyRDS) where
    toRDS xs = do
        put (531::Int32)
        put (fromIntegral (M.size xs) :: Int32)
        mapM_ toRDS (M.elems xs)
        putString "names"
        toRDS (M.keys xs)
        put (254 :: Int32)

toRDA x = GZip.compress $ runPut $ do
    mapM_ put "RDX2\nX\n"
    put (2 :: Int32)
    put (196609 :: Int32)
    put (131840 :: Int32)
    toRDS (RDA (recordValues x))

{- $rdaFmt

A typical type would be

> Record '[LVPair l1 (LVPair m1 x), LVPair l2 (LVPair m2 x)]

The outer labels (@l1@, @l2@) are the haskell-side. The inner labels
are the @m1@ @m2@ which are the names of the variables on the R side.


-- why can't the __ type be inferred?
fromRDA :: forall __ r.  FromRDA __ r => B.ByteString -> Record r
fromRDA x = ( $ GZip.decompress x) $ runGet $ do
    let hdr =  "RDX2\nX\n"
    hdr' <- fmap (BS.toString) $ getByteString (BS.length (BS.fromString hdr))
    unless (hdr == hdr') $ fail "wrong header"
    [{- 2 , 196609, 131840 -} _, _, _ :: Int32 ] <- replicateM 3 get
    fmap (\(RDA a) -> Record (hMap1 FLVPair (a::HList __) )) fromRDS

type FromRDA a r = (HMap1 FLVPair a r, FromRDS (RDA a))

-- * tests

makeLabels6 (words "x abc lab2")

-- | are these records necessarily this noisy?
sampV1 = x .=. newLVPair x [1,2,3,4 :: Double] .*.
        abc .=. newLVPair abc (V.fromList [4 :: Double]) .*.

testPut = B.writeFile "/tmp/foo2" $ toRDA sampV1

roundtrip = toRDA ((fromRDA $ toRDA sampV1) `asTypeOf` sampV1) == toRDA sampV1

sampleList = Record (recordValues sampV1)

testPut2 = B.writeFile "/tmp/foo3" $ toRDA (abc .=. (abc .=. sampleList) .*. emptyRecord)

sampV2 = x .=. x .=. (x .=. sampleList) .*.
            abc .=. (abc .=. sampleList) .*. emptyRecord

sampV3 = x .=. newLVPair x [1,2,3,4 :: Double] .*.
        abc .=. newLVPair abc "hi" .*.
        lab2 .=. newLVPair lab2 sampArr .*.

sampArr :: A.Array (Int,Int) Double
sampArr = A.listArray ((0,0),(2,2)) [1 .. 9]

sampMap = M.fromList [("x", AnyRDS [2 :: Double]), ("y", AnyRDS sampArr) ]

testPut3 = do
    B.writeFile "/tmp/foo3" $ toRDA sampV3
    readProcess "R" ["--no-save"] "load('/tmp/foo3')"

testPut4 = do
    B.writeFile "/tmp/foo3" $ toRDA (x .=. (x .=. sampMap) .*. emptyRecord)
    readProcess "R" ["--no-save"] "load('/tmp/foo3')"