Optimizer: 基于 OptimizerTask 实现 LogicalPlan 到 PhysicalPlan

OptimizerTask 的继承关系如下图(还有个 RewriteTreeTask,由于和其他 Task 不一起使用,因此没画出来)。每个子 Task 都需要去实现 execute 函数来完成真正的操作。

OptimizerTask-2

OptimizerTask 函数执行时采用后续遍历的顺序来执行任务,如下图。当前优化一个 Group 时,从 OptimizeGroupTask 函数开始,只有创建 OptimizeGroupTask 对象时才会创建一个新的 TaskContext,用 TaskContext 来记录当前 Group 的优化过程中产生最低成本的上界 upperBoundCost。

在执行过程中,每个 Task 会尽可能裁剪遍历分支,缩小搜索空间,降低优化器执行时间。

OptimizerTask

这里主要讲解下 OptimizerTask 的搜索过程,每个 Task 的具体实现有空再说。

SeriallyTaskScheduler

SeriallyTaskScheduler 按照 first in last out 顺序执行 OptimizerTask:通过 SeriallyTaskScheduler.PushTask 和 OptimizerTask.PushTask 函数将待执行的 Task 入栈,在 SeriallyTaskScheduler.executeTasks 函数中将栈中元素 pop 出来执行。

executeTasks 函数限制了一次完整遍历的时间不能超过阈值 new_planner_optimize_timeout(默认值 3000ms),完整代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class SeriallyTaskScheduler implements TaskScheduler {
private final Stack<OptimizerTask> tasks;

private SeriallyTaskScheduler() {
tasks = new Stack<>();
}

@Override
public void pushTask(OptimizerTask task) {
tasks.push(task);
}

@Override
public void executeTasks(TaskContext context) {
OptimizerContext optimizerContext = context.getOptimizerContext();
long timeout =
optimizerContext.getSessionVariable().getOptimizerExecuteTimeout();
Stopwatch watch = optimizerContext.getTraceInfo().getStopwatch();

// 将 stack 中的所有元素提取出来,全部执行完毕
while (!tasks.empty()) {
if (watch.elapsed(TimeUnit.MILLISECONDS) > timeout) {
//...
break;
}
OptimizerTask task = tasks.pop();
optimizerContext.setTaskContext(context);
task.execute();
}
}
}

Memo

Memo

init

基于得到的逻辑表达式构建 Group Tree,初始化完成一个 OptExpression 对应一个 Group。核心是 Memo.copyIn 函数,后续 AppleRuleTask 中也会调用 Memo.copyIn 函数将应用规则得到的新 OptExpression 添加到 Group 的逻辑等价表达式集合中。

1
2
3
4
5
6
7
public GroupExpression init(OptExpression originExpression) {
// 将 originExpressionTree --> GroupTree
GroupExpression rootGroupExpression
= copyIn(null, originExpression).second;
rootGroup = rootGroupExpression.getGroup();
return rootGroupExpression;
}

Optimizer.memoOptimize

Optimizer.memoOptimize 函数可大致分为两个部分:

  1. 添加一些规则,这些规则最终会被应用在 ApplyRuleTask 中
  2. 执行 OptimizeGroupTask

memoOptimize 函数结束,优化器的主体工作就结束了。

代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
void memoOptimize(ConnectContext connectContext, Memo memo, TaskContext rootTaskContext) {
OptExpression tree = memo.getRootGroup().extractLogicalTree();
// Join reorder
SessionVariable sessionVariable = connectContext.getSessionVariable();
int innerCrossJoinNode = Utils.countJoinNodeSize(tree, JoinOperator.innerCrossJoinSet());
if (!sessionVariable.isDisableJoinReorder()
&& innerCrossJoinNode < sessionVariable.getCboMaxReorderNode()) {
if (innerCrossJoinNode > sessionVariable.getCboMaxReorderNodeUseExhaustive()) {

CTEUtils.collectForceCteStatistics(memo, context);
new ReorderJoinRule().transform(tree, context);
context.getRuleSet().addJoinCommutativityWithOutInnerRule();
} else {
if (Utils.countJoinNodeSize(tree, JoinOperator.semiAntiJoinSet())
< sessionVariable.getCboMaxReorderNodeUseExhaustive()) {
context.getRuleSet().getTransformRules().add(
new SemiReorderRule());
}
context.getRuleSet().addJoinTransformationRules();
}
}

//add join implementRule
context.getRuleSet().addAutoJoinImplementationRule();

if (isEnableMultiTableRewrite(connectContext, tree)) {
if (sessionVariable.isEnableMaterializedViewViewDeltaRewrite()
&& rootTaskContext.getOptimizerContext().getCandidateMvs()
.stream().anyMatch(context -> context.hasMultiTables())) {
context.getRuleSet().addSingleTableMvRewriteRule();
}
context.getRuleSet().addMultiTableMvRewriteRule();
}

// 函数探索入口
context.getTaskScheduler().pushTask(
new OptimizeGroupTask(rootTaskContext, memo.getRootGroup()));
context.getTaskScheduler().executeTasks(rootTaskContext);
}

