Mastodon
跳过正文
  1. Posts/

Anthropic's Original Performance Takehome

·8897 字·18 分钟·

前天在Anthropic的博客看到他们开源了一个之前他们使用的笔试优化题,感觉挺有意思的,来尝试一下。

关于题目
#

题目本质是一个并行的树遍历问题。

数据结构
#

1
2
3
4
5
6
7
8
Tree:
        节点0
       /     \
    节点1     节点2
   /   \     /   \
  3     4   5     6
 / \   / \ / \   / \
7  8  9 10 11 12 13 14

树的结构是一个完美二叉树,每个节点有一个随机值,然后节点按层序遍历编号。

  • 左子节点索引 \(=2×\)父索引 \(+1\)
  • 右子节点索引 \(=2×\)父索引 \(+2\)
1
2
3
4
5
6
7
8
@dataclass
class Tree:
    """
    An implicit perfect balanced binary tree with values on the nodes.
    """

    height: int
    values: list[int]

输入有很多个工作项,每个工作项由一个用于记录当前位置的树节点索引和一个当前值组成。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@dataclass
class Input:
    """
    A batch of inputs, indices to nodes (starting as 0) and initial input
    values. We then iterate these for a specified number of rounds.
    """

    indices: list[int]
    values: list[int]
    rounds: int

流程
#

每个工作项会独立执行下面这些操作,重复多轮:

  1. 读取当前工作项的值
  2. myhash计算新值
  3. 根据新值的奇偶性决定走哪个分支,如果超出树的范围,则回到根节点
  4. 更新状态,包括工作项的位置和值
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def reference_kernel(t: Tree, inp: Input):
    """
    Reference implementation of the kernel.

    A parallel tree traversal where at each node we set
    cur_inp_val = myhash(cur_inp_val ^ node_val)
    and then choose the left branch if cur_inp_val is even.
    If we reach the bottom of the tree we wrap around to the top.
    """
    for h in range(inp.rounds):
        for i in range(len(inp.indices)):
            idx = inp.indices[i]
            val = inp.values[i]
            val = myhash(val ^ t.values[idx])
            idx = 2 * idx + (1 if val % 2 == 0 else 2)
            idx = 0 if idx >= len(t.values) else idx
            inp.values[i] = val
            inp.indices[i] = idx

关于myhash也要了解一下,它需要经过六个阶段的位运算。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def myhash(a: int) -> int:
    """A simple 32-bit hash function"""
    fns = {
        "+": lambda x, y: x + y,
        "^": lambda x, y: x ^ y,
        "<<": lambda x, y: x << y,
        ">>": lambda x, y: x >> y,
    }

    def r(x):
        return x % (2**32)

    for op1, val1, op2, op3, val3 in HASH_STAGES:
        a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))

    return a

我们要做的就是修改perf_takehome.py中的build_kernel函数,它目前是一个标量实现,一次只处理一个工作项。

关于模拟器
#

在开始前,还需要了解一下这个模拟器的相关知识。

VLIW架构
#

Very Large Instruction Word 一条指令可以包含多个引擎的多个操作。比如:

1
2
3
4
5
6
7
# 一条指令的例子
{
    "alu": [("*", 10, 5, 6), ("+", 11, 7, 8)],    # 2个ALU操作
    "valu": [("^", 20, 21, 22)],                   # 1个向量ALU操作
    "load": [("vload", 30, 31)],                   # 1个向量加载
    "flow": [("select", 40, 41, 42, 43)]          # 1个控制流操作
}

这一条指令会在一个周期内执行,所有引擎并行工作,所有操作的读取都会在周期开始前发送,写入在周期结束时发生。

每个引擎槽的槽位限制如下:

1
2
3
4
5
6
7
SLOT_LIMITS = {
    "alu": 12,   # 每周期最多12个标量ALU操作
    "valu": 6,   # 每周期最多6个向量ALU操作(每个处理8个元素)
    "load": 2,   # 每周期最多2个加载操作
    "store": 2,  # 每周期最多2个存储操作
    "flow": 1,   # 每周期最多1个控制流操作
}

SIMD向量操作
#

向量长度VLEN=8,相比于标量:

1
2
3
4
5
6
# 标量操作:8个周期处理8个元素
for i in range(8):
    {"alu": [("*", dest+i, a+i, b+i)]}  # 每次一个周期

# 向量操作:1个周期处理8个元素
{"valu": [("*", dest, a, b)]}  # dest到dest+7 = a到a+7 * b到b+7

一些比较重要的向量指令如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# 1. 向量广播(标量→向量)
("vbroadcast", dest, scalar_src)
# dest[0..7] = scalar_src

# 2. 向量ALU(所有标量ALU操作都有向量版本)
("*", v_dest, v_a, v_b)  # v_dest[i] = v_a[i] * v_b[i], for i in 0..7
("+", v_dest, v_a, v_b)
("^", v_dest, v_a, v_b)
# ... 等等

# 3. 融合乘加(特殊优化指令)
("multiply_add", v_dest, v_a, v_b, v_c)
# v_dest[i] = v_a[i] * v_b[i] + v_c[i]

# 4. 向量选择
("vselect", v_dest, v_cond, v_a, v_b)
# v_dest[i] = v_a[i] if v_cond[i] != 0 else v_b[i]

# 5. 向量加载/存储
("vload", v_dest, scalar_addr)   # 加载连续的8个元素
("vstore", scalar_addr, v_src)   # 存储连续的8个元素

内存和暂存空间
#

暂存空间的容量是1536个32位字。

延迟写入语义
#

在同一周期内,所有引擎的写入操作会先写入暂存空间,然后再写入内存。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 在step()函数中:
self.scratch_write = {}  # 写缓冲区
self.mem_write = {}

# 1. 所有引擎执行,写入到缓冲区
for name, slots in instr.items():
    ENGINE_FNS[name](core, *slot)  # 写入 scratch_write 和 mem_write

# 2. 周期结束时才真正写入
for addr, val in self.scratch_write.items():
    core.scratch[addr] = val
for addr, val in self.mem_write.items():
    self.mem[addr] = val

这也意味着在同一条指令内,读取是旧值,不可读写同一地址,但可以安全地读写不同的地址。

控制流指令
#

  • 跳转指令
1
2
3
4
5
6
7
# 1. 无条件跳转
("jump", addr)               # pc = addr
("jump_indirect", addr_var)  # pc = scratch[addr_var]

