Skip to content

Add matrix multiplication with double[][] and unit tests #6417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fcfd803
MatrixMultiplication.java created and updated.
Nishitha0730 Jul 19, 2025
c9f57d0
Add necessary comment to MatrixMultiplication.java
Nishitha0730 Jul 19, 2025
16c9897
Create MatrixMultiplicationTest.java
Nishitha0730 Jul 19, 2025
1ff92f5
method for 2 by 2 matrix multiplication is created
Nishitha0730 Jul 19, 2025
855d20f
Use assertMatrixEquals(), otherwise there can be error due to floatin…
Nishitha0730 Jul 19, 2025
05064a5
assertMatrixEquals method created and updated
Nishitha0730 Jul 19, 2025
4e4625c
method created for 3by2 matrix multiply with 2by1 matrix
Nishitha0730 Jul 19, 2025
f661455
method created for null matrix multiplication
Nishitha0730 Jul 19, 2025
a9438d2
method for test matrix dimension error
Nishitha0730 Jul 19, 2025
4bf395b
method for test empty matrix input
Nishitha0730 Jul 19, 2025
64a97db
testMultiply3by2and2by1 test case updated
Nishitha0730 Jul 19, 2025
6d16a54
Check for empty matrices part updated
Nishitha0730 Jul 19, 2025
65927d8
Updated Unit test coverage
Nishitha0730 Jul 19, 2025
1a58232
files updated
Nishitha0730 Jul 19, 2025
414200e
clean the code
Nishitha0730 Jul 19, 2025
d150bb2
clean the code
Nishitha0730 Jul 19, 2025
1858c69
Updated files with google-java-format
Nishitha0730 Jul 19, 2025
56baf84
Updated files
Nishitha0730 Jul 19, 2025
ca40a2b
Updated files
Nishitha0730 Jul 19, 2025
d633a3e
Updated files
Nishitha0730 Jul 19, 2025
4257686
Updated files
Nishitha0730 Jul 19, 2025
a9d9c39
Merge branch 'master' into my_algorithm
DenizAltunkapan Jul 19, 2025
dcec054
Add reference links and complexities
Nishitha0730 Jul 19, 2025
410d22f
Merge remote-tracking branch 'origin/my_algorithm' into my_algorithm
Nishitha0730 Jul 19, 2025
c243140
Add test cases for 1by1 matrix and non-rectangular matrix
Nishitha0730 Jul 19, 2025
49d7f98
Add reference links and complexities
Nishitha0730 Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.thealgorithms.matrix;

/**
* This class provides a method to perform matrix multiplication.
*
* <p>Matrix multiplication takes two 2D arrays (matrices) as input and
* produces their product, following the mathematical definition of
* matrix multiplication.
*
* <p>For more details:
* https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/
* https://en.wikipedia.org/wiki/Matrix_multiplication
*
* <p>Time Complexity: O(n^3) – where n is the dimension of the matrices
* (assuming square matrices for simplicity).
*
* <p>Space Complexity: O(n^2) – for storing the result matrix.
*
*
* @author Nishitha Wihala Pitigala
*
*/

public final class MatrixMultiplication {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add here another detailed explanation above the class. For example by providing a link and most importantly adding the time- and space complexity :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay:)

private MatrixMultiplication() {
}

/**
* Multiplies two matrices.
*
* @param matrixA the first matrix rowsA x colsA
* @param matrixB the second matrix rowsB x colsB
* @return the product of the two matrices rowsA x colsB
* @throws IllegalArgumentException if the matrices cannot be multiplied
*/
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
// Check the input matrices are not null
if (matrixA == null || matrixB == null) {
throw new IllegalArgumentException("Input matrices cannot be null");
}

// Check for empty matrices
if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) {
throw new IllegalArgumentException("Input matrices must not be empty");
}

// Validate the matrix dimensions
if (matrixA[0].length != matrixB.length) {
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
}

int rowsA = matrixA.length;
int colsA = matrixA[0].length;
int colsB = matrixB[0].length;

// Initialize the result matrix with zeros
double[][] result = new double[rowsA][colsB];

// Perform matrix multiplication
for (int i = 0; i < rowsA; i++) {
for (int j = 0; j < colsB; j++) {
for (int k = 0; k < colsA; k++) {
result[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
return result;
}
}
101 changes: 101 additions & 0 deletions src/test/java/com/thealgorithms/matrix/MatrixMultiplicationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package com.thealgorithms.matrix;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;

public class MatrixMultiplicationTest {

private static final double EPSILON = 1e-9; // for floating point comparison

@Test
void testMultiply1by1() {
double[][] matrixA = {{1.0}};
double[][] matrixB = {{2.0}};
double[][] expected = {{2.0}};

double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}

@Test
void testMultiply2by2() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};

double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
}

@Test
void testMultiply3by2and2by1() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
double[][] matrixB = {{7.0}, {8.0}};
double[][] expected = {{23.0}, {53.0}, {83.0}};

double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}

@Test
void testMultiplyNonRectangularMatrices() {
double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};
double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}};
double[][] expected = {{58.0, 64.0}, {139.0, 154.0}};

double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
assertMatrixEquals(expected, result);
}

@Test
void testNullMatrixA() {
double[][] b = {{1, 2}, {3, 4}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
}

@Test
void testNullMatrixB() {
double[][] a = {{1, 2}, {3, 4}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
}

@Test
void testMultiplyNull() {
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
double[][] matrixB = null;

Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));

String expectedMessage = "Input matrices cannot be null";
String actualMessage = exception.getMessage();

assertTrue(actualMessage.contains(expectedMessage));
}

@Test
void testIncompatibleDimensions() {
double[][] a = {{1.0, 2.0}};
double[][] b = {{1.0, 2.0}};
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
}

@Test
void testEmptyMatrices() {
double[][] a = new double[0][0];
double[][] b = new double[0][0];
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
}

private void assertMatrixEquals(double[][] expected, double[][] actual) {
assertEquals(expected.length, actual.length, "Row count mismatch");
for (int i = 0; i < expected.length; i++) {
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
for (int j = 0; j < expected[i].length; j++) {
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
}
}
}
}