OptimizeGroupTask

OptimizeGroupTask 是 Task 的执行入口:当开始优化一个 Group 时总是从 OptimizeGroupTask 开始。OptimizeGroupTask.execute 函数在执行前会先通过 Group.hasBestExpression 函数过滤已经优化过的 Group,防止重复搜索 or 死循环?

  • OptimizeExpressionTask: 用于优化所有逻辑等价的 Operator Tree
  • EnforceAndCostTask 用于寻找每个 Group 的 cost 最低的 PhysicalGroupExpression。

逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public void execute() {
// Group has optimized given the context
if (group.hasBestExpression(context.getRequiredProperty())) {
return;
}

// 逆序遍历,顺序执行
for (int i = group.getLogicalExpressions().size() - 1; i >= 0; i--) {
pushTask(new OptimizeExpressionTask(
context, group.getLogicalExpressions().get(i)));
}

// 逆序遍历,顺序执行
for (int i = group.getPhysicalExpressions().size() - 1; i >= 0; i--) {
pushTask((new EnforceAndCostTask(
context, group.getPhysicalExpressions().get(i))));
}
}

OptimizeExpressionTask

整体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@Override
public void execute() {
List<Rule> rules = getValidRules();

for (Rule rule : rules) {
pushTask(new ApplyRuleTask(
context, groupExpression, rule, isExplore));
}

pushTask(new DeriveStatsTask(context, groupExpression));
for (int i = groupExpression.arity() - 1; i >= 0; i--) {
pushTask(new ExploreGroupTask(
context, groupExpression.getInputs().get(i)));
}
}

任务执行的顺序是先入后出,因此执行 OptimizeExpressionTask 时,说明已经生成了一个初始化状态的最短路径。执行流程如下:

1.ExploreGroupTask

OptimizeExpressionTask.execute 函数中依次用 groupExpression 的所有的输入 (sub-groups) 来构造 ExploreGroupTask 对象

而 ExploreGroupTask.execute 基于 childGroup 中所有逻辑等价的 Group.logicalExpression 又构造 OptimizeExpressionTask 对象。如此不断递归直到递归基(叶结点),即就是没有输入的 GroupExpression 对象,进入 DeriveStatsTask 开始为 GroupExpression 获取统计信息。

1
2
3
4
5
6
7
8
9
10
public void ExploreGroupTask.execute() {
if (group.isExplored()) {
return;
}

for (GroupExpression logical : group.getLogicalExpressions()) {
pushTask(new OptimizeExpressionTask(context, logical, true));
}
group.setExplored();
}

2.DeriveStatsTask

顾名思义,DeriverStatsTask 是为 groupExpression.group.statistics 获取统计信息的,如果是物化视图则为 groupExpression.group.mvStatistics 获取统计信息。