# 2. 条件跳转
("cond_jump", cond, addr)           # if scratch[cond] != 0: pc = addr
("cond_jump_rel", cond, offset)     # if scratch[cond] != 0: pc += offset
  • 条件选择(无分支)
1
2
3
4
5
6
7
# 标量选择
("select", dest, cond, a, b)
# dest = a if scratch[cond] != 0 else b

# 向量选择
("vselect", v_dest, v_cond, v_a, v_b)
# v_dest[i] = v_a[i] if v_cond[i] != 0 else v_b[i]

build_kernel函数
#

最后来看看build_kernel函数的原始实现,它会生成一系列指令来执行reference_kernel2的逻辑。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def build_kernel(
        self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
    ):
    """
    Like reference_kernel2 but building actual instructions.
    Scalar implementation using only scalar ALU and load/store.
    """

    # 分配三个暂存空间地址
    tmp1 = self.alloc_scratch("tmp1")
    tmp2 = self.alloc_scratch("tmp2")
    tmp3 = self.alloc_scratch("tmp3")

    # Scratch space addresses
    # 给这七个变量分配暂存空间
    init_vars = [
        "rounds",
        "n_nodes",
        "batch_size",
        "forest_height",
        "forest_values_p",
        "inp_indices_p",
        "inp_values_p",
    ]
    for v in init_vars:
        self.alloc_scratch(v, 1)

    # 将内存的前七个元素(元数据)加载到暂存空间
    for i, v in enumerate(init_vars):
        self.add("load", ("const", tmp1, i))
        self.add("load", ("load", self.scratch[v], tmp1))

    # 预加载常量
    zero_const = self.scratch_const(0)
    one_const = self.scratch_const(1)
    two_const = self.scratch_const(2)

    # Pause instructions are matched up with yield statements in the reference
    # kernel to let you debug at intermediate steps. The testing harness in this
    # file requires these match up to the reference kernel's yields, but the
    # submission harness ignores them.
    self.add("flow", ("pause",))
    # Any debug engine instruction is ignored by the submission simulator
    self.add("debug", ("comment", "Starting loop"))

    body = []  # array of slots

    # Scalar scratch registers
    # 分配循环内的临时变量
    tmp_idx = self.alloc_scratch("tmp_idx")
    tmp_val = self.alloc_scratch("tmp_val")
    tmp_node_val = self.alloc_scratch("tmp_node_val")
    tmp_addr = self.alloc_scratch("tmp_addr")

    for round in range(rounds):
        for i in range(batch_size):
            i_const = self.scratch_const(i)
            # idx = mem[inp_indices_p + i]
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
            body.append(("load", ("load", tmp_idx, tmp_addr)))
            body.append(("debug", ("compare", tmp_idx, (round, i, "idx"))))
            # val = mem[inp_values_p + i]
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
            body.append(("load", ("load", tmp_val, tmp_addr)))
            body.append(("debug", ("compare", tmp_val, (round, i, "val"))))
            # node_val = mem[forest_values_p + idx]
            body.append(("alu", ("+", tmp_addr, self.scratch["forest_values_p"], tmp_idx)))
            body.append(("load", ("load", tmp_node_val, tmp_addr)))
            body.append(("debug", ("compare", tmp_node_val, (round, i, "node_val"))))
            # val = myhash(val ^ node_val)
            body.append(("alu", ("^", tmp_val, tmp_val, tmp_node_val)))
            body.extend(self.build_hash(tmp_val, tmp1, tmp2, round, i))
            body.append(("debug", ("compare", tmp_val, (round, i, "hashed_val"))))
            # idx = 2*idx + (1 if val % 2 == 0 else 2)
            body.append(("alu", ("%", tmp1, tmp_val, two_const)))
            body.append(("alu", ("==", tmp1, tmp1, zero_const)))
            body.append(("flow", ("select", tmp3, tmp1, one_const, two_const)))
            body.append(("alu", ("*", tmp_idx, tmp_idx, two_const)))
            body.append(("alu", ("+", tmp_idx, tmp_idx, tmp3)))
            body.append(("debug", ("compare", tmp_idx, (round, i, "next_idx"))))
            # idx = 0 if idx >= n_nodes else idx
            body.append(("alu", ("<", tmp1, tmp_idx, self.scratch["n_nodes"])))
            body.append(("flow", ("select", tmp_idx, tmp1, tmp_idx, zero_const)))
            body.append(("debug", ("compare", tmp_idx, (round, i, "wrapped_idx"))))
            # mem[inp_indices_p + i] = idx
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
            body.append(("store", ("store", tmp_addr, tmp_idx)))
            # mem[inp_values_p + i] = val
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
            body.append(("store", ("store", tmp_addr, tmp_val)))

    body_instrs = self.build(body)
    self.instrs.extend(body_instrs)
    # Required to match with the yield in reference_kernel2
    self.instrs.append({"flow": [("pause",)]})

关于内存头部的七个元数据,在build_mem_image里可以看到:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def build_mem_image(t: Tree, inp: Input) -> list[int]:
    """
    Build a flat memory image of the problem.
    """
    header = 7 # 头部的七个元素
    extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
    mem = [0] * (
        header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
    )
    forest_values_p = header
    inp_indices_p = forest_values_p + len(t.values)
    inp_values_p = inp_indices_p + len(inp.values)
    extra_room = inp_values_p + len(inp.values)

    # 头部存储元数据
    mem[0] = inp.rounds        # 迭代轮数
    mem[1] = len(t.values)     # 树的节点数
    mem[2] = len(inp.indices)  # 批次大小
    mem[3] = t.height          # 树的高度
    mem[4] = forest_values_p   # 树数据的起始地址
    mem[5] = inp_indices_p     # 索引数组的起始地址
    mem[6] = inp_values_p      # 值数组的起始地址
    mem[7] = extra_room        # 额外空间的起始地址

    mem[header:inp_indices_p] = t.values           # 树的节点值
    mem[inp_indices_p:inp_values_p] = inp.indices  # 工作项索引
    mem[inp_values_p:] = inp.values                # 工作项值
    return mem

内存布局图:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
Address
┌───────────────┐
│ mem[0]        │  rounds
├───────────────┤
│ mem[1]        │  number_of_tree_nodes = len(t.values)
├───────────────┤
│ mem[2]        │  batch_size = len(inp.indices)
├───────────────┤
│ mem[3]        │  tree_height
├───────────────┤
│ mem[4]        │  forest_values_ptr
│               │  (= header)
├───────────────┤
│ mem[5]        │  input_indices_ptr
│               │  (= forest_values_ptr + len(t.values))
├───────────────┤
│ mem[6]        │  input_values_ptr
│               │  (= input_indices_ptr + len(inp.indices))
├───────────────┤
│ mem[7]        │  extra_room_ptr
│               │  (= input_values_ptr + len(inp.values))
└───────────────┘
        │  Header (7 integers)
