{-# language DeriveGeneric, FlexibleContexts, ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Generics.Encode.OneHot -- Description : Generic 1-hot encoding of enumeration types -- Copyright : (c) Marco Zocca (2019-2020) -- License : MIT -- Maintainer : ocramz fripost org -- Stability : experimental -- Portability : GHC -- -- Generic 1-hot encoding of enumeration types -- ----------------------------------------------------------------------------- module Data.Generics.Encode.OneHot (OneHot, onehotDim, onehotIx, oneHotV -- ** Internal , mkOH) where import qualified GHC.Generics as G import Data.Hashable (Hashable(..)) import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM import Generics.SOP (DatatypeInfo, ConstructorInfo(..), constructorInfo, ConstructorName, hindex, hmap, SOP(..), I(..), K(..), hcollapse, SListI) -- import Generics.SOP.NP (cpure_NP) -- import Generics.SOP.Constraint (SListIN) -- import Generics.SOP.GGP (GCode, GDatatypeInfo, GFrom, gdatatypeInfo, gfrom) -- $setup -- >>> :set -XDeriveDataTypeable -- >>> :set -XDeriveGeneric -- >>> import Generics.SOP (Generic(..), All, Code, Proxy(..)) -- >>> import Generics.SOP.NP -- >>> import qualified GHC.Generics as G -- >>> import Generics.SOP.GGP (gdatatypeInfo, gfrom) -- >>> data C = C1 | C2 | C3 deriving (Eq, Show, G.Generic) -- | Construct a 'OneHot' encoding from generic datatype and value information -- -- >>> mkOH (gdatatypeInfo (Proxy :: Proxy C)) (gfrom C2) -- OH {ohDim = 3, ohIx = 1} mkOH :: SListI xs => DatatypeInfo xs -> SOP I xs -> OneHot Int mkOH di sop = oneHot where oneHot = OH sdim six six = hindex sop sdim = length $ constructorList di -- | 1-hot encoded vector. -- -- This representation is used to encode categorical variables as points in a vector space. data OneHot i = OH { ohDim :: i -- ^ Dimensionality of the ambient space , ohIx :: i -- ^ index of '1' } deriving (Eq, Ord, G.Generic) instance Hashable i => Hashable (OneHot i) instance Show i => Show (OneHot i) where show (OH od oi) = concat ["OH_", show od, "_", show oi] -- | Embedding dimension of the 1-hot encoded vector onehotDim :: OneHot i -> i onehotDim = ohDim -- | Active ('hot') index of the 1-hot encoded vector onehotIx :: OneHot i -> i onehotIx = ohIx constructorList :: SListI xs => DatatypeInfo xs -> [ConstructorName] constructorList di = hcollapse $ hmap (\(Constructor x) -> K x) $ constructorInfo di -- | Create a one-hot vector oneHotV :: Num a => OneHot Int -> V.Vector a oneHotV (OH n i) = V.create $ do vm <- VM.replicate n 0 VM.write vm i 1 return vm -- data C = C1 | C2 | C3 deriving (Eq, Show, G.Generic)