{-# LANGUAGE TupleSections #-}
module Algebra.Matrix.Generic.Mutable
  ( MMatrix(..), Index, Size, Column, Row
  , new, unsafeNew, copy, clone, generate, generateM
  , fromRow, fromColumn, getRow, getColumn
  , imapRow, mapRow, fill, read, write
  , scaleRow, unsafeSwapRows, swapRows
  , gaussReduction, unsafeGaussReduction
  ) where
import           Algebra.Matrix.Generic.Base
import           Algebra.Prelude.Core        hiding (Vector, generate)
import           Control.Monad.Primitive     (PrimMonad, PrimState)
import qualified Data.Vector.Generic         as GV

-- | Mutable, row-based @0@-origin matrix
class (GV.Vector (Column mat) a, GV.Vector (Row mat) a) => MMatrix mat a where
  -- | @'basicUnsafeNew' n m@ creates a mutable matrix with @n@ rows and @m@ columns,
  --   without initialisation.
  --   This method should not be used directly, use @'unsafeNew'@ instead.
  basicUnsafeNew            :: PrimMonad m => Size -> Size -> m (mat (PrimState m) a)
  basicInitialise           :: PrimMonad m => mat (PrimState m) a -> m ()
  basicRowCount             :: mat s a -> Size
  basicColumnCount          :: mat s a -> Size
  unsafeGetRow              :: PrimMonad m => Index -> mat (PrimState m) a -> m (Row    mat a)
  unsafeGetColumn           :: PrimMonad m => Index -> mat (PrimState m) a -> m (Column mat a)
  unsafeFill                :: PrimMonad m => Size -> Size -> a -> m (mat (PrimState m) a)
  unsafeFill Size
w Size
h a
a = do
    mat (PrimState m) a
m <- Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
basicUnsafeNew Size
w Size
h
    [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
wSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
i -> [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
hSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
j -> mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i Size
j a
a
    mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return mat (PrimState m) a
m
  -- | Construct a mutable matrix consisting of a single row, perhaps without any copy.
  unsafeFromRow             :: PrimMonad m => Row mat a -> m (mat (PrimState m) a)
  unsafeFromRow = [Row mat a] -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
[Row mat a] -> m (mat (PrimState m) a)
unsafeFromRows ([Row mat a] -> m (mat (PrimState m) a))
-> (Row mat a -> [Row mat a])
-> Row mat a
-> m (mat (PrimState m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Row mat a -> [Row mat a] -> [Row mat a]
forall a. a -> [a] -> [a]
:[])
  -- | Construct a mutable matrix consisting a single column, perhaps without any copy.
  unsafeFromRows            :: PrimMonad m => [Row mat a] -> m (mat (PrimState m) a)

  -- | Construct a mutable matrix consisting a single column, perhaps without any copy.
  unsafeFromColumn          :: PrimMonad m => Column mat a -> m (mat (PrimState m) a)
  unsafeFromColumn          = [Column mat a] -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
[Column mat a] -> m (mat (PrimState m) a)
unsafeFromColumns ([Column mat a] -> m (mat (PrimState m) a))
-> (Column mat a -> [Column mat a])
-> Column mat a
-> m (mat (PrimState m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Column mat a -> [Column mat a] -> [Column mat a]
forall a. a -> [a] -> [a]
:[])
  unsafeFromColumns         :: PrimMonad m => [Column mat a] -> m (mat (PrimState m) a)
  unsafeFromColumns [] = Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
new Size
0 Size
0
  unsafeFromColumns [Column mat a]
xs = Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
unsafeGenerate (Column mat a -> Size
forall (v :: * -> *) a. Vector v a => v a -> Size
GV.length (Column mat a -> Size) -> Column mat a -> Size
forall a b. (a -> b) -> a -> b
$ [Column mat a] -> Column mat a
forall a. [a] -> a
head [Column mat a]
xs) ([Column mat a] -> Size
forall (t :: * -> *) a. Foldable t => t a -> Size
length [Column mat a]
xs) ((Size -> Size -> a) -> m (mat (PrimState m) a))
-> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ \Size
i Size
j -> ([Column mat a]
xs [Column mat a] -> Size -> Column mat a
forall a. [a] -> Size -> a
!! Size
j) Column mat a -> Size -> a
forall (v :: * -> *) a. Vector v a => v a -> Size -> a
GV.! Size
i

  -- | @'usnafeCopy' target source@ copies the content of @source@ to @target@, without boundary check.
  unsafeCopy                :: PrimMonad m => mat (PrimState m) a -> mat (PrimState m) a -> m ()
  -- | @'unsafeRead' i j m@ reads the value at @i@th row in @j@th column in @m@, without boundary check.
  --
  --   __N.B.__ Rows and columns are regarded as /zero-origin/, not @1@!.
  unsafeRead                :: PrimMonad m => mat (PrimState m) a -> Index ->Index -> m a
  unsafeWrite               :: PrimMonad m => mat (PrimState m) a -> Index -> Index -> a -> m ()
  basicSet                  :: PrimMonad m => mat (PrimState m) a -> a -> m ()
  basicUnsafeIMapRowM       :: PrimMonad m => mat (PrimState m) a -> Index -> (Index -> a -> m a) -> m ()
  basicUnsafeIMapRow        :: PrimMonad m => mat (PrimState m) a -> Index -> (Index -> a -> a) -> m ()
  basicUnsafeIMapRow mat (PrimState m) a
m Size
i Size -> a -> a
f  = mat (PrimState m) a -> Size -> (Size -> a -> m a) -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> (Size -> a -> m a) -> m ()
basicUnsafeIMapRowM mat (PrimState m) a
m Size
i ((a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> (a -> a) -> a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.)((a -> a) -> a -> m a) -> (Size -> a -> a) -> Size -> a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Size -> a -> a
f)
  basicUnsafeSwapRows       :: PrimMonad m => mat (PrimState m) a -> Index -> Index -> m ()
  basicUnsafeSwapRows mat (PrimState m) a
m Size
i Size
i' = [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0.. mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
basicColumnCount mat (PrimState m) a
m Size -> Size -> Size
forall r. Group r => r -> r -> r
- Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
j -> do
    a
x <- mat (PrimState m) a -> Size -> Size -> m a
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> m a
unsafeRead mat (PrimState m) a
m Size
i  Size
j
    a
y <- mat (PrimState m) a -> Size -> Size -> m a
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> m a
unsafeRead mat (PrimState m) a
m Size
i' Size
j
    mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i  Size
j a
y
    mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i' Size
j a
x

  unsafeScaleRow :: (PrimMonad m, Commutative a) => mat (PrimState m) a -> Index -> a -> m ()

  unsafeGenerate :: (PrimMonad m) => Size -> Size -> (Index -> Index -> a) -> m (mat (PrimState m) a)
  unsafeGenerate Size
w Size
h Size -> Size -> a
f = do
    mat (PrimState m) a
m <- Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
unsafeNew Size
w Size
h
    mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> m ()
basicInitialise mat (PrimState m) a
m
    [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
wSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
i -> [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
hSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
j ->
      mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i Size
j (Size -> Size -> a
f Size
i Size
j)
    mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return mat (PrimState m) a
m


  unsafeGenerateM :: (PrimMonad m) => Size -> Size -> (Index -> Index -> m a) -> m (mat (PrimState m) a)
  unsafeGenerateM Size
w Size
h Size -> Size -> m a
f = do
    mat (PrimState m) a
m <- Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
unsafeNew Size
w Size
h
    mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> m ()
basicInitialise mat (PrimState m) a
m
    [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
wSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
i -> [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
hSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m ()) -> m ()) -> (Size -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Size
j ->
      mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i Size
j (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Size -> Size -> m a
f Size
i Size
j
    mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return mat (PrimState m) a
m

  toRows :: PrimMonad m => mat (PrimState m) a -> m [Row mat a]
  toRows mat (PrimState m) a
mat = [Size] -> (Size -> m (Row mat a)) -> m [Row mat a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Size
0..mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
matSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m (Row mat a)) -> m [Row mat a])
-> (Size -> m (Row mat a)) -> m [Row mat a]
forall a b. (a -> b) -> a -> b
$ \Size
i -> Size -> mat (PrimState m) a -> m (Row mat a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> mat (PrimState m) a -> m (Row mat a)
unsafeGetRow Size
i mat (PrimState m) a
mat

  toColumns :: PrimMonad m => mat (PrimState m) a -> m [Column mat a]
  toColumns mat (PrimState m) a
mat = [Size] -> (Size -> m (Column mat a)) -> m [Column mat a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Size
0..mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
matSize -> Size -> Size
forall r. Group r => r -> r -> r
-Size
1] ((Size -> m (Column mat a)) -> m [Column mat a])
-> (Size -> m (Column mat a)) -> m [Column mat a]
forall a b. (a -> b) -> a -> b
$ \Size
i -> Size -> mat (PrimState m) a -> m (Column mat a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> mat (PrimState m) a -> m (Column mat a)
unsafeGetColumn Size
i mat (PrimState m) a
mat

  -- | @'combineRows' i c j mat@ adds scalar multiple of @j@th row by @c@ to @i@th.
  combineRows :: (Semiring a, Commutative a, PrimMonad m) => Index -> a -> Index -> mat (PrimState m) a -> m ()
  combineRows Size
i a
c Size
j mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
j mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    mat (PrimState m) a -> Size -> (Size -> a -> m a) -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> (Size -> a -> m a) -> m ()
basicUnsafeIMapRowM mat (PrimState m) a
m Size
i (\Size
k a
a -> (a
aa -> a -> a
forall r. Additive r => r -> r -> r
+) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
ca -> a -> a
forall r. Multiplicative r => r -> r -> r
*) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> mat (PrimState m) a -> Size -> Size -> m a
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> m a
unsafeRead mat (PrimState m) a
m Size
j Size
k)


columnCount :: MMatrix mat a => mat s a -> Size
columnCount :: mat s a -> Size
columnCount = mat s a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
basicColumnCount
{-# INLINE columnCount #-}

rowCount :: MMatrix mat a => mat s a -> Size
rowCount :: mat s a -> Size
rowCount = mat s a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
basicRowCount
{-# INLINE rowCount #-}

-- | @'new' n m@ creates a mutable matrix with @n@ rows and @m@ columns.
new :: (MMatrix mat a, PrimMonad m) => Size -> Size -> m (mat (PrimState m) a)
new :: Size -> Size -> m (mat (PrimState m) a)
new Size
n Size
m | Size
n Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
0 Bool -> Bool -> Bool
&& Size
m Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
0 = Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
basicUnsafeNew Size
n Size
m m (mat (PrimState m) a)
-> (mat (PrimState m) a -> m (mat (PrimState m) a))
-> m (mat (PrimState m) a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \mat (PrimState m) a
v -> mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> m ()
basicInitialise mat (PrimState m) a
v m () -> m (mat (PrimState m) a) -> m (mat (PrimState m) a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return mat (PrimState m) a
v
        | Bool
otherwise = [Char] -> m (mat (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (mat (PrimState m) a))
-> [Char] -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [Char]
"negative length: " [Char] -> [Char] -> [Char]
forall w. Monoid w => w -> w -> w
++ Size -> [Char]
forall a. Show a => a -> [Char]
show Size
n

-- | @'unsafeNew' n m@ creates a mutable matrix with @n@ rows and @m@ columns, without memory initialisation.
--
--   See also: @'new'@.
unsafeNew :: (MMatrix mat a, PrimMonad m) => Size -> Size -> m (mat (PrimState m) a)
unsafeNew :: Size -> Size -> m (mat (PrimState m) a)
unsafeNew Size
n Size
m | Size
n Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
0 Bool -> Bool -> Bool
&& Size
m Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
0 = Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
basicUnsafeNew Size
n Size
m
              | Bool
otherwise = [Char] -> m (mat (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (mat (PrimState m) a))
-> [Char] -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [Char]
"negative length: " [Char] -> [Char] -> [Char]
forall w. Monoid w => w -> w -> w
++ Size -> [Char]
forall a. Show a => a -> [Char]
show Size
n

checkBound :: (Show a, Num a, Ord a) => a -> (t -> a) -> t -> p -> p
checkBound :: a -> (t -> a) -> t -> p -> p
checkBound a
i t -> a
f t
m p
a | a
0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
i Bool -> Bool -> Bool
&& a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< t -> a
f t
m = p
a
                   | Bool
otherwise = [Char] -> p
forall a. HasCallStack => [Char] -> a
error ([Char] -> p) -> [Char] -> p
forall a b. (a -> b) -> a -> b
$  [[Char]] -> [Char]
unwords [[Char]
"Out of bouds:", a -> [Char]
forall a. Show a => a -> [Char]
show a
i, [Char]
"out of", a -> [Char]
forall a. Show a => a -> [Char]
show (t -> a
f t
m)]

-- | @'getRow' n mat@ retrieves @n@th row in @mat@
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
getRow :: (MMatrix mat a, PrimMonad m) => Index -> mat (PrimState m) a -> m (Row mat a)
getRow :: Size -> mat (PrimState m) a -> m (Row mat a)
getRow Size
i mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m (Row mat a)
-> m (Row mat a)
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m (Row mat a) -> m (Row mat a)) -> m (Row mat a) -> m (Row mat a)
forall a b. (a -> b) -> a -> b
$ Size -> mat (PrimState m) a -> m (Row mat a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> mat (PrimState m) a -> m (Row mat a)
unsafeGetRow Size
i mat (PrimState m) a
m

-- | @'getColumn' n mat@ retrieves @n@th colun in @mat@
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
getColumn :: (MMatrix mat a, PrimMonad m) => Index -> mat (PrimState m) a -> m (Column mat a)
getColumn :: Size -> mat (PrimState m) a -> m (Column mat a)
getColumn Size
i mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m (Column mat a)
-> m (Column mat a)
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
m (m (Column mat a) -> m (Column mat a))
-> m (Column mat a) -> m (Column mat a)
forall a b. (a -> b) -> a -> b
$ Size -> mat (PrimState m) a -> m (Column mat a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> mat (PrimState m) a -> m (Column mat a)
unsafeGetColumn Size
i mat (PrimState m) a
m

-- | @'imapRow' i f m@ mutates @i@th row in the matrix @m@ by applying @f@ with column index.
--
--   See also: @'mapRow'@, @'imapColumn'@, @'mapColumn'@.
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
imapRow :: (PrimMonad m, MMatrix mat a) => Index -> (Index -> a -> a) -> mat (PrimState m) a -> m ()
imapRow :: Size -> (Size -> a -> a) -> mat (PrimState m) a -> m ()
imapRow Size
i Size -> a -> a
f mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ mat (PrimState m) a -> Size -> (Size -> a -> a) -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> (Size -> a -> a) -> m ()
basicUnsafeIMapRow mat (PrimState m) a
m Size
i Size -> a -> a
f

-- | @'mapRow' i f m@ mutates @i@th row in the matrix @m@ by applying @f@.
--
--   See also: @'imapRow'@, @'imapColumn'@, @'mapColumn'@.
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
mapRow :: (PrimMonad m, MMatrix mat a) => Index -> (a -> a) -> mat (PrimState m) a -> m ()
mapRow :: Size -> (a -> a) -> mat (PrimState m) a -> m ()
mapRow Size
i a -> a
f mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ mat (PrimState m) a -> Size -> (Size -> a -> a) -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> (Size -> a -> a) -> m ()
basicUnsafeIMapRow mat (PrimState m) a
m Size
i ((a -> a) -> Size -> a -> a
forall a b. a -> b -> a
const a -> a
f)

-- | @'scaleRowL' i c m@ multiplies every element in @i@th row in @m@ by @c@, from right.
--
--   See also: @'scaleRowL'@ and @'scaleRow'@
scaleRow :: (Multiplicative a, Commutative a, MMatrix mat a, PrimMonad m)
         => Index -> a -> mat (PrimState m) a -> m ()
scaleRow :: Size -> a -> mat (PrimState m) a -> m ()
scaleRow Size
i a
c mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ mat (PrimState m) a -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m, Commutative a) =>
mat (PrimState m) a -> Size -> a -> m ()
unsafeScaleRow mat (PrimState m) a
m Size
i a
c
{-# INLINE scaleRow #-}

-- | @'fill' n m c@ creates a mutable constant matrix with @n@ rows and @m@ columns.
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
fill :: (PrimMonad m, MMatrix mat a) => Size -> Size -> a -> m (mat (PrimState m) a)
fill :: Size -> Size -> a -> m (mat (PrimState m) a)
fill Size
i Size
j | Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
i Bool -> Bool -> Bool
&& Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
j = Size -> Size -> a -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> a -> m (mat (PrimState m) a)
unsafeFill Size
i Size
j
         | Bool
otherwise = [Char] -> a -> m (mat (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> a -> m (mat (PrimState m) a))
-> [Char] -> a -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unwords[ [Char]
"fill:", [Char]
"out of bounds:", (Size, Size) -> [Char]
forall a. Show a => a -> [Char]
show (Size
i, Size
j)]

-- | @'read' i j m@ reads the value at @i@th row in @j@th column in @m@
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
read :: (PrimMonad m, MMatrix mat a) => Index -> Index -> mat (PrimState m) a -> m a
read :: Size -> Size -> mat (PrimState m) a -> m a
read Size
i Size
j mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m a
-> m a
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m a
-> m a
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
j mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
m (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ mat (PrimState m) a -> Size -> Size -> m a
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> m a
unsafeRead mat (PrimState m) a
m Size
i Size
j

-- | @'read' i j m@ writes the value at @i@th row in @j@th column in @m@
--
--   __N.B.__ Index is considered as /@0@-origin/, NOT @1@!
write :: (PrimMonad m, MMatrix mat a) => Index -> Index -> mat (PrimState m) a -> a -> m ()
write :: Size -> Size -> mat (PrimState m) a -> a -> m ()
write Size
i Size
j mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> (a -> m ())
-> a
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m ((a -> m ()) -> a -> m ()) -> (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> (a -> m ())
-> a
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
j mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
m ((a -> m ()) -> a -> m ()) -> (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ mat (PrimState m) a -> Size -> Size -> a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> a -> m ()
unsafeWrite mat (PrimState m) a
m Size
i Size
j

-- | @'unsafeSwapRows' n m mat@ swaps @n@th and @m@th rows in @m@, without boundary check.
--
--   See also: @'swapRows'@.
unsafeSwapRows :: (PrimMonad m, MMatrix mat a) => Index -> Index -> mat (PrimState m) a -> m ()
unsafeSwapRows :: Size -> Size -> mat (PrimState m) a -> m ()
unsafeSwapRows Size
i Size
j mat (PrimState m) a
m = mat (PrimState m) a -> Size -> Size -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> Size -> Size -> m ()
basicUnsafeSwapRows mat (PrimState m) a
m Size
i Size
j

-- | @'swapRows' n m mat@ swaps @n@th and @m@th rows in @m@.
--
--   See also: @'unsafeSwapRows'@.
swapRows :: (PrimMonad m, MMatrix mat a) => Index -> Index -> mat (PrimState m) a -> m ()
swapRows :: Size -> Size -> mat (PrimState m) a -> m ()
swapRows Size
i Size
j mat (PrimState m) a
m = Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
i mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Size
-> (mat (PrimState m) a -> Size)
-> mat (PrimState m) a
-> m ()
-> m ()
forall a t p.
(Show a, Num a, Ord a) =>
a -> (t -> a) -> t -> p -> p
checkBound Size
j mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Size -> Size -> mat (PrimState m) a -> m ()
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> mat (PrimState m) a -> m ()
unsafeSwapRows Size
i Size
j mat (PrimState m) a
m

copy :: (PrimMonad m, MMatrix mat a) => mat (PrimState m) a -> mat (PrimState m) a -> m ()
copy :: mat (PrimState m) a -> mat (PrimState m) a -> m ()
copy mat (PrimState m) a
targ mat (PrimState m) a
src | mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
targ Size -> Size -> Bool
forall a. Eq a => a -> a -> Bool
== mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
src
                Bool -> Bool -> Bool
&& mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
targ Size -> Size -> Bool
forall a. Eq a => a -> a -> Bool
== mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
src = mat (PrimState m) a -> mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> mat (PrimState m) a -> m ()
unsafeCopy mat (PrimState m) a
targ mat (PrimState m) a
src
              | Bool
otherwise = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Two matrices should be of the same size"

clone :: (PrimMonad m, MMatrix mat a) => mat (PrimState m) a -> m (mat (PrimState m) a)
clone :: mat (PrimState m) a -> m (mat (PrimState m) a)
clone mat (PrimState m) a
m = do
  mat (PrimState m) a
m' <- Size -> Size -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> m (mat (PrimState m) a)
new (mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
m) (mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
m)
  mat (PrimState m) a -> mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
mat (PrimState m) a -> mat (PrimState m) a -> m ()
unsafeCopy mat (PrimState m) a
m' mat (PrimState m) a
m
  mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) a. Monad m => a -> m a
return mat (PrimState m) a
m'

fromRow :: (PrimMonad m, MMatrix mat a) => Row mat a -> m (mat (PrimState m) a)
fromRow :: Row mat a -> m (mat (PrimState m) a)
fromRow = Row mat a -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Row mat a -> m (mat (PrimState m) a)
unsafeFromRow

fromColumn :: (PrimMonad m, MMatrix mat a) => Column mat a -> m (mat (PrimState m) a)
fromColumn :: Column mat a -> m (mat (PrimState m) a)
fromColumn = Column mat a -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Column mat a -> m (mat (PrimState m) a)
unsafeFromColumn

generate :: (PrimMonad m, MMatrix mat a)
         => Int -> Int -> (Index -> Index -> a) -> m (mat (PrimState m) a)
generate :: Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
generate Size
w Size
h | Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
w Bool -> Bool -> Bool
&& Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
h = Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
unsafeGenerate Size
w Size
h
             | Bool
otherwise = [Char] -> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Size -> Size -> a) -> m (mat (PrimState m) a))
-> [Char] -> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
forall w. Monoid w => [w] -> w
concat [[Char]
"Generating matrix with negative width or height: ", Size -> [Char]
forall a. Show a => a -> [Char]
show Size
w, [Char]
"x", Size -> [Char]
forall a. Show a => a -> [Char]
show Size
h]

generateM :: (PrimMonad m, MMatrix mat a)
          => Int -> Int -> (Index -> Index -> m a) -> m (mat (PrimState m) a)
generateM :: Size -> Size -> (Size -> Size -> m a) -> m (mat (PrimState m) a)
generateM Size
w Size
h | Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
w Bool -> Bool -> Bool
&& Size
0 Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
<= Size
h = Size -> Size -> (Size -> Size -> m a) -> m (mat (PrimState m) a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, PrimMonad m) =>
Size -> Size -> (Size -> Size -> m a) -> m (mat (PrimState m) a)
unsafeGenerateM Size
w Size
h
              | Bool
otherwise = [Char] -> (Size -> Size -> m a) -> m (mat (PrimState m) a)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Size -> Size -> m a) -> m (mat (PrimState m) a))
-> [Char] -> (Size -> Size -> m a) -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
forall w. Monoid w => [w] -> w
concat [[Char]
"Generating matrix with negative width or height: ", Size -> [Char]
forall a. Show a => a -> [Char]
show Size
w, [Char]
"x", Size -> [Char]
forall a. Show a => a -> [Char]
show Size
h]

-- | Performs Gaussian reduction to given matrix, returns the pair of triangulated matrix, pivot matrix
--   and determinant.
gaussReduction :: (Eq a, PrimMonad m, Field a, Normed a, MMatrix mat a)
               => mat (PrimState m) a -> m (mat (PrimState m) a, mat (PrimState m) a, a)
gaussReduction :: mat (PrimState m) a
-> m (mat (PrimState m) a, mat (PrimState m) a, a)
gaussReduction mat (PrimState m) a
mat = do
  mat (PrimState m) a
m' <- mat (PrimState m) a -> m (mat (PrimState m) a)
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
mat (PrimState m) a -> m (mat (PrimState m) a)
clone mat (PrimState m) a
mat
  (mat (PrimState m) a
p, a
d) <- mat (PrimState m) a -> m (mat (PrimState m) a, a)
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, Normed a, Eq a, Field a, PrimMonad m) =>
mat (PrimState m) a -> m (mat (PrimState m) a, a)
unsafeGaussReduction mat (PrimState m) a
m'
  (mat (PrimState m) a, mat (PrimState m) a, a)
-> m (mat (PrimState m) a, mat (PrimState m) a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (mat (PrimState m) a
m', mat (PrimState m) a
p, a
d)

-- | Performs in-place gaussian reduction to the mutable matrix, and returns the pivoting matrix and determinant.
unsafeGaussReduction :: (MMatrix mat a, Normed a, Eq a, Field a, PrimMonad m)
                      => mat (PrimState m) a -> m (mat (PrimState m) a, a)
unsafeGaussReduction :: mat (PrimState m) a -> m (mat (PrimState m) a, a)
unsafeGaussReduction mat (PrimState m) a
mat = {-# SCC "gaussRed" #-} do
  mat (PrimState m) a
pivot <- Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> (Size -> Size -> a) -> m (mat (PrimState m) a)
generate (mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
mat) (mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
mat) ((Size -> Size -> a) -> m (mat (PrimState m) a))
-> (Size -> Size -> a) -> m (mat (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ \ Size
i Size
j -> if Size
i Size -> Size -> Bool
forall a. Eq a => a -> a -> Bool
== Size
j then a
forall r. Unital r => r
one else a
forall m. Monoidal m => m
zero
  a
det <- Size -> Size -> mat (PrimState m) a -> a -> m a
go Size
0 Size
0 mat (PrimState m) a
pivot a
forall r. Unital r => r
one
  (mat (PrimState m) a, a) -> m (mat (PrimState m) a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (mat (PrimState m) a
pivot, a
det)
  where
    nrows :: Size
nrows = mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
mat
    ncols :: Size
ncols = mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
columnCount mat (PrimState m) a
mat
    go :: Size -> Size -> mat (PrimState m) a -> a -> m a
go Size
i Size
j mat (PrimState m) a
p a
dAcc
      | Size
i Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
nrows Bool -> Bool -> Bool
|| Size
j Size -> Size -> Bool
forall a. Ord a => a -> a -> Bool
>= Size
ncols = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
dAcc
      | Bool
otherwise = do
          (Size
k, a
newC) <- ((Size, a) -> (Size, a) -> Ordering) -> [(Size, a)] -> (Size, a)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (((Size, a) -> Norm a) -> (Size, a) -> (Size, a) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (((Size, a) -> Norm a) -> (Size, a) -> (Size, a) -> Ordering)
-> ((Size, a) -> Norm a) -> (Size, a) -> (Size, a) -> Ordering
forall a b. (a -> b) -> a -> b
$ a -> Norm a
forall a. Normed a => a -> Norm a
norm (a -> Norm a) -> ((Size, a) -> a) -> (Size, a) -> Norm a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Size, a) -> a
forall a b. (a, b) -> b
snd) ([(Size, a)] -> (Size, a)) -> m [(Size, a)] -> m (Size, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            (Size -> m (Size, a)) -> [Size] -> m [(Size, a)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ Size
l -> (Size
l,) (a -> (Size, a)) -> m a -> m (Size, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Size -> Size -> mat (PrimState m) a -> m a
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> mat (PrimState m) a -> m a
read Size
l Size
j mat (PrimState m) a
mat) [Size
i..mat (PrimState m) a -> Size
forall (mat :: * -> * -> *) a s. MMatrix mat a => mat s a -> Size
rowCount mat (PrimState m) a
mat Size -> Size -> Size
forall r. Group r => r -> r -> r
- Size
1]
          if a -> Bool
forall r. DecidableZero r => r -> Bool
isZero a
newC
            then Size -> Size -> mat (PrimState m) a -> a -> m a
go Size
i (Size
j Size -> Size -> Size
forall r. Additive r => r -> r -> r
+ Size
1) mat (PrimState m) a
p a
forall m. Monoidal m => m
zero
            else do
            let cancel :: Size -> m ()
cancel Size
l = do
                  a
coe <- a -> a
forall r. Group r => r -> r
negate (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Size -> Size -> mat (PrimState m) a -> m a
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> mat (PrimState m) a -> m a
read Size
l Size
j mat (PrimState m) a
mat
                  Size -> a -> Size -> mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, Semiring a, Commutative a, PrimMonad m) =>
Size -> a -> Size -> mat (PrimState m) a -> m ()
combineRows Size
l a
coe Size
i mat (PrimState m) a
mat
                  Size -> a -> Size -> mat (PrimState m) a -> m ()
forall (mat :: * -> * -> *) a (m :: * -> *).
(MMatrix mat a, Semiring a, Commutative a, PrimMonad m) =>
Size -> a -> Size -> mat (PrimState m) a -> m ()
combineRows Size
l a
coe Size
i mat (PrimState m) a
p
            Size -> Size -> mat (PrimState m) a -> m ()
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> mat (PrimState m) a -> m ()
swapRows Size
i Size
k mat (PrimState m) a
mat
            Size -> a -> mat (PrimState m) a -> m ()
forall a (mat :: * -> * -> *) (m :: * -> *).
(Multiplicative a, Commutative a, MMatrix mat a, PrimMonad m) =>
Size -> a -> mat (PrimState m) a -> m ()
scaleRow Size
i (a -> a
forall r. Division r => r -> r
recip a
newC) mat (PrimState m) a
mat
            Size -> Size -> mat (PrimState m) a -> m ()
forall (m :: * -> *) (mat :: * -> * -> *) a.
(PrimMonad m, MMatrix mat a) =>
Size -> Size -> mat (PrimState m) a -> m ()
swapRows Size
i Size
k mat (PrimState m) a
p
            Size -> a -> mat (PrimState m) a -> m ()
forall a (mat :: * -> * -> *) (m :: * -> *).
(Multiplicative a, Commutative a, MMatrix mat a, PrimMonad m) =>
Size -> a -> mat (PrimState m) a -> m ()
scaleRow Size
i (a -> a
forall r. Division r => r -> r
recip a
newC) mat (PrimState m) a
p
            [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
0..Size
i Size -> Size -> Size
forall r. Group r => r -> r -> r
- Size
1] Size -> m ()
cancel
            [Size] -> (Size -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Size
iSize -> Size -> Size
forall r. Additive r => r -> r -> r
+Size
1..Size
nrows Size -> Size -> Size
forall r. Group r => r -> r -> r
- Size
1] Size -> m ()
cancel
            let offset :: a -> a
offset = if Size
i Size -> Size -> Bool
forall a. Eq a => a -> a -> Bool
== Size
k then a -> a
forall a. a -> a
id else a -> a
forall r. Group r => r -> r
negate
            Size -> Size -> mat (PrimState m) a -> a -> m a
go (Size
iSize -> Size -> Size
forall r. Additive r => r -> r -> r
+Size
1) (Size
jSize -> Size -> Size
forall r. Additive r => r -> r -> r
+Size
1) mat (PrimState m) a
p (a -> a
offset (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
dAcc a -> a -> a
forall r. Multiplicative r => r -> r -> r
* a
newC)