┌──────────────────────────────────────────────┐
│ Tree Values (t.values)                       │
│ mem[forest_values_ptr ... input_indices_ptr) │
└──────────────────────────────────────────────┘

┌──────────────────────────────────────────────┐
│ Input Indices (inp.indices)                  │
│ mem[input_indices_ptr ... input_values_ptr)  │
└──────────────────────────────────────────────┘

┌──────────────────────────────────────────────┐
│ Input Values (inp.values)                    │
│ mem[input_values_ptr ... extra_room_ptr)     │
└──────────────────────────────────────────────┘

┌──────────────────────────────────────────────┐
│ Extra Room / Scratch Space                   │
│ mem[extra_room_ptr ... end_of_mem)           │
│                                              │
│ - temporary buffers                          │
│ - vector registers spill area                │
│ - intermediate computation data              │
└──────────────────────────────────────────────┘

优化
#

了解完这些,就可以开始着手优化了。下面是我在开始前运行的测试的结果,作为一个对照。

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  147734
Speedup over baseline:  1.0
.
----------------------------------------------------------------------
Ran 1 test in 1.213s

OK

向量化
#

首先是向量化,这个的收益最大。新的build_kernel函数和向量化的哈希函数如下,写的时候注意一下暂存空间的复用就行。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def build_hash_vectorized(self, vec_val, vec_tmp1, vec_tmp2, round, i, hash_vec_consts):
    """
    向量化的哈希函数
    vec_val: 输入/输出向量
    vec_tmp1, vec_tmp2: 向量临时变量
    hash_vec_consts: 预分配的哈希常量字典
    """
    slots = []

    for hi in range(len(HASH_STAGES)):
        # 从预分配的常量中获取
        const_info = hash_vec_consts[hi]
        vec_const1 = const_info['vec_const1']
        vec_const3 = const_info['vec_const3']
        op1 = const_info['op1']
        op2 = const_info['op2']
        op3 = const_info['op3']

        # 向量运算
        slots.append(("valu", (op1, vec_tmp1, vec_val, vec_const1)))
        slots.append(("valu", (op3, vec_tmp2, vec_val, vec_const3)))
        slots.append(("valu", (op2, vec_val, vec_tmp1, vec_tmp2)))

        # Debug验证
        for vi in range(8):
            slots.append(("debug", ("compare", vec_val + vi, (round, i + vi, "hash_stage", hi))))

    return slots

