Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法
2021-01-25 04:14
标签:sum tip except __name__ ras The pre algorithm ati Reference, 1. Introduction to algorithms Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法 标签:sum tip except __name__ ras The pre algorithm ati 原文地址:https://www.cnblogs.com/zzyzz/p/12862998.html问题:
求解矩阵乘法 C = A * B 的时候, 已知 A, B, C 均为 N x N 的方阵, 切 N 为 2 的幂(为简化问题).
A = [[A11, A12], [A21, A22]]
B = [[B11, B12], [B21, B22]]
C = [[C11, C12], [C21, C22]]
则(矩阵乘法运算法则):
C11 = A11 * B11 + A12 * B21
C12 = A11 * B12 + A12 * B22
C21 = A21 * B11 + A22 * B21
C22 = A21 * B12 + A22 * B22
N x N 方阵的常规计算方法:
def squre_matrix_multiply(A, B):
n = len(A)
# let c to be a new n x n matrix
c = [[0 for y in range(n)] for x in range(n)]
for i in range(n):
for j in range(n):
for k in range(n):
c[i][j] = c[i][j] + A[i][k] * B[k][j]
print(c)
if __name__ == ‘__main__‘:
A = [[2,1],[3,6]]
B = [[3,4],[2,2]]
squre_matrix_multiply(A,B)
结果:
[[8, 10], [21,24]]
通过分治思想求解:
分治思想: 将 N x N 划分为 4 个 N/2 * N/2 的子矩阵乘积之和.
def squre_matrix_multiply_recursive(A, B):
try:
n = len(A[0])
except TypeError:
n = 1
# let c to be a new nxn matrix
c = [[0 for x in range(n)] for y in range(n)]
if n == 1:
c = [[0],[0]]
c[0][0] = A[0] * B[0]
else: # partition A, B and C
c[0][0] = squre_matrix_multiply_recursive([A[0][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][0]])
c[0][1] = squre_matrix_multiply_recursive([A[0][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[0][1]], [B[1][1]])
c[1][0] = squre_matrix_multiply_recursive([A[1][0]], [B[0][0]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][0]])
c[1][1] = squre_matrix_multiply_recursive([A[1][0]], [B[0][1]]) + squre_matrix_multiply_recursive([A[1][1]], [B[1][1]])
# process the res
res = [[0 for x in range(n)] for y in range(n)]
for i in range(n):
for j in range(n):
res[i][j] = sum_list(c[i][j])
return res
def sum_list(A):
# A: [[6], [0], [2], [0]]
res = 0
try:
for i in A:
res += i[0]
except TypeError:
res += A
return resif __name__ == ‘__main__‘:
A = [[2,1],[3,6]]
B = [[3,4],[2,2]]
print(squre_matrix_multiply_recursive(A,B))
结果:
[[8, 10], [21, 24]]Strassen 算法:
Strassen 算法只递归进行 7 次运算 N/2 x N/2 矩阵的乘法(分治算法递归运算8次) .
1. 创建10个 N/2 x N/2 的矩阵 S1, S2, …, S10.
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 - B12
2. 通过 S1 … S10 构建 P1, P2, …, P7
P1 = A11 * S1 = A11 * B12 - A11 * B22
P2 = S2 * B22 = A11 * B22 + A12 * B22
P3 = S3 * B11 = A21 * B11 + A22 * B1
P4 = A22 * S4 = A22 * B21 - A22 * B11
P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22
P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22
P7 = S9 * 10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12
3. 通过上面步骤构建的 P1 … P7 来计算 C
C11 = P4 + P5 + P6 - P2
C12 = P1 + P2
C21 = P3 + P4
C22 = P1 + P5 +P7 - P3
def strassn(A, B):
try:
n = len(A[0])
except TypeError:
n = 1
# let c to be a new nxn matrix
c = [[0 for x in range(n)] for y in range(n)]
if n == 1:
c[0][0] = A[0] * B[0]
# partition A, B and C
else:
# only suit for 2X2 matrix
# step 1
s1 = B[0][1] - B[1][1]
s2 = A[0][0] + A[0][1]
s3 = A[1][0] + A[1][1]
s4 = B[1][0] - B[0][0]
s5 = A[0][0] + A[1][1]
s6 = B[0][0] + B[1][1]
s7 = A[0][1] - A[1][1]
s8 = B[1][0] + B[1][1]
s9 = A[0][0] - A[1][0]
s10 = B[0][0] + B[0][1]
# step 2
p1 = A[0][0] * s1
p2 = s2 * B[1][1]
p3 = s3 * B[0][0]
p4 = A[1][1] * s4
p5 = s5 * s6
p6 = s7 * s8
p7 = s9 * s10
# step 3
c[0][0] = p5 + p4 - p2 + p6
c[0][1] = p1 + p2
c[1][0] = p3 + p4
c[1][1] = p5 + p1 - p3 - p7
return c
if __name__ == ‘__main__‘:
A = [[2,1],[3,6]]
B = [[3,4],[2,2]]
print(strassn(A, B))
结果:
[[8, 10], [21, 24]]
strassn(A, B)
文章标题:Algorithms - Strassen's algorithm for matrix multiplication 矩阵乘法 Strassen 算法
文章链接:http://soscw.com/index.php/essay/46628.html