-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Index
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Redundant version of @Torch.Indef.{Dynamic/Static}.Tensor@ for Index tensors.
--
-- FIXME: in the future, there could be a smaller subset of Torch which could
-- be compiled to to keep the code dry. Alternatively, if backpack one day
-- supports recursive indefinites, we could use this feature to possibly remove
-- this package and 'Torch.Indef.Mask'.
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MonoLocalBinds #-}
module Torch.Indef.Index
  ( singleton
  , esingleton
  , newIx
  , newIxDyn
  , zeroIxNd
  , index
  , index1d
  , indexNd
  , indexDyn
  , mkCPUIx
  , withCPUIxStorage
  , withIxStorage
  , withDynamicState
  , mkCPUIxStorage
  , ixShape
  , ixCPUStorage
  , showIx
  ) where

import Foreign
import Foreign.Ptr
import Data.Proxy
import Data.List
import Data.Typeable
import Data.Maybe
import Data.Singletons.Prelude.List
import Control.Monad
import Control.Applicative
import System.IO.Unsafe
import Numeric.Dimensions

import Torch.Sig.State as Sig
import Torch.Sig.Types.Global as Sig
import Torch.Indef.Types hiding (withDynamicState, mkDynamic, mkDynamicIO)
import Torch.Indef.Internal
import qualified Torch.Types.TH as TH
import qualified Torch.Types.TH.Long as THLong
import qualified Torch.Sig.Index.Tensor as IxSig
import qualified Torch.Sig.Index.TensorFree as IxSig
import qualified Torch.FFI.TH.Long.Storage as LongStorage
import qualified Torch.FFI.TH.Long.Storage as LongStorage
import qualified Torch.FFI.TH.Long.Tensor as LongTensor
import qualified Foreign as FM
import qualified Foreign.Marshal.Array as FM

singleton :: Integral i => i -> IndexTensor '[1]
singleton = longAsStatic . indexDyn . (:[]) . fromIntegral

esingleton :: Enum i => i -> IndexTensor '[1]
esingleton = singleton . fromEnum

-- | build a new static index tensor
--
-- FIXME: This can abstracted away with backpack, but I'm not sure if it can do it atm.
newIx :: forall d . Dimensions d => IndexTensor d
newIx = longAsStatic $ newIxDyn (listDims (dims :: Dims d))

-- | build a new, empty, static index tensor with no term-level dimensions -- but allow
-- the type-level dimensions to vary.
--
-- FIXME: this is a bad function and should be replaced with a 'Torch.Indef.Static.Tensor.new' function.
zeroIxNd :: Dimensions d => IndexTensor d
zeroIxNd = longAsStatic $ newIxDyn [0]

-- | build a new 1-dimensional, dynamically-typed index tensor of lenght @i@
newIxDyn :: Integral i => [i] -> IndexDynamic
newIxDyn is = unsafePerformIO $ withForeignPtr Sig.torchstate $ \s ->
  case is of
    [] -> IxSig.c_new s >>= mkDynamic
    [x] -> IxSig.c_newWithSize1d s (fromIntegral x) >>= mkDynamic
    [x, y] -> IxSig.c_newWithSize2d s (fromIntegral x) (fromIntegral y) >>= mkDynamic
    [x, y, z] -> IxSig.c_newWithSize3d s (fromIntegral x) (fromIntegral y) (fromIntegral z) >>= mkDynamic
    _ -> error "> 4-dimensional indexes not currently supported"

-- | Make a dynamic, 1d index tensor from a list.
--
-- FIXME construct this with TH, not with the setting, which might be doing a second linear pass
indexDyn :: [Integer] -> IndexDynamic
indexDyn l = unsafePerformIO $ do
  let res = newIxDyn [length l]
  mapM_  (upd res) (zip [0..length l - 1] l)
  pure res

  where
    upd :: IndexDynamic -> (Int, Integer) -> IO ()
    upd t (idx, v) = withDynamicState t $ \s' t' -> IxSig.c_set1d s' t' (fromIntegral idx) (fromIntegral v)

