-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Indef.Dynamic.Print
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Helper functions to render n-rank tensors
-------------------------------------------------------------------------------
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Indef.Dynamic.Print
  ( showTensor
  , describeTensor
  ) where

import Control.Applicative
import Control.Exception.Safe
import Control.Monad
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe
import Data.Typeable
import GHC.Int
import GHC.Word
import Text.Printf

import qualified Data.List.NonEmpty as NE

import Torch.Indef.Types

-- | Generic way of showing the internal data of a tensor in a tabular format.
-- This makes no assumptions about the type of representation to show and can be
-- used for 'Storage', 'Dynamic', and 'Tensor' types.
showTensor
  :: forall a ix . (Typeable a, Ord a, Num a, Show a, Integral ix, Show ix)
  => (ix -> a)
  -> (ix -> ix -> a)
  -> (ix -> ix -> ix -> a)
  -> (ix -> ix -> ix -> ix -> a)
  -> [ix]
  -> String
showTensor get'1d get'2d get'3d get'4d ds =
  case ds of
    []  -> ""
    [x] -> brackets . intercalate "" $ fmap (valWithSpace . get'1d) (mkIx x)
    [x,y] -> go "" get'2d x y

    [z,x,y]   -> concat . flip fmap (mkIx z) $ \z' -> gt2IxHeader [z'] ++ (go "  " (get'3d z') x y)
    -- [x,y,z]   -> mat3dGo x y z
    [w,q,x,y] -> concat . flip fmap (mkXY w q) $ \(w', q') -> gt2IxHeader [w', q'] ++ (go "  " (get'4d w' q') x y)
    -- [x,y,z,q] -> mat4dGo x y z q

    _ -> error "Can't print this yet"
 where
  go :: String -> (ix -> ix -> a) -> ix -> ix -> String
  go fill getter x y = mat2dGo fill y "" $ fmap (valWithSpace . uncurry getter) (mkXY x y)

  mat2dGo :: String -> ix -> String -> [String] -> String
  mat2dGo    _ _ acc []  = acc
  mat2dGo fill y acc rcs = mat2dGo fill y acc' rest
    where
      (row, rest) = splitAt (fromIntegral y) rcs
      fullrow = fill ++ brackets (intercalate "" row)
      acc' = if null acc then fullrow else acc ++ "\n" ++ fullrow

  mat3dGo :: ix -> ix -> ix -> String
  mat3dGo x y z = concat $ flip fmap (mkIx x) $ \x' ->
    let mat = go "  " (get'3d x') y z
    in gt2IxHeader [x'] ++ mat

  mat4dGo :: ix -> ix -> ix -> ix -> String
  mat4dGo w q x y = concat $ flip fmap (mkXY w q) $ \(w', q') ->
    let mat = go "  " (get'4d w' q') x y
    in gt2IxHeader [w', q'] ++ mat


  mkIx :: ix -> [ix]
  mkIx x = [0..x - 1]

  mkXY :: ix -> ix -> [(ix, ix)]
  mkXY x y = [ (r, c) | r <- mkIx x, c <- mkIx y ]

  brackets :: String -> String
  brackets s = "[" ++ s ++ "]"

  valWithSpace :: (Typeable a, Ord a, Num a, Show a) => a -> String
  valWithSpace v = spacing ++ value ++ ""
   where
     truncTo :: (RealFrac x, Fractional x) => Int -> x -> x
     truncTo n f = fromInteger (round $ f * (10^n)) / (10.0^^n)

     value :: String
     value = fromMaybe (show v) $
           (printf "%.8f" <$> (cast v :: Maybe Double))
       <|> (printf "%.4f" <$> (cast v :: Maybe Float))

     spacing = magspacing ++ signspacing
     magspacing = ""
     -- magspacing = case compare (v `mod` 10) 4 of
     --   LT -> replicate (v `mod` 10)
     signspacing = case compare (signum v) 0 of
        LT -> " "
        _  -> "  "

  gt2IxHeader :: Show ix => [ix] -> String
  gt2IxHeader is = "\n(" ++ intercalate "," (fmap show is) ++ ",.,.):\n"

-- | show the shape of a tensor
describeTensor
  :: forall t dims
  . (Typeable t, Show dims)
  => [dims]
  -> Proxy t
  -> String
describeTensor ds t =
  "[" ++ show (typeRep t) ++ " tensor with shape: " ++ intercalate "x" (fmap show ds) ++ "]"

data TenSlices
  = TenNone
  | TenVector (NonEmpty HsReal)
  | TenMatricies (NonEmpty (NonEmpty [HsReal]))

-- | Helper function to show the matrix slices from a tensor.
tensorSlices
  :: Dynamic
  -> (Int64 -> IO HsReal)
  -> (Int64 -> Int64 -> IO HsReal)
  -- -> (Int64 -> Int64 -> Int64 -> IO HsReal)
  -- -> (Int64 -> Int64 -> Int64 -> Int64 -> IO HsReal)
  -> [Word64]
  -> IO TenSlices
tensorSlices t get'1d get'2d -- get'3d get'4d
  = \case
    []  -> pure TenNone
    [x] -> TenVector <$> go1d get'1d x
    [x,y] -> (TenMatricies . (:|[])) <$> go2d get'2d x y
    _ -> throwString "Can't slice this yet"
 where
  go1d :: (Int64 -> IO HsReal) -> Word64 -> IO (NonEmpty HsReal)
  go1d getter x
    = forM (mkIx x) getter

  go2d :: (Int64 -> Int64 -> IO HsReal) -> Word64 -> Word64 -> IO (NonEmpty [HsReal])
  go2d getter x y =
    forM (mkIx x) $ \ix ->
      forM (mkVIx y) $ \iy ->
        getter ix iy

  go3d :: (Int64 -> Int64 -> Int64 -> IO HsReal) -> Word64 -> Word64 -> Word64 -> IO (NonEmpty (NonEmpty [HsReal]))
  go3d getter x y z =
    forM (mkIx x) $ \ix ->
      forM (mkIx y) $ \iy ->
        -- forM [0..z - 1] $ \iz ->
          traverse (getter ix iy) (mkVIx z)

  -- mat2dGo :: Int64 -> String -> [HsReal] -> String
  -- mat2dGo _ acc []  = acc
  -- mat2dGo y acc rcs = mat2dGo y acc' rest
  --   where
  --     (row, rest) = splitAt (fromIntegral y) rcs
  --     acc' = if null acc then row else acc ++ "\n" ++ row

  -- mat3dGo :: Int64 -> Int64 -> Int64 -> IO String
  -- mat3dGo x y z = fmap (intercalate "") $ forM (mkIx x) $ \x' -> do
  --   mat <- go "  " (get'3d x') y z
  --   pure $ gt2IxHeader [x'] ++ mat

  -- mat4dGo :: Int64 -> Int64 -> Int64 -> Int64 -> IO String
  -- mat4dGo w q x y = fmap (intercalate "") $ forM (mkXY w q) $ \(w', q') -> do
  --   mat <- go "  " (get'4d w' q') x y
  --   pure $ gt2IxHeader [w', q'] ++ mat

  mkIx :: Word64 -> NonEmpty Int64
  mkIx 0 = 0 :| []
  mkIx x = 0 :| [1..fromIntegral x - 1]

  mkVIx :: Word64 -> [Int64]
  mkVIx 0 = []
  mkVIx x = [0..fromIntegral x - 1]