{-# language DeriveGeneric, FlexibleContexts, ScopedTypeVariables #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Generics.Encode.OneHot
-- Description :  Generic 1-hot encoding of enumeration types 
-- Copyright   :  (c) Marco Zocca (2019-2020)
-- License     :  MIT
-- Maintainer  :  ocramz fripost org
-- Stability   :  experimental
-- Portability :  GHC
--
-- Generic 1-hot encoding of enumeration types
--
-----------------------------------------------------------------------------
module Data.Generics.Encode.OneHot (OneHot, onehotDim, onehotIx, oneHotV
                                   -- ** Internal
                                   , mkOH) where

import qualified GHC.Generics as G
import Data.Hashable (Hashable(..))
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Generics.SOP (DatatypeInfo, ConstructorInfo(..), constructorInfo, ConstructorName, hindex, hmap, SOP(..), I(..), K(..), hcollapse, SListI)
-- import Generics.SOP.NP (cpure_NP)
-- import Generics.SOP.Constraint (SListIN)
-- import Generics.SOP.GGP (GCode, GDatatypeInfo, GFrom, gdatatypeInfo, gfrom)

-- $setup
-- >>> :set -XDeriveDataTypeable
-- >>> :set -XDeriveGeneric
-- >>> import Generics.SOP (Generic(..), All, Code, Proxy(..))
-- >>> import Generics.SOP.NP
-- >>> import qualified GHC.Generics as G
-- >>> import Generics.SOP.GGP (gdatatypeInfo, gfrom)
-- >>> data C = C1 | C2 | C3 deriving (Eq, Show, G.Generic)

-- | Construct a 'OneHot' encoding from generic datatype and value information
-- 
-- >>> mkOH (gdatatypeInfo (Proxy :: Proxy C)) (gfrom C2)
-- OH {ohDim = 3, ohIx = 1}
mkOH :: SListI xs => DatatypeInfo xs -> SOP I xs -> OneHot Int
mkOH :: DatatypeInfo xs -> SOP I xs -> OneHot Int
mkOH DatatypeInfo xs
di SOP I xs
sop = OneHot Int
oneHot where
     oneHot :: OneHot Int
oneHot = Int -> Int -> OneHot Int
forall i. i -> i -> OneHot i
OH Int
sdim Int
six
     six :: Int
six = SOP I xs -> Int
forall k l (h :: (k -> *) -> l -> *) (f :: k -> *) (xs :: l).
HIndex h =>
h f xs -> Int
hindex SOP I xs
sop
     sdim :: Int
sdim = [ConstructorName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ConstructorName] -> Int) -> [ConstructorName] -> Int
forall a b. (a -> b) -> a -> b
$ DatatypeInfo xs -> [ConstructorName]
forall (xs :: [[*]]).
SListI xs =>
DatatypeInfo xs -> [ConstructorName]
constructorList DatatypeInfo xs
di

-- | 1-hot encoded vector.
--
-- This representation is used to encode categorical variables as points in a vector space.
data OneHot i = OH {
  OneHot i -> i
ohDim :: i -- ^ Dimensionality of the ambient space
  , OneHot i -> i
ohIx :: i  -- ^ index of '1'
  } deriving (OneHot i -> OneHot i -> Bool
(OneHot i -> OneHot i -> Bool)
-> (OneHot i -> OneHot i -> Bool) -> Eq (OneHot i)
forall i. Eq i => OneHot i -> OneHot i -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OneHot i -> OneHot i -> Bool
$c/= :: forall i. Eq i => OneHot i -> OneHot i -> Bool
== :: OneHot i -> OneHot i -> Bool
$c== :: forall i. Eq i => OneHot i -> OneHot i -> Bool
Eq, Eq (OneHot i)
Eq (OneHot i)
-> (OneHot i -> OneHot i -> Ordering)
-> (OneHot i -> OneHot i -> Bool)
-> (OneHot i -> OneHot i -> Bool)
-> (OneHot i -> OneHot i -> Bool)
-> (OneHot i -> OneHot i -> Bool)
-> (OneHot i -> OneHot i -> OneHot i)
-> (OneHot i -> OneHot i -> OneHot i)
-> Ord (OneHot i)
OneHot i -> OneHot i -> Bool
OneHot i -> OneHot i -> Ordering
OneHot i -> OneHot i -> OneHot i
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall i. Ord i => Eq (OneHot i)
forall i. Ord i => OneHot i -> OneHot i -> Bool
forall i. Ord i => OneHot i -> OneHot i -> Ordering
forall i. Ord i => OneHot i -> OneHot i -> OneHot i
min :: OneHot i -> OneHot i -> OneHot i
$cmin :: forall i. Ord i => OneHot i -> OneHot i -> OneHot i
max :: OneHot i -> OneHot i -> OneHot i
$cmax :: forall i. Ord i => OneHot i -> OneHot i -> OneHot i
>= :: OneHot i -> OneHot i -> Bool
$c>= :: forall i. Ord i => OneHot i -> OneHot i -> Bool
> :: OneHot i -> OneHot i -> Bool
$c> :: forall i. Ord i => OneHot i -> OneHot i -> Bool
<= :: OneHot i -> OneHot i -> Bool
$c<= :: forall i. Ord i => OneHot i -> OneHot i -> Bool
< :: OneHot i -> OneHot i -> Bool
$c< :: forall i. Ord i => OneHot i -> OneHot i -> Bool
compare :: OneHot i -> OneHot i -> Ordering
$ccompare :: forall i. Ord i => OneHot i -> OneHot i -> Ordering
$cp1Ord :: forall i. Ord i => Eq (OneHot i)
Ord, (forall x. OneHot i -> Rep (OneHot i) x)
-> (forall x. Rep (OneHot i) x -> OneHot i) -> Generic (OneHot i)
forall x. Rep (OneHot i) x -> OneHot i
forall x. OneHot i -> Rep (OneHot i) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall i x. Rep (OneHot i) x -> OneHot i
forall i x. OneHot i -> Rep (OneHot i) x
$cto :: forall i x. Rep (OneHot i) x -> OneHot i
$cfrom :: forall i x. OneHot i -> Rep (OneHot i) x
G.Generic)
instance Hashable i => Hashable (OneHot i)
instance Show i => Show (OneHot i) where
  show :: OneHot i -> ConstructorName