核心功能由 StatisticsCalculator 实现。 关于 StatisticsCalculator,可以参考官方博客 StarRocks 统计信息和 Cost 估算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@Override
public void execute() {
// 裁剪
if (groupExpression.isStatsDerived() || groupExpression.isUnused()) {
return;
}

ExpressionContext expressionContext = new ExpressionContext(groupExpression);
StatisticsCalculator statisticsCalculator = new StatisticsCalculator(
expressionContext,
context.getOptimizerContext().getColumnRefFactory(),
context.getOptimizerContext());
statisticsCalculator.estimatorStats(); // 获取统计信息

Statistics currentStatistics
= groupExpression.getGroup().getStatistics();
Statistics statistics = expressionContext.getStatistics();

// 更新统计
if (currentStatistics == null ||
(statistics.getOutputRowCount() < currentStatistics.getOutputRowCount()
// currentStatistics != null 可能是 merge 的
&& !isMaterializedView())) {
groupExpression.getGroup().setStatistics(statistics);
}
if (currentStatistics != null && !currentStatistics.equals(statistics)) {
if (isMaterializedView()) {
LogicalOlapScanOperator scan = groupExpression.getOp().cast();
MaterializedView mv = (MaterializedView) scan.getTable();
groupExpression.getGroup().setMvStatistics(mv.getId(), statistics);
}
}

// 标记,防止重复统计
groupExpression.setStatsDerived();
}

3.ApplyRuleTask

按照规则生成新的逻辑表达式、物理表达式

OptimizeExpressionTask.getValidRules

首先要获取有有效的规则:logicalRules、physicalRules 是 Optimizer.memoOptimize 函数中在执行 OptimizeGroupTask 之前添加的。filterInValidRules 函数再将 groupExpression 和 logicalRules、physicalRules 进行匹配。 能匹配得上的即 validRules,最终将 validRules 基于 Rule::promise 值进行排序。

这里有个概念需要注意下:

  • Transform Rule: 基于规则,生成等价的逻辑计划,扩充搜索空间
  • Implement Rule: 将逻辑节点 LogicalOperator 转换为物理节点 PhysicalOperator

因此,难点是 Transform Rule,其中 Rule 细节后续有机会再深入分析咯。

Rule

代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private List<Rule> getValidRules() {
List<Rule> validRules =
Lists.newArrayListWithCapacity(RuleType.NUM_RULES.id());
List<Rule> logicalRules =
context.getOptimizerContext().getRuleSet().getTransformRules();
// 根据匹配规则,获取能应用到 groupExpression 的有效规则
filterInValidRules(groupExpression, logicalRules, validRules);

if (!isExplore) {
List<Rule> physicalRules =
context.getOptimizerContext().getRuleSet().getImplementRules();
filterInValidRules(groupExpression, physicalRules, validRules);
}

// 排序: logical rule 前, physical rule 在后
validRules.sort(Comparator.comparingInt(Rule::promise));
return validRules;
}

OptimizerTask.filterInValidRules

每个 Rule 都有一个 Pattern,只有 Rule 和 Pattern 匹配时才算一个有效规则。
解释如代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void filterInValidRules(GroupExpression groupExpression,
List<Rule> candidateRules,
List<Rule> validRules) {
OptimizerConfig optimizerConfig =
context.getOptimizerContext().getOptimizerConfig();
for (Rule rule : candidateRules) {
//CHECK-1: groupExpression 已经应用过这个规则
if (groupExpression.hasRuleExplored(rule)) {
continue;
}

// CHECK-2: operator + child 是否能匹配上
if (!rule.getPattern().matchWithoutChild(groupExpression)) {
continue;
}
// CHECK-3: 这个规则没有 disable
if (optimizerConfig.isRuleDisable(rule.type())) {
continue;
}
// 有效
validRules.add(rule);
}
}

Pattern.matchWithoutChild

Pattern 主要用于描述 Rule 的匹配规则,用户可以按照想要匹配的 Operator 类型创建 Pattern。除了常规的 Operator 类型以外,StarRocks 还提供两种特殊的 Operator:

  • PATTERN_LEAF:用于匹配任意单个节点,几乎每个 Rule 里都会以 PATTERN_LEFF 作为叶子节点;
  • PATTERN_MULTI_LEAF:用于匹配 N个(>=0)任意节点,在 UNION、INTERSECT、EXCEPT 这类多输入节点相关的Rule中比较常见。