def build_kernel(
        self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
    ):
    """
    Like reference_kernel2 but building actual instructions.
    Scalar implementation using only scalar ALU and load/store.
    """
    tmp1 = self.alloc_scratch("tmp1")
    tmp2 = self.alloc_scratch("tmp2")
    tmp3 = self.alloc_scratch("tmp3")
    # Scratch space addresses
    init_vars = [
        "rounds",
        "n_nodes",
        "batch_size",
        "forest_height",
        "forest_values_p",
        "inp_indices_p",
        "inp_values_p",
    ]
    for v in init_vars:
        self.alloc_scratch(v, 1)
    for i, v in enumerate(init_vars):
        self.add("load", ("const", tmp1, i))
        self.add("load", ("load", self.scratch[v], tmp1))

    zero_const = self.scratch_const(0)
    one_const = self.scratch_const(1)
    two_const = self.scratch_const(2)

    # Pause instructions are matched up with yield statements in the reference
    # kernel to let you debug at intermediate steps. The testing harness in this
    # file requires these match up to the reference kernel's yields, but the
    # submission harness ignores them.
    self.add("flow", ("pause",))
    # Any debug engine instruction is ignored by the submission simulator
    self.add("debug", ("comment", "Starting loop"))

    body = []  # array of slots

    # 标量变量(用于地址计算)
    tmp_addr = self.alloc_scratch("tmp_addr_scalar")

    # 在循环前分配向量变量(每个占8个地址)
    vec_idx = self.alloc_scratch("vec_idx", length=8)
    vec_val = self.alloc_scratch("vec_val", length=8)
    vec_node_val = self.alloc_scratch("vec_node_val", length=8)
    vec_addr = self.alloc_scratch("vec_addr", length=8)

    # 在循环前分配向量临时变量用于哈希(每个占8个地址)
    vec_tmp1 = self.alloc_scratch("vec_tmp1", length=8)
    vec_tmp2 = self.alloc_scratch("vec_tmp2", length=8)
    vec_tmp3 = self.alloc_scratch("vec_tmp3", length=8)

    # 向量变量(用于节点值加载)
    vec_forest_p = self.alloc_scratch("vec_forest_p", length=8)
    body.append(("valu", ("vbroadcast", vec_forest_p, self.scratch["forest_values_p"])))

    # 向量常量(用于索引更新)
    vec_two = self.alloc_scratch("vec_two", length=8)
    vec_zero = self.alloc_scratch("vec_zero", length=8)
    vec_one = self.alloc_scratch("vec_one", length=8)
    # 广播常量
    body.append(("valu", ("vbroadcast", vec_two, two_const)))
    body.append(("valu", ("vbroadcast", vec_zero, zero_const)))
    body.append(("valu", ("vbroadcast", vec_one, one_const)))

    # 向量变量(用于边界检查)
    vec_n_nodes = self.alloc_scratch("vec_n_nodes", length=8)
    body.append(("valu", ("vbroadcast", vec_n_nodes, self.scratch["n_nodes"])))

    # 预分配哈希函数需要的向量常量
    hash_vec_consts = {}

    for hi, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
        # 为每个阶段的两个常量分配向量空间
        vec_const1 = self.alloc_scratch(f"hash_c1_{hi}", length=8)
        vec_const3 = self.alloc_scratch(f"hash_c3_{hi}", length=8)

        # 获取标量常量
        const1_scalar = self.scratch_const(val1)
        const3_scalar = self.scratch_const(val3)

        # 广播到向量
        body.append(("valu", ("vbroadcast", vec_const1, const1_scalar)))
        body.append(("valu", ("vbroadcast", vec_const3, const3_scalar)))

        # 存储
        hash_vec_consts[hi] = {
            'vec_const1': vec_const1,
            'vec_const3': vec_const3,
            'op1': op1,
            'op2': op2,
            'op3': op3,
        }

    for round in range(rounds):
        for i in range(0, batch_size, VLEN):
            i_const = self.scratch_const(i)

            # 1.索引加载
            # 计算起始地址
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
            # 向量加载
            # vec_idx[0..7] = mem[inp_indices_p+i : inp_indices_p+i+8]
            body.append(("load", ("vload", vec_idx, tmp_addr)))
            for vi in range(8):
                body.append(("debug", ("compare", vec_idx + vi, (round, i + vi, "idx"))))

            # 2.值加载
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
            # vec_val[0..7] = mem[inp_values_p+i : inp_values_p+i+8]
            body.append(("load", ("vload", vec_val, tmp_addr)))
            for vi in range(8):
                body.append(("debug", ("compare", vec_val + vi, (round, i + vi, "val"))))

            # 3.节点值加载(不连续)
            # vec_addr[i] = forest_values_p + vec_idx[i]
            body.append(("valu", ("+", vec_addr, vec_forest_p, vec_idx)))
            for offset in range(VLEN):
                body.append(("load", ("load_offset", vec_node_val, vec_addr, offset))) # 这里可以VLIW并行加载
            for vi in range(8):
                body.append(("debug", ("compare", vec_node_val + vi, (round, i + vi, "node_val"))))

            # 4.向量XOR:vec_val[i] = vec_val[i] ^ vec_node_val[i]
            body.append(("valu", ("^", vec_val, vec_val, vec_node_val)))

            # 5.哈希
            body.extend(self.build_hash_vectorized(vec_val, vec_tmp1, vec_tmp2, round, i, hash_vec_consts))
            for vi in range(8):
                body.append(("debug", ("compare", vec_val + vi, (round, i + vi, "hashed_val"))))

            # 6.索引更新
            # vec_tmp1 = vec_val % 2
            body.append(("valu", ("%", vec_tmp1, vec_val, vec_two)))
            # vec_tmp1 = (vec_tmp1 == 0)
            body.append(("valu", ("==", vec_tmp1, vec_tmp1, vec_zero)))
            # vec_tmp3 = 1 if even else 2
            body.append(("flow", ("vselect", vec_tmp3, vec_tmp1, vec_one, vec_two)))
            # vec_idx = vec_idx * 2
            body.append(("valu", ("*", vec_idx, vec_idx, vec_two)))
            # vec_idx = vec_idx + vec_tmp3
            body.append(("valu", ("+", vec_idx, vec_idx, vec_tmp3)))
            for vi in range(8):
                body.append(("debug", ("compare", vec_idx + vi, (round, i + vi, "next_idx"))))

            # 7.边界检查
            # vec_tmp1 = (vec_idx < n_nodes)
            body.append(("valu", ("<", vec_tmp1, vec_idx, vec_n_nodes)))
            # vec_idx = vec_idx if in_range else 0
            body.append(("flow", ("vselect", vec_idx, vec_tmp1, vec_idx, vec_zero)))
            for vi in range(8):
                body.append(("debug", ("compare", vec_idx + vi, (round, i + vi, "wrapped_idx"))))

            # 8.存储结果
            # 存储索引
            # 计算起始地址
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_indices_p"], i_const)))
            # 向量存储
            body.append(("store", ("vstore", tmp_addr, vec_idx)))
            # 存储值
            body.append(("alu", ("+", tmp_addr, self.scratch["inp_values_p"], i_const)))
            body.append(("store", ("vstore", tmp_addr, vec_val)))

    body_instrs = self.build(body)
    self.instrs.extend(body_instrs)
    # Required to match with the yield in reference_kernel2
    self.instrs.append({"flow": [("pause",)]})

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  22094
Speedup over baseline:  6.686611749796325
.
----------------------------------------------------------------------
Ran 1 test in 0.448s

OK

VLIW优化
#

这个优化的目的是让指令槽全部被填满。实现方案是一个贪心调度器,它会尝试将指令塞入当前的Instruction Bundle,只有遇到资源冲突或者数据依赖才开启下一个Bundle。

解析读写依赖的函数:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def get_rw_sets(engine, slot):
    """
    解析指令的读写依赖范围。
    返回: (read_set, write_set)
    集合中包含具体的 scratch 地址。
    """
    op = slot[0]
    reads = set()
    writes = set()

    # 辅助:添加范围
    def add_range(s, base, length=8):
        for i in range(length):
            s.add(base + i)

    if engine == 'alu':
        # (op, dest, src1, src2)
        # 绝大多数 ALU 都是 dest, src1, src2
        if len(slot) >= 4:
            writes.add(slot[1])
            reads.add(slot[2])
            reads.add(slot[3])
        elif len(slot) == 3: # 比如 not 之类的单操作数,虽然 problem.py 里好像都是双操作数
             writes.add(slot[1])
             reads.add(slot[2])

    elif engine == 'valu':
        # (op, dest, src1, src2) 或者 (vbroadcast, dest, src)
        dest = slot[1]
        add_range(writes, dest, 8)

        if op == 'vbroadcast':
            reads.add(slot[2]) # src 是标量
        elif op == 'multiply_add':
            # (multiply_add, dest, a, b, c) -> dest = a*b + c
            # 这是一个特殊的 fused 指令,Problem.py 里定义了
            # dest 是向量,a,b,c 都是向量
            add_range(reads, slot[2], 8)
            add_range(reads, slot[3], 8)
            add_range(reads, slot[4], 8)
        else:
            # 普通双操作数 (op, dest, a, b)
            if len(slot) > 2: add_range(reads, slot[2], 8)
            if len(slot) > 3: add_range(reads, slot[3], 8)

    elif engine == 'load':
        # load: (load, dest, addr)
        # vload: (vload, dest, addr)
        # load_offset: (load_offset, dest, addr, offset)
        # const: (const, dest, val)
        if op == 'load':
            writes.add(slot[1])
            reads.add(slot[2])
        elif op == 'vload':
            add_range(writes, slot[1], 8)
            reads.add(slot[2]) # addr 是标量
        elif op == 'load_offset':
            dest_base, addr_base, offset = slot[1], slot[2], slot[3]
            writes.add(dest_base + offset)
            reads.add(addr_base + offset)
        elif op == 'const':
            writes.add(slot[1])

    elif engine == 'store':
        # store: (store, addr, src)
        # vstore: (vstore, addr, src)
        if op == 'store':
            reads.add(slot[1])
            reads.add(slot[2])
        elif op == 'vstore':
            reads.add(slot[1]) # addr 是标量
            add_range(reads, slot[2], 8) # src 是向量

    elif engine == 'flow':
        # vselect: (vselect, dest, cond, a, b)
        if op == 'vselect':
            add_range(writes, slot[1], 8)
            add_range(reads, slot[2], 8)
            add_range(reads, slot[3], 8)
            add_range(reads, slot[4], 8)
        elif op == 'select':
            writes.add(slot[1])
            reads.add(slot[2])
            reads.add(slot[3])
            reads.add(slot[4])
        # 其他 flow 指令暂时用不到

    # Debug 指令
    elif engine == 'debug':
        # debug: (compare, loc, key)
        # debug: (vcompare, loc, keys)
        op = slot[0]
        if op == 'compare':
            reads.add(slot[1])
        elif op == 'vcompare':
            add_range(reads, slot[1], 8)

    return reads, writes