-- | purely make a 1d static index tensor from a list of integers. Returns Nothing if the
-- list does not match the expected size of the tensor.
--
-- should be depreciated in favor of 'index1d'
index :: forall n . KnownDim n => [Integer] -> Maybe (IndexTensor '[n])
index l
  | genericLength l == dimVal (dim :: Dim n) = Just . longAsStatic . indexDyn $ l
  | otherwise = Nothing

-- | alias to 'index' and should subsume it when this package no longer assumes that index tensors are 1d.
index1d :: KnownDim n => [Integer] -> Maybe (IndexTensor '[n])
index1d = index

-- | n-dimensional version of 'index1d'.
--
-- FIXME: this relies on 'indexDyn' which only makes 1d tensors.
indexNd :: forall d . KnownDim (Product d) => [Integer] -> Maybe (IndexTensor d)
indexNd l
  | genericLength l == dimVal (dim :: Dim (Product d)) = Just . longAsStatic . indexDyn $ l
  | otherwise = Nothing

-- | Convenience method for 'newWithData' specific to longs for making CPU Long storage.
ixCPUStorage :: [Integer] -> IO TH.LongStorage
ixCPUStorage pr = withForeignPtr TH.torchstate $ \st -> do
  pr' <- FM.withArray (THLong.hs2cReal <$> pr) pure
  thl <- LongStorage.c_newWithData st pr' (fromIntegral $ length pr)
  TH.LongStorage <$> ((TH.torchstate,)
    <$> FM.newForeignPtr LongStorage.p_free thl)

-- | resize a 1d dynamic index tensor.
--
-- FIXME: export or remove this function as appropriate.
_resizeDim1d :: IndexDynamic -> Integer -> IO ()
_resizeDim1d t x = withDynamicState t $ \s' t' -> IxSig.c_resize1d s' t' (fromIntegral x)

-- | make a dynamic CPU tensor from a raw torch ctensor
mkCPUIx :: Ptr TH.C'THLongTensor -> IO TH.LongDynamic
mkCPUIx p = fmap TH.LongDynamic
  $ (TH.torchstate,)
  <$> newForeignPtr LongTensor.p_free p

-- | run a function with access to a raw CPU-bound Long tensor storage.
withCPUIxStorage :: TH.LongStorage -> (Ptr TH.C'THLongStorage -> IO x) -> IO x
withCPUIxStorage ix fn = withForeignPtr (snd $ TH.longStorageState ix) fn

-- | run a function with access to a dynamic index tensor's raw internal state and c-pointer.
withDynamicState :: IndexDynamic -> (Ptr Sig.CState -> Ptr Sig.CLongTensor -> IO x) -> IO x
withDynamicState t fn = do
  withForeignPtr (fst $ Sig.longDynamicState t) $ \sref ->
    withForeignPtr (snd $ Sig.longDynamicState t) $ \tref ->
      fn sref tref

-- | run a function with access to a dynamic index storage's raw internal state and c-pointer.
withIxStorage :: Sig.IndexStorage -> (Ptr CLongStorage -> IO x) -> IO x
withIxStorage ix fn = withForeignPtr (snd $ Sig.longStorageState ix) fn

-- | make a dynamic CPU tensor's storage from a raw torch LongStorage
mkCPUIxStorage :: Ptr TH.C'THLongStorage -> IO TH.LongStorage
mkCPUIxStorage p = fmap TH.LongStorage
  $ (TH.torchstate,)
  <$> newForeignPtr LongStorage.p_free p

-- | get the shape of a static index tensor from the term-level
ixShape :: IndexTensor d -> [Word]
ixShape t = unsafePerformIO $ withDynamicState (longAsDynamic t) $ \s' t' -> do
  ds <- IxSig.c_nDimension s' t'
  mapM (fmap fromIntegral . IxSig.c_size s' t' . fromIntegral) [0..ds-1]

-- | show an index.
--
-- FIXME: because we are using backpack, we can't declare a show instance on the IndexTensor both
-- here and in the signatures. /if we want this functionality we must operate on raw code and write
-- the show instance in hasktorch-types/.
{-# NOINLINE showIx #-}
showIx t = unsafePerformIO $ do
  let ds = fromIntegral <$> ixShape t
  (vs, desc) <- go (ixGet1d t) (ixGet2d t) (ixGet3d t) (ixGet4d t) ds
  pure (vs ++ "\n" ++ desc)
 where
  ixGet1d :: IndexTensor d -> Int64 -> IO Integer
  ixGet1d it i = fmap fromIntegral . withDynamicState (longAsDynamic it) $ \s' it' -> IxSig.c_get1d s' it'
    (fromIntegral i)
  ixGet2d :: IndexTensor d -> Int64 -> Int64 -> IO Integer
  ixGet2d it i i1 = fmap fromIntegral . withDynamicState (longAsDynamic it) $ \s' it' -> IxSig.c_get2d s' it'
    (fromIntegral i) (fromIntegral i1)
  ixGet3d :: IndexTensor d -> Int64 -> Int64 -> Int64 -> IO Integer
  ixGet3d it i i1 i2 = fmap fromIntegral . withDynamicState (longAsDynamic it) $ \s' it' -> IxSig.c_get3d s' it'
    (fromIntegral i) (fromIntegral i1) (fromIntegral i2)

  ixGet4d :: IndexTensor d -> Int64 -> Int64 -> Int64 -> Int64 -> IO Integer
  ixGet4d it i i1 i2 i3 = fmap fromIntegral . withDynamicState (longAsDynamic it) $ \s' it' -> IxSig.c_get4d s' it'
    (fromIntegral i) (fromIntegral i1) (fromIntegral i2) (fromIntegral i3)

  go
    :: forall a . (Typeable a, Ord a, Num a, Show a)
    => (Int64 -> IO a)
    -> (Int64 -> Int64 -> IO a)
    -> (Int64 -> Int64 -> Int64 -> IO a)
    -> (Int64 -> Int64 -> Int64 -> Int64 -> IO a)
    -> [Int64]
    -> IO (String, String)
  go get'1d get'2d get'3d get'4d ds =
    (,desc) <$> case ds of
      []  -> pure ""
      [x] -> brackets . intercalate "" <$> mapM (fmap valWithSpace . get'1d) (mkIx x)
      [x,y] -> go "" get'2d x y
      [x,y,z] -> mat3dGo x y z
      [x,y,z,q] -> mat4dGo x y z q
      _ -> pure "Can't print this yet"
   where
    go :: String -> (Int64 -> Int64 -> IO a) -> Int64 -> Int64 -> IO String
    go fill getter x y = do
      vs <- mapM (fmap valWithSpace . uncurry getter) (mkXY x y)
      pure (mat2dGo fill y "" vs)

    mat2dGo :: String -> Int64 -> String -> [String] -> String
    mat2dGo    _ _ acc []  = acc
    mat2dGo fill y acc rcs = mat2dGo fill y acc' rest
      where
        (row, rest) = splitAt (fromIntegral y) rcs
        fullrow = fill ++ brackets (intercalate "" row)
        acc' = if null acc then fullrow else acc ++ "\n" ++ fullrow

    mat3dGo :: Int64 -> Int64 -> Int64 -> IO String
    mat3dGo x y z = fmap (intercalate "") $ forM (mkIx x) $ \x' -> do
      mat <- go "  " (get'3d x') y z
      pure $ gt2IxHeader [x'] ++ mat

    mat4dGo :: Int64 -> Int64 -> Int64 -> Int64 -> IO String
    mat4dGo w q x y = fmap (intercalate "") $ forM (mkXY w q) $ \(w', q') -> do
      mat <- go "  " (get'4d w' q') x y
      pure $ gt2IxHeader [w', q'] ++ mat

    gt2IxHeader :: [Int64] -> String
    gt2IxHeader is = "\n(" ++ intercalate "," (fmap show is) ++",.,.):\n"

    mkIx :: Int64 -> [Int64]
    mkIx x = [0..x - 1]

    mkXY :: Int64 -> Int64 -> [(Int64, Int64)]
    mkXY x y = [ (r, c) | r <- mkIx x, c <- mkIx y ]

    brackets :: String -> String
    brackets s = "[" ++ s ++ "]"

    valWithSpace :: (Typeable a, Ord a, Num a, Show a) => a -> String
    valWithSpace v = spacing ++ value ++ " "
     where
       truncTo :: (RealFrac x, Fractional x) => Int -> x -> x
       truncTo n f = fromInteger (round $ f * (10^n)) / (10.0^^n)

       value :: String
       value = fromMaybe (show v) $
             (show . truncTo 6 <$> (cast v :: Maybe Double))
         <|> (show . truncTo 6 <$> (cast v :: Maybe Float))

       spacing = case compare (signum v) 0 of
          LT -> " "
          _  -> "  "

    descType, descShape, desc :: String
    descType = show (typeRep (Proxy :: Proxy a)) ++ " tensor with "
    descShape = "shape: " ++ intercalate "x" (fmap show ds)
    desc = brackets $ descType ++ descShape

-------------------------------------------------------------------------------
-- Helper functions which mimic code from 'Torch.Indef.Types'

mkDynamic :: Ptr Sig.CLongTensor -> IO IndexDynamic
mkDynamic t =
  withForeignPtr Sig.torchstate $ \s ->
    Sig.longDynamic Sig.torchstate
      <$> newForeignPtrEnv IxSig.p_free s t