matchWithoutChild 解释如代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public boolean matchWithoutChild(GroupExpression expression, int level) {
if (expression == null) {
return false;
}

// CHECK-1: 子节点个数是否能匹配上
if (expression.getInputs().size() < this.children().size()
&& children.stream().noneMatch(
p -> OperatorType.PATTERN_MULTI_LEAF.equals(p.getOpType()))) {
return false;
}

// CHECK-2: Pattern 是否能匹配任意 Operator
if (OperatorType.PATTERN_LEAF.equals(getOpType())
|| OperatorType.PATTERN_MULTI_LEAF.equals(getOpType())) {
return true;
}

OperatorType givenOpType = expression.getOp().getOpType();

// CHECK-3: Pattern 只能匹配 SCAN Operator
if (isPatternScan() && scanTypes.contains(givenOpType)) {
return true;
}

// CHECK-4: Pattern 只能匹配 JOIN Operator
if (isPatternMultiJoin() && isMultiJoin(givenOpType, level)) {
return true;
}

// CHECK-5: 直接判断
return getOpType().equals(givenOpType);
}

ApplyRuleTask.execute

匹配的核心是 Binder,能找到 GroupExpression 中符合 rule 的部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@Override
public void execute() {
//裁剪: 这个规则已经匹配过 or 这个group 不需要优化
if (groupExpression.hasRuleExplored(rule) ||
groupExpression.isUnused()) {
return;
}

OptimizerContext optimizerContext = context.getOptimizerContext();

// 1. 对 groupExpression 应用规则,生成新的表达式
Pattern pattern = rule.getPattern();
Binder binder = new Binder(pattern, groupExpression);
OptExpression extractExpr = binder.next();
List<OptExpression> newExpressions = Lists.newArrayList();
while (extractExpr != null) {
// 每个具体的规则会重载 Rule.check 函数
// 更加细致的判断是否符合 extractExpr 与 rule 能否匹配
if (!rule.check(extractExpr, optimizerContext)) {
extractExpr = binder.next();
continue;
}

// 生成新的 expression
newExpressions.addAll(
rule.transform(extractExpr, optimizerContext));

// 迭代
extractExpr = binder.next();
}

// 2. 针对新生成的表达式进行递归优化
for (OptExpression expression : newExpressions) {
// Insert new OptExpression to memo
Pair<Boolean, GroupExpression> result
= optimizerContext.getMemo().
copyIn(groupExpression.getGroup(), expression);

// The group has been merged
if (groupExpression.hasEmptyRootGroup()) {
return;
}

GroupExpression newGroupExpression = result.second;
if (newGroupExpression.getOp().isLogical()) {
// For logic newGroupExpression, optimize it
pushTask(new OptimizeExpressionTask(
context, newGroupExpression, isExplore));
} else {
// For physical newGroupExpression, enforce and cost it,
// Optimize its inputs if needed
pushTask(new EnforceAndCostTask(
context, newGroupExpression));
}
}

groupExpression.setRuleExplored(rule);
}

Binder

Binder 的作用是从 GroupExpression 搜索出符合 Pattern 的所有等价表达式,并通过 Binder.next 函数返回所有符合 Pattern 的表达式子树。

如图,左侧的 Pattern 在右侧的匹配结果就是虚线框中的部分。
Rule-Pattern

如图,左侧的 Pattern 匹配的结果,Binder.next 返回的结果顺序:

  • JOIN-OLAP_SCAN_1-OLAP_SCAN_2
  • JOIN-OLAP_SCAN_1-OLAP_SCAN_4
  • JOIN-OLAP_SCAN_3-OLAP_SCAN_3
  • JOIN-OLAP_SCAN_3-OLAP_SCAN_4

Rule-Pattern-2

因此,为了能 Binder.next 函数能返回不重复的所有匹配结果,Binder 需要具有存储状态。

  • groupTraceKey: 表示当前正在访问哪个 Group

  • groupExpressionIndex[groupTraceKey]: 表示访问 Group 的第几个等价表达式,在每次遍历时 groupExpressionIndex 最终的长度都是和 和 Group 的个数一致。

    groupExpressionIndex[0] 是当前 GroupExpression 所属的 Group,groupExpressionIndex[1:-1] 是 ChildernGroup

Binder.next 本质是个多叉树搜索算法,下面顺着代码顺序来看。

next