核心调度器:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def smart_build(self, ops_list):
    """
    VLIW 贪心调度器。
    将指令列表打包成尽可能少的 Instruction Bundles。
    """
    instrs = []

    # 当前正在打包的 bundle
    current_bundle = defaultdict(list)
    # 当前 bundle 已经写入的寄存器集合(用于检查 RAW 依赖)
    # 如果下一条指令要读这里面的东西,就必须由下一个 cycle 执行
    current_writes = set()

    for engine, slot in ops_list:
        # 1. 获取依赖
        reads, writes = get_rw_sets(engine, slot)

        # 2. 检查冲突
        conflict = False

        # A. 资源限制冲突
        if len(current_bundle[engine]) >= SLOT_LIMITS[engine]:
            conflict = True

        # B. 数据依赖冲突 (RAW: Read After Write)
        # 如果当前指令要读的变量,是当前包里前面的指令写入的,
        # 那么这会导致读取旧值(硬件行为),这通常不是编译器想要的逻辑。
        # 我们希望逻辑上是顺序执行,所以必须拆分到下一周期。
        if not conflict:
            if not reads.isdisjoint(current_writes):
                conflict = True

        # C. 输出依赖冲突 (WAW: Write After Write)
        # 极其罕见,但如果两条指令同时写同一个地址,结果是未定义的或竞争的。
        if not conflict:
            if not writes.isdisjoint(current_writes):
                conflict = True

        # 3. 处理冲突
        if conflict:
            # 提交当前包
            instrs.append(dict(current_bundle))
            # 开启新包
            current_bundle = defaultdict(list)
            current_writes = set()

        # 4. 加入当前包
        current_bundle[engine].append(slot)
        current_writes.update(writes)

    # 提交最后一个包
    if current_bundle:
        instrs.append(dict(current_bundle))

    return instrs

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  15424
Speedup over baseline:  9.578189834024895
.
----------------------------------------------------------------------
Ran 1 test in 0.316s

OK

软件流水线与指令交错
#

现在初步完成了SIMD和VLIW的优化,理论上一次处理8个数据,以及调度器可以让一部分指令并行。我们不妨计算一下:

  1. Load瓶颈:
    • 每个Batch(8个数据)需要做10次Load操作(2次连续vload+8次离散load_offset)。
    • Load槽位限制是2
    • 耗时:\(10/2 = 5\ \text{cycles}\)
  2. ALU瓶颈:
    • Hash计算+索引更新大约有25个向量运算。
    • VALU槽位限制是6。
    • 耗时:\(25/6 \approx 4.166\ \text{cycles}\)

那么理论上每个Batch最快只需要5个周期,但实际是\(15424 \div (16 \times 32) \approx 30\ \text{个周期/Batch}\)。这中间的问题在于目前的指令发射模式本质上还是批量串行的。即:

发Load指令 -> 等待内存返回 (气泡) -> 发ALU指令 -> 等待计算结果 (气泡) -> 发Store指令

目前的贪心调度器每次只拿一个Batch的指令,所以它没有未来的指令可以调度。那么接下来要优化的就是一次性处理多个Batch。

