{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Torch.Indef.Dynamic.Print
( showTensor
, describeTensor
) where
import Control.Applicative
import Control.Exception.Safe
import Control.Monad
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe
import Data.Typeable
import GHC.Int
import GHC.Word
import Text.Printf
import qualified Data.List.NonEmpty as NE
import Torch.Indef.Types
showTensor
:: forall a ix . (Typeable a, Ord a, Num a, Show a, Integral ix, Show ix)
=> (ix -> a)
-> (ix -> ix -> a)
-> (ix -> ix -> ix -> a)
-> (ix -> ix -> ix -> ix -> a)
-> [ix]
-> String
showTensor get'1d get'2d get'3d get'4d ds =
case ds of
[] -> ""
[x] -> brackets . intercalate "" $ fmap (valWithSpace . get'1d) (mkIx x)
[x,y] -> go "" get'2d x y
[z,x,y] -> concat . flip fmap (mkIx z) $ \z' -> gt2IxHeader [z'] ++ (go " " (get'3d z') x y)
[w,q,x,y] -> concat . flip fmap (mkXY w q) $ \(w', q') -> gt2IxHeader [w', q'] ++ (go " " (get'4d w' q') x y)
_ -> error "Can't print this yet"
where
go :: String -> (ix -> ix -> a) -> ix -> ix -> String
go fill getter x y = mat2dGo fill y "" $ fmap (valWithSpace . uncurry getter) (mkXY x y)
mat2dGo :: String -> ix -> String -> [String] -> String
mat2dGo _ _ acc [] = acc
mat2dGo fill y acc rcs = mat2dGo fill y acc' rest
where
(row, rest) = splitAt (fromIntegral y) rcs
fullrow = fill ++ brackets (intercalate "" row)
acc' = if null acc then fullrow else acc ++ "\n" ++ fullrow
mat3dGo :: ix -> ix -> ix -> String
mat3dGo x y z = concat $ flip fmap (mkIx x) $ \x' ->
let mat = go " " (get'3d x') y z
in gt2IxHeader [x'] ++ mat
mat4dGo :: ix -> ix -> ix -> ix -> String
mat4dGo w q x y = concat $ flip fmap (mkXY w q) $ \(w', q') ->
let mat = go " " (get'4d w' q') x y
in gt2IxHeader [w', q'] ++ mat
mkIx :: ix -> [ix]
mkIx x = [0..x - 1]
mkXY :: ix -> ix -> [(ix, ix)]
mkXY x y = [ (r, c) | r <- mkIx x, c <- mkIx y ]
brackets :: String -> String
brackets s = "[" ++ s ++ "]"
valWithSpace :: (Typeable a, Ord a, Num a, Show a) => a -> String
valWithSpace v = spacing ++ value ++ ""
where
truncTo :: (RealFrac x, Fractional x) => Int -> x -> x
truncTo n f = fromInteger (round $ f * (10^n)) / (10.0^^n)
value :: String
value = fromMaybe (show v) $
(printf "%.8f" <$> (cast v :: Maybe Double))
<|> (printf "%.4f" <$> (cast v :: Maybe Float))
spacing = magspacing ++ signspacing
magspacing = ""
signspacing = case compare (signum v) 0 of
LT -> " "
_ -> " "
gt2IxHeader :: Show ix => [ix] -> String
gt2IxHeader is = "\n(" ++ intercalate "," (fmap show is) ++ ",.,.):\n"
describeTensor
:: forall t dims
. (Typeable t, Show dims)
=> [dims]
-> Proxy t
-> String
describeTensor ds t =
"[" ++ show (typeRep t) ++ " tensor with shape: " ++ intercalate "x" (fmap show ds) ++ "]"
data TenSlices
= TenNone
| TenVector (NonEmpty HsReal)
| TenMatricies (NonEmpty (NonEmpty [HsReal]))
tensorSlices
:: Dynamic
-> (Int64 -> IO HsReal)
-> (Int64 -> Int64 -> IO HsReal)
-> [Word64]
-> IO TenSlices
tensorSlices t get'1d get'2d
= \case
[] -> pure TenNone
[x] -> TenVector <$> go1d get'1d x
[x,y] -> (TenMatricies . (:|[])) <$> go2d get'2d x y
_ -> throwString "Can't slice this yet"
where
go1d :: (Int64 -> IO HsReal) -> Word64 -> IO (NonEmpty HsReal)
go1d getter x
= forM (mkIx x) getter
go2d :: (Int64 -> Int64 -> IO HsReal) -> Word64 -> Word64 -> IO (NonEmpty [HsReal])
go2d getter x y =
forM (mkIx x) $ \ix ->
forM (mkVIx y) $ \iy ->
getter ix iy
go3d :: (Int64 -> Int64 -> Int64 -> IO HsReal) -> Word64 -> Word64 -> Word64 -> IO (NonEmpty (NonEmpty [HsReal]))
go3d getter x y z =
forM (mkIx x) $ \ix ->
forM (mkIx y) $ \iy ->
traverse (getter ix iy) (mkVIx z)
mkIx :: Word64 -> NonEmpty Int64
mkIx 0 = 0 :| []
mkIx x = 0 :| [1..fromIntegral x - 1]
mkVIx :: Word64 -> [Int64]
mkVIx 0 = []
mkVIx x = [0..fromIntegral x - 1]