next 函数用于从 groupExpression 中找到匹配 Pattern 的 OptExpression。

  1. 每次进入 next 函数都需要把 groupTraceKey 赋值为0,表示重头从第一个 group 进行遍历,

  2. 将上次 next 函数访问的最后一个 group 中的等价表达式的下标 + 1

    每个 Group 中的逻辑等价的表达式的数量是 Group.getLogicalExpressions().size(),会依次访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public OptExpression next() {
// For logic scan to physical scan, we only need to match once
if (pattern.children().size() == 0 && groupExpressionIndex.get(0) > 0) {
return null;
}

OptExpression expression;
do {
this.groupTraceKey = 0;

// Match with the next groupExpression of the last group node
int lastNode = this.groupExpressionIndex.size() - 1;
int lastNodeIndex = this.groupExpressionIndex.get(lastNode);
// 增加上次访问到 Group 的逻辑等价表达式的下标,实现依次访问
this.groupExpressionIndex.set(lastNode, lastNodeIndex + 1);

expression = match(pattern, groupExpression, 0);
// while 中的判断条件是为了回溯
} while (expression == null && this.groupExpressionIndex.size() != 1);

return expression;
}

match

Binder.match 函数目的是找出 groupExpression 中匹配 Pattern 的 OptExpression 子树。实际上就是 Pattern-Tree 能否在 GroupExpression-Tree 中能否找到结构和自己一样的子树,因此需要同时遍历 Pattern、GroupExpression 两棵树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
private OptExpression match(Pattern pattern, GroupExpression groupExpression,
int level) {
// CHECK-1: 先匹配 groupExpression 自己是否符合 Pattern 结构
if (!pattern.matchWithoutChild(groupExpression, level)) {
return null;
}

// recursion match children
List<OptExpression> resultInputs = Lists.newArrayList();

// 记录两棵树遍历的位置
int patternIndex = 0;
int gEI = 0;

List<Pattern> childernPatternSize = pattern.children().size();
List<Group> childernGroup = groupExpression.getInputs();
// 遍历所有的 childernGroup
while ((patternIndex < childernPatternSize) && gEI < childernGroup.size()) {
trace();
Group childGroup = childernGroup.get(gEI);
Pattern childPattern = pattern.childAt(patternIndex);
// 先递归子 group
OptExpression opt =
match(childPattern,
extractGroupExpression(childPattern, childGroup),
level);

if (opt == null) {
return null;
}

resultInputs.add(opt);

// 同时为 true 才不更新 patternIndex
if (!(childPattern.isPatternMultiLeaf() &&
(childernGroup.size() - gEI) > (childernPatternSize - patternIndex))) {
patternIndex++;
}

gEI++;
}

OptExpression result = new OptExpression(groupExpression);
result.getInputs().addAll(resultInputs);
return result;
}

trace

准备下一个要访问的 Group。groupExpressionIndex.add(0) 表示第一次访问下一个Group。

1
2
3
4
5
6
private void trace() {
this.groupTraceKey++;
for (int i = this.groupExpressionIndex.size(); i < this.groupTraceKey + 1; i++) {
this.groupExpressionIndex.add(0);
}
}

extractGroupExpression

groupExpressionIndex 用于存储状态,在 extractGroupExpression 函数中提取 Group 的 groupExpressionIndex[groupTraceKey] 个逻辑等价的表达式。当 ChildernGroup[i] 的所有等价表达式访问完毕,则将 groupTraceKey 从 groupExpressionIndex 中删除,回溯到到前一个 ChildernGroup[i-1]。

本质上就是个后序遍历多叉树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private GroupExpression extractGroupExpression(Pattern pattern, Group group) {
if (pattern.isPatternLeaf() || pattern.isPatternMultiLeaf()) {
// 这种只有一个表达式
if (groupExpressionIndex.get(groupTraceKey) > 0) {
groupExpressionIndex.remove(groupTraceKey);
return null;
}
return group.getFirstLogicalExpression();
} else {
// 多个逻辑等价的表达式,依次遍历
int valueIndex = groupExpressionIndex.get(groupTraceKey);
if (valueIndex >= group.getLogicalExpressions().size()) {
// 本group遍历结束,删除节点进行回溯
groupExpressionIndex.remove(groupTraceKey);
return null;
}
return group.getLogicalExpressions().get(valueIndex);
}
}

Reference