我们将循环展开四次,这样我们手上就同时有4个批次(32个数据)的任务,然后通过让这四个批次进行指令交错,实现Load和Compute的并行。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def build_kernel(
        self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
    ):

    # 前面省略

    # --- 循环展开设置 ---
    UNROLL = 4  # 4 x 8 = 32 elements per step

    # 分配寄存器组
    vec_idxs = [self.alloc_scratch(f"vec_idx_{u}", 8) for u in range(UNROLL)]
    vec_vals = [self.alloc_scratch(f"vec_val_{u}", 8) for u in range(UNROLL)]
    vec_node_vals = [self.alloc_scratch(f"vec_node_val_{u}", 8) for u in range(UNROLL)]
    vec_addrs = [self.alloc_scratch(f"vec_addr_{u}", 8) for u in range(UNROLL)]

    vec_tmp1s = [self.alloc_scratch(f"vec_tmp1_{u}", 8) for u in range(UNROLL)]
    vec_tmp2s = [self.alloc_scratch(f"vec_tmp2_{u}", 8) for u in range(UNROLL)]
    vec_tmp3s = [self.alloc_scratch(f"vec_tmp3_{u}", 8) for u in range(UNROLL)]

    # 独立的标量地址寄存器 (防止 WAW 冲突)
    tmp_addrs = [self.alloc_scratch(f"tmp_addr_{u}") for u in range(UNROLL)]

    for round in range(rounds):
        for i in range(0, batch_size, VLEN * UNROLL):

            # 1. Load Indices & Values (Phase 1)
            # 交错发射:把所有 Load 指令排在一起
            for u in range(UNROLL):
                curr_i = i + u * VLEN
                i_const = self.scratch_const(curr_i)
                body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
                body.append(("load", ("vload", vec_idxs[u], tmp_addrs[u])))
                # Debug 紧跟 Load
                for vi in range(8):
                    body.append(("debug", ("compare", vec_idxs[u] + vi, (round, curr_i + vi, "idx"))))

            for u in range(UNROLL):
                curr_i = i + u * VLEN
                i_const = self.scratch_const(curr_i)
                body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
                body.append(("load", ("vload", vec_vals[u], tmp_addrs[u])))
                for vi in range(8):
                    body.append(("debug", ("compare", vec_vals[u] + vi, (round, curr_i + vi, "val"))))

            # 2. Gather Address Calculation
            for u in range(UNROLL):
                body.append(("valu", ("+", vec_addrs[u], vec_forest_p, vec_idxs[u])))

            # 3. Gather Load (Phase 2)
            # 极度交错:按 Offset 循环,让 Load 单元尽可能满载
            for offset in range(VLEN):
                for u in range(UNROLL):
                    body.append(("load", ("load_offset", vec_node_vals[u], vec_addrs[u], offset)))

            # Debug Gather 结果
            for u in range(UNROLL):
                curr_i = i + u * VLEN
                for vi in range(8):
                    body.append(("debug", ("compare", vec_node_vals[u] + vi, (round, curr_i + vi, "node_val"))))

            # 4. Hash Computation (Compute Bottleneck)
            # 初始 XOR
            for u in range(UNROLL):
                body.append(("valu", ("^", vec_vals[u], vec_vals[u], vec_node_vals[u])))

            # 关键优化:按 Stage 交错展开哈希计算
            # 这样 Batch 0 在做 Op1/Op3 时,Batch 1 也在做 Op1/Op3,填满 6 个 VALU 槽位
            for hi in range(len(HASH_STAGES)):
                const_info = hash_vec_consts[hi]

                for u in range(UNROLL):
                    # Op1 & Op3 (Parallel)
                    body.append(("valu", (const_info['op1'], vec_tmp1s[u], vec_vals[u], const_info['vec_const1'])))
                    body.append(("valu", (const_info['op3'], vec_tmp2s[u], vec_vals[u], const_info['vec_const3'])))

                for u in range(UNROLL):
                    # Op2 (Dependent)
                    body.append(("valu", (const_info['op2'], vec_vals[u], vec_tmp1s[u], vec_tmp2s[u])))

                # Debug Hash Stage
                for u in range(UNROLL):
                    curr_i = i + u * VLEN
                    for vi in range(8):
                        body.append(("debug", ("compare", vec_vals[u] + vi, (round, curr_i + vi, "hash_stage", hi))))

            # Final Hash Debug
            for u in range(UNROLL):
                curr_i = i + u * VLEN
                for vi in range(8):
                    body.append(("debug", ("compare", vec_vals[u] + vi, (round, curr_i + vi, "hashed_val"))))

            # 5. Logic Update
            # 按步骤交错,确保 ALU 利用率
            for u in range(UNROLL):
                body.append(("valu", ("%", vec_tmp1s[u], vec_vals[u], vec_two)))

            for u in range(UNROLL):
                body.append(("valu", ("==", vec_tmp1s[u], vec_tmp1s[u], vec_zero)))

            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_tmp3s[u], vec_tmp1s[u], vec_one, vec_two)))

            for u in range(UNROLL):
                body.append(("valu", ("*", vec_idxs[u], vec_idxs[u], vec_two)))

            for u in range(UNROLL):
                body.append(("valu", ("+", vec_idxs[u], vec_idxs[u], vec_tmp3s[u])))

            for u in range(UNROLL):
                curr_i = i + u * VLEN
                for vi in range(8):
                    body.append(("debug", ("compare", vec_idxs[u] + vi, (round, curr_i + vi, "next_idx"))))

            # 6. Boundary Check
            for u in range(UNROLL):
                body.append(("valu", ("<", vec_tmp1s[u], vec_idxs[u], vec_n_nodes)))

            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_idxs[u], vec_tmp1s[u], vec_idxs[u], vec_zero)))

            for u in range(UNROLL):
                curr_i = i + u * VLEN
                for vi in range(8):
                    body.append(("debug", ("compare", vec_idxs[u] + vi, (round, curr_i + vi, "wrapped_idx"))))

            # 7. Store
            for u in range(UNROLL):
                curr_i = i + u * VLEN
                i_const = self.scratch_const(curr_i)

                body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
                body.append(("store", ("vstore", tmp_addrs[u], vec_idxs[u])))

                body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
                body.append(("store", ("vstore", tmp_addrs[u], vec_vals[u])))

    # 后面省略

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  9280
Speedup over baseline:  15.919612068965517
.
----------------------------------------------------------------------
Ran 1 test in 0.316s

OK

指令融合与循环展开
#

原来哈希函数中有步骤是val = val + C1,然后val = val + (val << shift),这至少需要3条指令,并且它们之间有数据依赖,但可以注意到:

\(x + C + (x \ll S)\) 其实等价于线性方程: \(x \cdot (1 + 2^S) + C\)

而在problem.py中刚好有一条指令multiply_add(dest, a, b, c),它可以在一个周期内完成\(a \cdot b + c\)。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def build_hash_vectorized_optimized(self, vec_val, vec_tmp1, vec_tmp2, round, i, hash_vec_consts):
        """
        优化哈希函数:
        - 识别 Stage 0, 2, 4 并替换为 multiply_add
        - 其他 Stage 保持并行发射
        """
        slots = []

        for hi in range(len(HASH_STAGES)):
            const_info = hash_vec_consts[hi]
            is_mac = const_info['is_mac']

            if is_mac:
                # 融合指令:val = val * mul + add
                # multiply_add(dest, a, b, c) -> dest = a*b + c
                slots.append(("valu", ("multiply_add", vec_val, vec_val, const_info['vec_mul'], const_info['vec_add'])))
            else:
                # 普通指令,尽量并行
                op1, op2, op3 = const_info['op1'], const_info['op2'], const_info['op3']
                slots.append(("valu", (op1, vec_tmp1, vec_val, const_info['vec_const1'])))
                slots.append(("valu", (op3, vec_tmp2, vec_val, const_info['vec_const3'])))
                slots.append(("valu", (op2, vec_val, vec_tmp1, vec_tmp2)))

            # Debug
            for vi in range(8):
                slots.append(("debug", ("compare", vec_val + vi, (round, i + vi, "hash_stage", hi))))

        return slots

然后我还把原先一次发射4组数据的Load请求改成了一次发射16组数据的。

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  7711
Speedup over baseline:  19.1588639605758
.
----------------------------------------------------------------------
Ran 1 test in 0.267s

OK

循环互换
#

