-- | Module : Numeric.InfBackprop.Instances.NumHask -- Copyright : (C) 2025 Alexey Tochin -- License : BSD3 (see the file LICENSE) -- Maintainer : Alexey Tochin -- -- Utility functions for working with vectors. module Numeric.InfBackprop.Utils.Vector ( fromTuple, safeHead, safeLast, trimArrayHead, trimArrayTail, zipWith, ) where import Control.Monad (MonadPlus, mzero) import Data.Bool (otherwise) import Data.Eq (Eq, (==)) import Data.Function (($)) import qualified Data.IndexedListLiterals as DILL import Data.Maybe (Maybe (Just, Nothing)) import Data.Ord (Ordering (EQ, GT, LT), compare) import qualified Data.Vector.Generic as DVG import GHC.Base (pure, (.)) -- | Converts a tuple into a Vector (`Data.Vector.Vector`). -- -- === __Examples__ -- -- >>> import GHC.Int (Int) -- >>> import qualified Data.Vector as DV -- -- >>> fromTuple (1 :: Int, 2 :: Int, 3 :: Int) :: DV.Vector Int -- [1,2,3] fromTuple :: (DVG.Vector v a) => (DILL.IndexedListLiterals input length a) => input -> v a fromTuple = DVG.fromList . DILL.toList -- | Returns the first element of a vector safely. -- If the vector is empty, it returns 'Nothing'. -- -- ==== __Examples__ -- -- >>> import GHC.Int (Int) -- >>> import Data.Vector (fromList) -- -- >>> safeHead (fromList [1, 2, 3]) :: Maybe Int -- Just 1 -- -- >>> safeHead (fromList []) :: Maybe Int -- Nothing safeHead :: (DVG.Vector v a, MonadPlus m) => v a -> m a safeHead vec | DVG.null vec = mzero | otherwise = pure $ DVG.unsafeHead vec -- | Returns the last element of a vector safely. -- If the vector is empty, it returns 'Nothing'. -- -- ==== __Examples__ -- -- >>> import GHC.Int (Int) -- >>> import Data.Vector (fromList, empty) -- -- >>> safeLast (fromList [1, 2, 3]) :: Maybe Int -- Just 3 -- -- >>> safeLast empty :: Maybe Int -- Nothing safeLast :: (DVG.Vector v a, MonadPlus m) => v a -> m a safeLast vec | DVG.null vec = mzero | otherwise = pure $ DVG.unsafeLast vec -- | Removes elements from the beginning of the vector until the first element -- is not equal to the given value. -- -- ==== __Examples__ -- -- >>> import Data.Vector (fromList, empty) -- -- >>> trimArrayHead 1 (fromList [1, 1, 1, 2, 3]) -- [2,3] -- -- >>> trimArrayHead 1 empty -- [] trimArrayHead :: (DVG.Vector v a, Eq a) => a -> v a -> v a trimArrayHead x vec = case safeHead vec of Nothing -> DVG.empty Just firstVal -> if firstVal == x then trimArrayHead x (DVG.tail vec) else vec -- | Removes elements from the end of the vector until the last element -- is not equal to the given value. -- -- ==== __Examples__ -- -- >>> import Data.Vector (fromList, empty) -- -- >>> trimArrayTail 3 (fromList [1, 2, 3, 3, 3]) -- [1,2] -- -- >>> trimArrayTail 3 empty -- [] trimArrayTail :: (DVG.Vector v a, Eq a) => a -> v a -> v a trimArrayTail x array = case safeLast array of Nothing -> DVG.empty Just lastVal -> if lastVal == x then trimArrayTail x (DVG.init array) else array -- | Combines two arrays of different lengths using a custom function. -- The resulting array has a length equal to the maximum of the two input vectors. -- The shorter array is padded with values generated by the provided functions. -- -- ==== __Examples__ -- -- >>> import Prelude (id, negate, (-), Int) -- >>> import qualified Data.Vector as DV -- -- The following example demonstrates subtracting two arrays of different lengths. -- The shorter array is padded with zeros, and the remaining elements are processed -- using the provided functions. -- -- >>>:{ -- zipWith -- (-) -- Subtract corresponding elements from the two arrays -- id -- Keep the remaining elements of the first array unchanged -- negate -- Negate the remaining elements of the second array -- (DV.fromList [10, 20, 30]) -- First array -- (DV.fromList [1, 2]) -- Second array -- :} -- [9,18,30] -- -- >>> import Prelude (id, negate, (-), Int) -- >>> import Data.Vector (fromList) -- -- >>> let v0 :: DV.Vector Int = DV.fromList [10, 20, 30] -- >>> let v1 :: DV.Vector Int = DV.fromList [1, 2] -- >>> zipWith (-) id negate v0 v1 -- [9,18,30] zipWith :: (DVG.Vector v a, DVG.Vector v b, DVG.Vector v c) => (a -> b -> c) -> (a -> c) -> (b -> c) -> v a -> v b -> v c zipWith f g h a0 a1 = case compare l0 l1 of EQ -> base GT -> base DVG.++ DVG.map g (DVG.drop l1 a0) LT -> base DVG.++ DVG.map h (DVG.drop l0 a1) where l0 = DVG.length a0 l1 = DVG.length a1 base = DVG.zipWith f a0 a1