module Data.NeuralNetwork.Backend.BLASHS (
module Data.NeuralNetwork.Backend.BLASHS.Utils,
ByBLASHS(..),
ErrCode(..),
cost'
) where
import Data.NeuralNetwork hiding (relu, relu', cost')
import Data.NeuralNetwork.Backend.BLASHS.Layers
import Data.NeuralNetwork.Backend.BLASHS.Utils
import Data.NeuralNetwork.Backend.BLASHS.SIMD
import Control.Monad.Except
import Data.Constraint (Dict(..))
type Err = ExceptT ErrCode IO
data ErrCode = ErrMismatch
data ByBLASHS = ByBLASHS
instance (HeadSize z, TranslateBody s,
Component (RunLayer (SpecToTag s)),
Run (RunLayer (SpecToTag s)) ~ IO)
=> Backend ByBLASHS (z :++ s) where
type Env ByBLASHS = Err
type ConvertFromSpec (z :++ s) = RunLayer (SpecToTag s)
compile _ (a :++ l)= trans (hsize a) l
witness _ _ = Dict
instance RunInEnv IO Err where
run = liftIO
data LayerSize = D1 Int | D2 Int Int Int
class HeadSize l where
hsize :: l -> LayerSize
instance HeadSize SpecIn1D where
hsize (In1D n) = D1 n
instance HeadSize SpecIn2D where
hsize (In2D m n) = D2 1 m n
class BodySize l where
bsize :: LayerSize -> l -> LayerSize
instance BodySize SpecReshape2DAs1D where
bsize (D2 k m n) _ = D1 (k*m*n)
instance BodySize SpecFullConnect where
bsize _ (FullConnect n) = D1 n
instance BodySize SpecConvolution where
bsize (D2 _ m n) (Convolution k f p) = D2 k (m+2*pf+1) (n+2*pf+1)
instance BodySize SpecMaxPooling where
bsize (D2 k m n) (MaxPooling s) = D2 k (m `div` s) (n `div` s)
class TranslateBody s where
type SpecToTag s
trans :: LayerSize -> s -> Err (RunLayer (SpecToTag s))
instance TranslateBody SpecFullConnect where
type SpecToTag SpecFullConnect = S F (T SinglVec)
trans (D1 s) (FullConnect n) = do u <- lift $ newFLayer s n
return $ Stack u (Activation (relu, relu'))
trans _ _ = throwError ErrMismatch
instance TranslateBody SpecConvolution where
type SpecToTag SpecConvolution = S C (T MultiMat)
trans (D2 k s t) (Convolution n f p) = do u <- lift $ newCLayer k n f p
return $ Stack u (Activation (relu, relu'))
trans _ _ = throwError ErrMismatch
instance TranslateBody SpecMaxPooling where
type SpecToTag SpecMaxPooling = P
trans (D2 _ _ _) (MaxPooling n) = return (MaxP n)
trans (D1 _) _ = throwError ErrMismatch
instance TranslateBody SpecReshape2DAs1D where
type SpecToTag SpecReshape2DAs1D = A
trans (D2 _ _ _) _ = return As1D
trans (D1 _) _ = throwError ErrMismatch
instance (TranslateBody a, TranslateBody c, BodySize a) => TranslateBody (a :++ c) where
type SpecToTag (a :++ b) = S (SpecToTag a) (SpecToTag b)
trans s (a :++ c) = do u <- trans s a
v <- trans (bsize s a) c
return $ Stack u v