之前的优化虽然利用了循环展开,但在结构上依然是Loop(Rounds) -> Loop(Batch),每一轮都要进行读写,那么不妨把结构翻转成Loop(Batch) -> Loop(Rounds)。对于每一批的128个数据,一次性读入indicesvalues到寄存器中,在寄存器里跑完16轮迭代,然后一次性写回结果。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def build_kernel(
        self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
    ):
    # 前面省略

    # === 优化:Batch Loop 外层,Round Loop 内层 ===
    UNROLL = 16 # Process 128 elements (40% of register file)

    vec_idxs = [self.alloc_scratch(f"vec_idx_{u}", 8) for u in range(UNROLL)]
    vec_vals = [self.alloc_scratch(f"vec_val_{u}", 8) for u in range(UNROLL)]
    vec_node_vals = [self.alloc_scratch(f"vec_node_val_{u}", 8) for u in range(UNROLL)]
    vec_addrs = [self.alloc_scratch(f"vec_addr_{u}", 8) for u in range(UNROLL)]
    vec_tmp1s = [self.alloc_scratch(f"vec_tmp1_{u}", 8) for u in range(UNROLL)]
    vec_tmp2s = [self.alloc_scratch(f"vec_tmp2_{u}", 8) for u in range(UNROLL)]
    vec_tmp3s = [self.alloc_scratch(f"vec_tmp3_{u}", 8) for u in range(UNROLL)]
    tmp_addrs = [self.alloc_scratch(f"tmp_addr_{u}") for u in range(UNROLL)]

    # 1. 遍历 Batch (Step = 128)
    for i in range(0, batch_size, VLEN * UNROLL):

        # --- Load Phase (Only Once) ---
        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
            body.append(("load", ("vload", vec_idxs[u], tmp_addrs[u])))

        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
            body.append(("load", ("vload", vec_vals[u], tmp_addrs[u])))

        # 2. Loop Rounds (Inside Registers)
        for round in range(rounds):
            # a. Calc Addr
            for u in range(UNROLL):
                body.append(("valu", ("+", vec_addrs[u], vec_forest_p, vec_idxs[u])))

            # b. Gather Load
            for offset in range(VLEN):
                for u in range(UNROLL):
                    body.append(("load", ("load_offset", vec_node_vals[u], vec_addrs[u], offset)))

            # c. Hash (MAC Optimized)
            for u in range(UNROLL):
                body.append(("valu", ("^", vec_vals[u], vec_vals[u], vec_node_vals[u])))

            for hi in range(len(HASH_STAGES)):
                const_info = hash_vec_consts[hi]
                if const_info['is_mac']:
                    for u in range(UNROLL):
                        body.append(("valu", ("multiply_add", vec_vals[u], vec_vals[u], const_info['vec_mul'], const_info['vec_add'])))
                else:
                    for u in range(UNROLL):
                        body.append(("valu", (const_info['op1'], vec_tmp1s[u], vec_vals[u], const_info['vec_const1'])))
                        body.append(("valu", (const_info['op3'], vec_tmp2s[u], vec_vals[u], const_info['vec_const3'])))
                    for u in range(UNROLL):
                        body.append(("valu", (const_info['op2'], vec_vals[u], vec_tmp1s[u], vec_tmp2s[u])))

            # d. Update Logic
            for u in range(UNROLL):
                body.append(("valu", ("%", vec_tmp1s[u], vec_vals[u], vec_two)))
            for u in range(UNROLL):
                body.append(("valu", ("==", vec_tmp1s[u], vec_tmp1s[u], vec_zero)))
            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_tmp3s[u], vec_tmp1s[u], vec_one, vec_two)))
            for u in range(UNROLL):
                body.append(("valu", ("multiply_add", vec_idxs[u], vec_idxs[u], vec_two, vec_tmp3s[u])))

            # e. Boundary Check
            for u in range(UNROLL):
                body.append(("valu", ("<", vec_tmp1s[u], vec_idxs[u], vec_n_nodes)))
            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_idxs[u], vec_tmp1s[u], vec_idxs[u], vec_zero)))

        # --- Store Phase (Only Once) ---
        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
            body.append(("store", ("vstore", tmp_addrs[u], vec_idxs[u])))
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
            body.append(("store", ("vstore", tmp_addrs[u], vec_vals[u])))

    # 后面省略

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  4638
Speedup over baseline:  31.852953859422165
.
----------------------------------------------------------------------
Ran 1 test in 0.131s

OK

寄存器缓存
#

