{-# LANGUAGE MultiParamTypeClasses #-} ----------------------------------------------------------------------------- -- | -- Module : BLAS.Tensor.Immutable -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module BLAS.Tensor.Immutable ( ITensor(..), (!), ) where import BLAS.Tensor.Base import BLAS.Elem import Data.Ix infixl 9 ! infixl 7 *> infixl 5 `shift` -- | A class for immutable tensors. class (BaseTensor x i) => ITensor x i where -- | Get the numer of elements stored in the tensor. size :: x n e -> Int -- | Get a new tensor by replacing the elements at the given indices. (//) :: (BLAS1 e) => x n e -> [(i,e)] -> x n e -- | Get the value at the given index, without doing any bounds-checking. unsafeAt :: (Elem e) => x n e -> i -> e -- | Same as '(//)' but doesn't do any bounds-checking. unsafeReplace :: (BLAS1 e) => x n e -> [(i,e)] -> x n e -- | Get the indices of the elements stored in the tensor. indices :: x n e -> [i] -- | Get the elements stored in the tensor. elems :: (Elem e) => x n e -> [e] elems = snd . unzip . assocs -- | Get the list of @(@index@,@ element@)@ pairs stored in the tensor. assocs :: (Elem e) => x n e -> [(i,e)] -- accum :: (e -> e' -> e) -> x e -> [(i,e')] -> x e -- | Apply a function elementwise to a tensor. tmap :: (BLAS1 e) => (e -> e) -> x n e -> x n e -- ixmap :: i -> (i -> i) -> x e -> x e -- unsafeIxMap -- | Scale every element by the given value. (*>) :: (BLAS1 e) => e -> x n e -> x n e (*>) k = tmap (k*) -- | Add a constant to every element. shift :: (BLAS1 e) => e -> x n e -> x n e shift k = tmap (k+) -- | Get the value at the given index. Range-checks the argument. (!) :: (ITensor x i, Elem e) => x n e -> i -> e (!) x i = case (inRange b i) of False -> error $ "tried to get element at a index `" ++ show i ++ "'" ++ " in an object with shape `" ++ show s ++ "'" True -> unsafeAt x i where b = bounds x s = shape x