题解 | #矩阵乘法计算量估算#
矩阵乘法计算量估算
https://www.nowcoder.com/practice/15e41630514445719a942e004edc0a5b
栈的解法
思路看注释就够了。
这里只说一下新建的类,
矩阵的行列---MatrixInfo,包括行、列,
计算量---Amount,包括矩阵、计算量,用于矩阵计算的返回值,
栈的元素类型---Node:
由于只用1个栈,感觉1个栈用来判断括号更清晰一些。栈中要存矩阵、'('两种类型,所以建了一个同一类型Node,矩阵和'('只会有1个不为null。
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
/**
* HJ70 矩阵乘法计算量估算
*/
public class Main {
/**
* 借助栈
* 过程
* i从左到右遍历字符串
* 1. 如果a[i]==字母
* (1)如果栈顶不为矩阵,入栈
* (2)如果栈顶为矩阵,栈顶出栈,和a[i]计算,累加数量,结果矩阵入栈
* 2. 如果a[i]=='('
* 入栈
* 由于第1条,所以如果有一连串矩阵,随着i的移动,栈中'('以上只会有1个矩阵,栈类似这样:
* | 矩阵 |
* | ( |
* | 矩阵 |
* | ( |
* | ..... |
* | ..... |
* |_______|
* 所以对于')'的规则是:
* 3. 如果a[i]==')'
* 栈顶矩阵1出栈,'('出栈,
* (1)如果此时栈顶是矩阵2,那么矩阵2和矩阵1计算,累加数量,结果矩阵入栈
* (2)如果此时栈顶是'(',那么矩阵1入栈
* 栈中只会存矩阵、'('两种类型,不会存')'
* @param matrixInfos
* @param orderStr
* @return
*/
private static int calculateAmount(List<MatrixInfo> matrixInfos, String orderStr) {
// 字母和矩阵对应关系
Map<Character, MatrixInfo> matrixInfoMap = new HashMap<>(matrixInfos.size());
String letterStr = orderStr.replaceAll("[()]", "");
char[] chars = letterStr.toCharArray();
// 按照字母顺序排序
Arrays.sort(chars);
for (int i = 0; i < chars.length; i++) {
matrixInfoMap.put(chars[i], matrixInfos.get(i));
}
// 栈
Deque<Node> deque = new ArrayDeque<>();
char[] orderChars = orderStr.toCharArray();
// 累计计算量
int total = 0;
for (char c : orderChars) {
// 如果是字母,判断栈顶是否为矩阵,如果是,运算,运算结果放入栈顶
if (c >= 'A' && c <= 'Z') {
Node pNode = deque.peekLast();
// 如果栈为空,或者栈顶为'(',入栈
if (pNode == null || (pNode.bracket != null && pNode.bracket == '(')) {
MatrixInfo matrixInfo = matrixInfoMap.get(c);
Node node = new Node(matrixInfo);
deque.addLast(node);
}
// 栈不为空,且栈顶为矩阵,栈顶出栈,和a[i]计算,入栈
else {
Node node = deque.pollLast();
MatrixInfo cInfo = matrixInfoMap.get(c);
Amount r = calculateAmount(node.matrixInfo, cInfo);
// 结果累加
total = total + r.calAmount;
// 创建新的节点,入栈
Node tNode = new Node(r.matrixInfo);
deque.addLast(tNode);
}
}
// 如果是'(',入栈
else if (c == '(') {
deque.addLast(new Node('('));
}
// 如果是')'
else {
// 栈顶矩阵1出栈,'('出栈
Node matrixNode = deque.pollLast();
Node bracketNode = deque.pollLast();
// (1)如果此时栈顶是矩阵2,那么矩阵2出栈,和矩阵1计算,结果矩阵入栈
if (!deque.isEmpty() && deque.peekLast().matrixInfo != null) {
Node node = deque.pollLast();
Amount amount = calculateAmount(node.matrixInfo, matrixNode.matrixInfo);
// 累加结果
total = total + amount.calAmount;
// 创建新节点,入栈
deque.addLast(new Node(amount.matrixInfo));
}
// (2)如果此时栈为空(实际不会有这种情况)或者栈顶是'(',那么矩阵1入栈
else {
deque.addLast(matrixNode);
}
}
}
// i走完
return total;
}
/**
* 2个矩阵的计算量,a*b
* @param a
* @param b
* @return
*/
private static Amount calculateAmount(MatrixInfo a, MatrixInfo b) {
// 结果矩阵行列数
MatrixInfo r = new MatrixInfo(a.x, b.y);
// 计算量。假设a是x行y列,b是y行z列,结果是x行z列
// 结果是x行z列,那么一共x*z个元素
// 计算每个元素需要的乘法数量:a的一行*b的一列 --- y个数和y个数的乘积和 --- y次乘法
// 所以乘法数量=y*(x*z)
int amount = a.y * (a.x * b.y);
return new Amount(r, amount);
}
/**
* 计算量
*/
static class Amount {
MatrixInfo matrixInfo;
// 计算量
int calAmount = 0;
public Amount(MatrixInfo matrixInfo) {
this.matrixInfo = matrixInfo;
}
public Amount(MatrixInfo matrixInfo, int calAmount) {
this.matrixInfo = matrixInfo;
this.calAmount = calAmount;
}
}
/**
* 栈中存储的节点
* 如果是矩阵,那么matrixInfo!=null,bracket==null
* 如果是括号,那么bracket!=null,matrixInfo==null
*/
static class Node {
// 矩阵
MatrixInfo matrixInfo;
// 括号
Character bracket;
public Node(MatrixInfo matrixInfo) {
this.matrixInfo = matrixInfo;
}
public Node(Character bracket) {
this.bracket = bracket;
}
}
static class MatrixInfo {
// 矩阵行数
int x;
// 矩阵列数
int y;
public MatrixInfo(int x, int y) {
this.x = x;
this.y = y;
}
}
public static void main(String[] args) {
// List<MatrixInfo> matrixInfos = new ArrayList<>();
// matrixInfos.add(new MatrixInfo(50, 10));
// matrixInfos.add(new MatrixInfo(10, 20));
// matrixInfos.add(new MatrixInfo(20, 5));
// String orderStr = "(A(BC))";
// int r = calculateAmount(matrixInfos, orderStr);
// System.out.println(r); // 预期3500
// List<MatrixInfo> matrixInfos = new ArrayList<>();
// matrixInfos.add(new MatrixInfo(8, 6));
// matrixInfos.add(new MatrixInfo(6, 14));
// String orderStr = "(AB)";
// int r = calculateAmount(matrixInfos, orderStr);
// System.out.println(r); // 预期672
try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in))) {
String ns = bufferedReader.readLine();
int n = Integer.parseInt(ns);
List<MatrixInfo> list = new ArrayList<>();
for (int i = 0; i < n; i++) {
String matrixS = bufferedReader.readLine();
String[] a = matrixS.split(" ");
list.add(new MatrixInfo(Integer.parseInt(a[0]), Integer.parseInt(a[1])));
}
String orderStr = bufferedReader.readLine();
int result = calculateAmount(list, orderStr);
System.out.println(result);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
查看17道真题和解析