UNROLL=16的情况下,Load单元(每周期2个)就成了绝对瓶颈,那么我们这里的优化思路就是把树的前几层节点缓存到寄存器中,而更多的深层节点再用load_offset方式。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def build_kernel(
        self, forest_height: int, n_nodes: int, batch_size: int, rounds: int
    ):
    # 前面省略

    # === 缓存树的前 3 层 (Level 0, 1, 2) ===
    vec_cached_tree = self.alloc_scratch("vec_cached_tree", length=8)
    body.append(("load", ("vload", vec_cached_tree, self.scratch["forest_values_p"])))

    vec_c0 = self.alloc_scratch("vc0", 8)
    vec_c1 = self.alloc_scratch("vc1", 8)
    vec_c2 = self.alloc_scratch("vc2", 8)
    vec_c3 = self.alloc_scratch("vc3", 8)
    vec_c4 = self.alloc_scratch("vc4", 8)
    vec_c5 = self.alloc_scratch("vc5", 8)
    vec_c6 = self.alloc_scratch("vc6", 8)

    body.append(("valu", ("vbroadcast", vec_c0, vec_cached_tree + 0)))
    body.append(("valu", ("vbroadcast", vec_c1, vec_cached_tree + 1)))
    body.append(("valu", ("vbroadcast", vec_c2, vec_cached_tree + 2)))
    body.append(("valu", ("vbroadcast", vec_c3, vec_cached_tree + 3)))
    body.append(("valu", ("vbroadcast", vec_c4, vec_cached_tree + 4)))
    body.append(("valu", ("vbroadcast", vec_c5, vec_cached_tree + 5)))
    body.append(("valu", ("vbroadcast", vec_c6, vec_cached_tree + 6)))

    # === 预计算哈希常量 (MAC优化) ===
    hash_vec_consts = {}
    for hi, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
        is_mac = (op1 == "+" and op2 == "+" and op3 == "<<")
        info = {'is_mac': is_mac, 'op1': op1, 'op2': op2, 'op3': op3}
        if is_mac:
            mul_val = 1 + (1 << val3)
            vec_mul = self.alloc_scratch(f"hash_mul_{hi}", length=8)
            vec_add = self.alloc_scratch(f"hash_add_{hi}", length=8)
            body.append(("valu", ("vbroadcast", vec_mul, self.scratch_const(mul_val))))
            body.append(("valu", ("vbroadcast", vec_add, self.scratch_const(val1))))
            info['vec_mul'] = vec_mul
            info['vec_add'] = vec_add
        else:
            vec_const1 = self.alloc_scratch(f"hash_c1_{hi}", length=8)
            vec_const3 = self.alloc_scratch(f"hash_c3_{hi}", length=8)
            body.append(("valu", ("vbroadcast", vec_const1, self.scratch_const(val1))))
            body.append(("valu", ("vbroadcast", vec_const3, self.scratch_const(val3))))
            info['vec_const1'] = vec_const1
            info['vec_const3'] = vec_const3
        hash_vec_consts[hi] = info

    # === 终极循环设置: Batch Outer Loop, Round Inner Loop ===
    UNROLL = 16
    vec_idxs = [self.alloc_scratch(f"vec_idx_{u}", 8) for u in range(UNROLL)]
    vec_vals = [self.alloc_scratch(f"vec_val_{u}", 8) for u in range(UNROLL)]
    vec_node_vals = [self.alloc_scratch(f"vec_node_val_{u}", 8) for u in range(UNROLL)]
    vec_addrs = [self.alloc_scratch(f"vec_addr_{u}", 8) for u in range(UNROLL)]

    vec_tmp1s = [self.alloc_scratch(f"vec_tmp1_{u}", 8) for u in range(UNROLL)]
    vec_tmp2s = [self.alloc_scratch(f"vec_tmp2_{u}", 8) for u in range(UNROLL)]
    vec_tmp3s = [self.alloc_scratch(f"vec_tmp3_{u}", 8) for u in range(UNROLL)]

    tmp_addrs = [self.alloc_scratch(f"tmp_addr_{u}") for u in range(UNROLL)]

    # --- Batch Loop ---
    for i in range(0, batch_size, VLEN * UNROLL):

        # 1. Load Batch (Once)
        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
            body.append(("load", ("vload", vec_idxs[u], tmp_addrs[u])))
        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
            body.append(("load", ("vload", vec_vals[u], tmp_addrs[u])))

        # 2. Round Loop (In Registers)
        for round in range(rounds):
            depth = round % (forest_height + 1)

            # --- Step A: Get Node Value ---
            if depth == 0: # Level 0
                for u in range(UNROLL):
                    body.append(("valu", ("+", vec_node_vals[u], vec_c0, vec_zero)))
            elif depth == 1: # Level 1
                for u in range(UNROLL):
                    body.append(("valu", ("==", vec_tmp1s[u], vec_idxs[u], vec_one)))
                    body.append(("flow", ("vselect", vec_node_vals[u], vec_tmp1s[u], vec_c1, vec_c2)))
            elif depth == 2: # Level 2
                for u in range(UNROLL):
                    # Cascade select for 4 nodes
                    body.append(("valu", ("==", vec_tmp1s[u], vec_idxs[u], vec_three)))
                    body.append(("flow", ("vselect", vec_tmp2s[u], vec_tmp1s[u], vec_c3, vec_c4)))
                    body.append(("valu", ("==", vec_tmp1s[u], vec_idxs[u], vec_five)))
                    body.append(("flow", ("vselect", vec_tmp3s[u], vec_tmp1s[u], vec_c5, vec_c6)))
                    body.append(("valu", ("<", vec_tmp1s[u], vec_idxs[u], vec_five)))
                    body.append(("flow", ("vselect", vec_node_vals[u], vec_tmp1s[u], vec_tmp2s[u], vec_tmp3s[u])))
            else: # Regular Gather
                for u in range(UNROLL):
                    body.append(("valu", ("+", vec_addrs[u], vec_forest_p, vec_idxs[u])))
                for offset in range(VLEN):
                    for u in range(UNROLL):
                        body.append(("load", ("load_offset", vec_node_vals[u], vec_addrs[u], offset)))

            # --- Step B: Hash ---
            for u in range(UNROLL):
                body.append(("valu", ("^", vec_vals[u], vec_vals[u], vec_node_vals[u])))

            for hi in range(len(HASH_STAGES)):
                const_info = hash_vec_consts[hi]
                if const_info['is_mac']:
                    for u in range(UNROLL):
                        body.append(("valu", ("multiply_add", vec_vals[u], vec_vals[u], const_info['vec_mul'], const_info['vec_add'])))
                else:
                    for u in range(UNROLL):
                        body.append(("valu", (const_info['op1'], vec_tmp1s[u], vec_vals[u], const_info['vec_const1'])))
                        body.append(("valu", (const_info['op3'], vec_tmp2s[u], vec_vals[u], const_info['vec_const3'])))
                    for u in range(UNROLL):
                        body.append(("valu", (const_info['op2'], vec_vals[u], vec_tmp1s[u], vec_tmp2s[u])))

            # --- Step C: Update ---
            for u in range(UNROLL):
                body.append(("valu", ("%", vec_tmp1s[u], vec_vals[u], vec_two)))
            for u in range(UNROLL):
                body.append(("valu", ("==", vec_tmp1s[u], vec_tmp1s[u], vec_zero)))
            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_tmp3s[u], vec_tmp1s[u], vec_one, vec_two)))
            for u in range(UNROLL):
                body.append(("valu", ("multiply_add", vec_idxs[u], vec_idxs[u], vec_two, vec_tmp3s[u])))

            # --- Step D: Boundary ---
            for u in range(UNROLL):
                body.append(("valu", ("<", vec_tmp1s[u], vec_idxs[u], vec_n_nodes)))
            for u in range(UNROLL):
                body.append(("flow", ("vselect", vec_idxs[u], vec_tmp1s[u], vec_idxs[u], vec_zero)))

        # 3. Store Batch (Once)
        for u in range(UNROLL):
            curr_i = i + u * VLEN
            i_const = self.scratch_const(curr_i)
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_indices_p"], i_const)))
            body.append(("store", ("vstore", tmp_addrs[u], vec_idxs[u])))
            body.append(("alu", ("+", tmp_addrs[u], self.scratch["inp_values_p"], i_const)))
            body.append(("store", ("vstore", tmp_addrs[u], vec_vals[u])))

    # 后面省略

结果:

1
2
3
4
5
6
7
8
9
PS F:\original_performance_takehome> python perf_takehome.py Tests.test_kernel_cycles
forest_height=10, rounds=16, batch_size=256
CYCLES:  4131
Speedup over baseline:  35.762285160977974
.
----------------------------------------------------------------------
Ran 1 test in 0.131s

OK
tinuvile
作者
tinuvile
一个笨小孩