{-# LANGUAGE ScopedTypeVariables #-} module Torch.Data.OneHot where import qualified Data.Vector as V #ifdef CUDA import Torch.Cuda.Double import qualified Torch.Cuda.Long as Long #else import Torch.Double import qualified Torch.Long as Long #endif -- onehotL -- :: forall c sz -- . (Ord c, Bounded c, Enum c) -- , sz ~ FromEnum (MaxBound c), KnownDim sz, KnownNat sz) -- => c -- -> LongTensor '[10] -- '[FromEnum (MaxBound c)] -- onehotL c -- = Long.unsafeVector -- $ onehot c -- onehotT -- :: forall c sz -- . (Ord c, Bounded c, Enum c) -- , sz ~ FromEnum (MaxBound c), KnownDim sz, KnownNat sz) -- => c -- -> Tensor '[10] -- '[FromEnum (MaxBound c)] -- onehotT c -- = unsafeVector -- $ fmap fromIntegral -- $ onehot c onehot :: forall i c . (Integral i, Ord c, Bounded c, Enum c) => c -> [i] onehot c = V.toList $ V.generate (fromEnum (maxBound :: c) + 1) (fromIntegral . fromEnum . (== fromEnum c)) onehotf :: forall i c . (Fractional i, Ord c, Bounded c, Enum c) => c -> [i] onehotf c = V.toList $ V.generate (fromEnum (maxBound :: c) + 1) (realToFrac . fromIntegral . fromEnum . (== fromEnum c))