show (OH i
od i
oi) = [ConstructorName] -> ConstructorName
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ConstructorName
"OH_", i -> ConstructorName
forall a. Show a => a -> ConstructorName
show i
od, ConstructorName
"_", i -> ConstructorName
forall a. Show a => a -> ConstructorName
show i
oi]

-- | Embedding dimension of the 1-hot encoded vector
onehotDim :: OneHot i -> i
onehotDim :: OneHot i -> i
onehotDim = OneHot i -> i
forall i. OneHot i -> i
ohDim
-- | Active ('hot') index of the 1-hot encoded vector
onehotIx :: OneHot i -> i
onehotIx :: OneHot i -> i
onehotIx = OneHot i -> i
forall i. OneHot i -> i
ohIx

constructorList :: SListI xs => DatatypeInfo xs -> [ConstructorName]
constructorList :: DatatypeInfo xs -> [ConstructorName]
constructorList DatatypeInfo xs
di = NP (K ConstructorName) xs -> CollapseTo NP ConstructorName
forall k l (h :: (k -> *) -> l -> *) (xs :: l) a.
(HCollapse h, SListIN h xs) =>
h (K a) xs -> CollapseTo h a
hcollapse (NP (K ConstructorName) xs -> CollapseTo NP ConstructorName)
-> NP (K ConstructorName) xs -> CollapseTo NP ConstructorName
forall a b. (a -> b) -> a -> b
$ (forall (a :: [*]). ConstructorInfo a -> K ConstructorName a)
-> NP ConstructorInfo xs -> NP (K ConstructorName) xs
forall k l (h :: (k -> *) -> l -> *) (xs :: l) (f :: k -> *)
       (f' :: k -> *).
(SListIN (Prod h) xs, HAp h) =>
(forall (a :: k). f a -> f' a) -> h f xs -> h f' xs
hmap (\(Constructor x) -> ConstructorName -> K ConstructorName a
forall k a (b :: k). a -> K a b
K ConstructorName
x) (NP ConstructorInfo xs -> NP (K ConstructorName) xs)
-> NP ConstructorInfo xs -> NP (K ConstructorName) xs
forall a b. (a -> b) -> a -> b
$ DatatypeInfo xs -> NP ConstructorInfo xs
forall (xss :: [[*]]). DatatypeInfo xss -> NP ConstructorInfo xss
constructorInfo DatatypeInfo xs
di


-- | Create a one-hot vector
oneHotV :: Num a =>
           OneHot Int
        -> V.Vector a
oneHotV :: OneHot Int -> Vector a
oneHotV (OH Int
n Int
i) = (forall s. ST s (MVector s a)) -> Vector a
forall a. (forall s. ST s (MVector s a)) -> Vector a
V.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
vm <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
VM.replicate Int
n a
0
  MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
VM.write MVector s a
MVector (PrimState (ST s)) a
vm Int
i a
1
  MVector s a -> ST s (MVector s a)
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
vm


-- data C = C1 | C2 | C3 deriving (Eq, Show, G.Generic)