···305305 gir::Module mod;
306306 gir::Builder gb(mod);
307307 auto rp = gb.get_root_ptr();
308308- auto x = gb.load(gb.add(rp, gb.mul(gb.get_local_invocation_id(), gb.i32(4))));
308308+ auto p = gb.add(rp, gb.mul(gb.get_local_invocation_id(), gb.i32(4)));
309309+ auto x = gb.load(p);
309310 auto sum = gb.add(x, gb.i32(15));
310310- gb.store(x, sum);
311311+ gb.store(p, sum);
311312312313 rdna2_compile(mod, alloc.cpu, alloc.gpu);
313314 }
+230-82
drivers/amdgpu/compiler/compiler.cpp
···66#include <iomanip>
77#include <string>
88#include <fstream>
99+#include <optional>
9101011using namespace gir;
1112···2930};
30313132void lower_simple(Compiler &);
3333+void lower_memory_loads(Compiler &);
3234void analyze_uniformity(Compiler &);
3335void allocate_registers(Compiler &);
3436void codegen(Compiler &);
35373638enum class AmdIntrinsics : uint32_t {
3737- GlobalLoadDword,
3838- GlobalLoadDwordAddTI, // 12-bit imm offset, saddr, addr
3939+ GlobalLoadDwordAddTID_Scale4,
4040+ GlobalStoreDwordAddTID_Scale4,
3941};
40424343+void tmp_dump_shader(uint32_t *data, size_t code_size_bytes) {
4444+ std::stringstream ss;
4545+ ss << "shader_tmp.bin";
4646+ std::string filename = ss.str();
4747+4848+ std::ofstream outfile(filename, std::ios::out | std::ios::binary);
4949+ if (outfile.is_open()) {
5050+ outfile.write(reinterpret_cast<const char*>(data), code_size_bytes);
5151+ outfile.close();
5252+ }
5353+ log("shader written to {}", filename);
5454+}
5555+4156void rdna2_compile(gir::Module &mod, void *write_ptr, uint64_t base_addr) {
4257 Compiler compiler(mod);
43585959+ gir::pass_normalize(mod);
6060+4461 lower_simple(compiler);
6262+ lower_memory_loads(compiler);
4563 analyze_uniformity(compiler);
6464+6565+ gir::pass_eliminate_dead_code(mod);
6666+4667 allocate_registers(compiler);
6868+4769 codegen(compiler);
48704971 auto code = compiler.as.values();
5072 auto code_size_bytes = code.size() * sizeof(uint32_t);
5173 memcpy(write_ptr, code.data(), code_size_bytes);
52745353- // dump the shader code to a file.
5454- // @todo: wip
5555- {
5656- std::stringstream ss;
5757- ss << "shader_tmp.bin";
5858- std::string filename = ss.str();
5959-6060- std::ofstream outfile(filename, std::ios::out | std::ios::binary);
6161- if (outfile.is_open()) {
6262- outfile.write(reinterpret_cast<const char*>(code.data()), code_size_bytes);
6363- outfile.close();
6464- }
6565- log("shader written to {}", filename);
6666- exit(0);
6767- }
7575+ tmp_dump_shader(code.data(), code_size_bytes);
6876}
69777078void lower_simple(Compiler &cc) {
···8088 // @todo: handle local_invocation_id.
8189 // There are many ways to do this, but I believe we need to lower it
8290 // into a pack operation of vgpr0,1,2. But I'm not entirely sure.
9191+ if (inst.op == gir::Op::GetLocalInvocationId) {
9292+ // @todo: stop assuming 1d dispatch.
9393+ inst.meta.phys_reg = 0; // vgpr0
9494+ inst.meta.is_uniform = false;
9595+ }
8396 }
8497}
85989999+struct AddressPattern {
100100+ Value base_ptr;
101101+ bool is_tid_scaled_by_4 = false;
102102+};
103103+104104+std::optional<AddressPattern> match_address_pattern(Compiler& cc, const Inst& addr) {
105105+ AddressPattern pat;
106106+107107+ // Match: ptr
108108+ if (addr.type == Type::Ptr && addr.op != Op::Add) {
109109+ pat.base_ptr = addr.operands[0];
110110+ return pat;
111111+ }
112112+113113+ // Match: ptr + offset
114114+ if (addr.op == Op::Add) {
115115+ auto& lhs = cc.mod.deref(addr.operands[0]);
116116+ auto& rhs = cc.mod.deref(addr.operands[1]);
117117+118118+ // After normalization, ptr should be on left
119119+ if (lhs.type != Type::Ptr) {
120120+ log("not normalized?");
121121+ return std::nullopt; // Shouldn't happen after normalization
122122+ }
123123+124124+ pat.base_ptr = addr.operands[0];
125125+126126+ // Check if offset is tid * 4
127127+ if (rhs.op == Op::Mul) {
128128+ auto& mul_lhs = cc.mod.deref(rhs.operands[0]);
129129+ auto& mul_rhs = cc.mod.deref(rhs.operands[1]);
130130+131131+ if (mul_lhs.op == Op::GetLocalInvocationId &&
132132+ mul_rhs.op == Op::Const &&
133133+ mul_rhs.data.imm_i64 == 4) {
134134+ pat.is_tid_scaled_by_4 = true;
135135+ return pat;
136136+ }
137137+ }
138138+139139+ // Other offset patterns not yet supported
140140+ log("other offset pattern?");
141141+ return std::nullopt;
142142+ }
143143+144144+ log("what? op: {}", (int)addr.op);
145145+ return std::nullopt;
146146+}
147147+86148void lower_memory_loads(Compiler &cc) {
87149 // device memory loads should become global_load_dword or similar.
88150 // these kinds of instructions support a base ptr + imm offset or
89151 // base sgpr ptr + vgpr offset.
901529191- // if we detect such a pattern we can replace with these opcodes.
92153 // global_load_dword: saddr + voff (+ imm offset)
93154 // global_load_dword: vaddr (+ imm offset)
94155 // global_load_dword_addtid: saddr (+ imm offset) + 4 * local_invocation_id
95156 for (uint32_t i = 0; i < cc.mod.insts.size(); ++i) {
96157 auto &inst = cc.mod.insts[i];
9797- if (inst.op == gir::Op::Load) {
9898- auto addr = cc.mod.deref(inst.operands[0]);
9999-100100- if (addr.meta.is_uniform) {
101101- not_implemented("lower_memory_loads: cannot handle Op::Load with uniform address");
158158+ if (inst.op == Op::Load) {
159159+ auto& addr = cc.mod.deref(inst.operands[0]);
160160+ auto pat = match_address_pattern(cc, addr);
161161+ if (!pat) {
162162+ not_implemented("lower_memory_loads: unsupported load address pattern");
102163 }
103164104104- if (addr.op == gir::Op::Add) {
105105- // we have detected an offset!
106106- // @todo: I think we need some form of canonicalization
107107- // here so the check can be more trivial.
165165+ auto& base = cc.mod.deref(pat->base_ptr);
108166109109- auto lhs = cc.mod.deref(addr.operands[0]);
110110- auto rhs = cc.mod.deref(addr.operands[1]);
167167+ if (!base.meta.is_uniform) {
168168+ not_implemented("lower_memory_loads: non-uniform base pointer in Load not yet supported");
169169+ }
111170112112- assert(lhs.type == gir::Type::Ptr, "lower_memory_loads: invalid operand in load(x + y)");
171171+ if (pat->is_tid_scaled_by_4) {
172172+ inst.op = Op::BackendIntrinsic;
173173+ inst.intrinsic_id = (uint32_t)AmdIntrinsics::GlobalLoadDwordAddTID_Scale4;
174174+ inst.operands = {pat->base_ptr};
175175+ } else {
176176+ not_implemented("lower_memory_loads: simple loads not yet implemented");
177177+ }
178178+ } else if (inst.op == Op::Store) {
179179+ auto& addr = cc.mod.deref(inst.operands[0]);
180180+ auto pat = match_address_pattern(cc, addr);
181181+ if (!pat) {
182182+ not_implemented("lower_memory_loads: unsupported Store address pattern");
183183+ }
113184114114- if (rhs.op == gir::Op::Mul) {
115115- auto lhs2 = cc.mod.deref(rhs.operands[0]);
116116- auto rhs2 = cc.mod.deref(rhs.operands[1]);
185185+ auto& base = cc.mod.deref(pat->base_ptr);
117186118118- if (lhs2.op == gir::Op::GetLocalInvocationId && rhs2.op == gir::Op::Const && rhs2.data.imm_i64 == 4) {
119119- // replace instruction
120120- auto args = std::vector<Value>{lhs, rhs2};
121121- inst = gir::Inst{
122122- .op = gir::Op::BackendIntrinsic,
123123- .type = gir::Type::I32,
124124- .operands = args,
125125- .intrinsic_id = AmdIntrinsics::GlobalLoadDwordAddTI
126126- }
127127- }
128128- }
187187+ if (!base.meta.is_uniform) {
188188+ not_implemented("lower_memory_loads: non-uniform base pointer in Store not yet supported");
189189+ }
129190191191+ if (pat->is_tid_scaled_by_4) {
192192+ inst.op = Op::BackendIntrinsic;
193193+ inst.intrinsic_id = (uint32_t)AmdIntrinsics::GlobalStoreDwordAddTID_Scale4;
194194+ inst.operands = {pat->base_ptr, inst.operands[1]};
130195 } else {
131131-196196+ not_implemented("lower_memory_loads: simple stores not yet implemented");
132197 }
133198 }
134199 }
···211276}
212277213278void codegen(Compiler &cc) {
214214-215279 for (auto &inst : cc.mod.insts) {
216280 switch (inst.op) {
217281 case gir::Op::BackendIntrinsic: {
218282 switch(inst.intrinsic_id) {
219219- case (uint32_t)AmdIntrinsics::GlobalLoadDwordAddTI: {
220220- // @todo: support offset constants
221221- //assert(cc.mod.deref(inst.operands[0]).op == gir::Op::Const, "offset must be const");
222222- //auto offset = mod.deref(inst.operands[0]).data.imm_i64;
223223- auto offset = 0;
283283+ case (uint32_t)AmdIntrinsics::GlobalLoadDwordAddTID_Scale4: {
284284+ auto saddr = get_ssrc(cc, inst.operands[0]);
285285+286286+ // @todo: how do we know what to do about the cache flags?
287287+ cc.as.global(
288288+ RDNA2Assembler::global_opcode::global_load_dword_addtid,
289289+ false, false, false, false,
290290+ 0, // 12-bit immediate offset (0 for now)
291291+ 0, // vaddr (0 = use addtid mode)
292292+ (uint8_t)saddr, // saddr base pointer
293293+ inst.meta.phys_reg, // vdst destination register
294294+ 0 // unused in addtid mode
295295+ );
296296+297297+ // @todo: wait for load to complete. this is very conservative
298298+ cc.as.sopp(RDNA2Assembler::sopp_opcode::s_waitcnt, 0x3F70);
299299+ } break;
300300+ case (uint32_t)AmdIntrinsics::GlobalStoreDwordAddTID_Scale4: {
301301+ auto saddr = get_ssrc(cc, inst.operands[0]);
302302+ auto& data = cc.mod.deref(inst.operands[1]);
224303225225- auto saddr = get_ssrc(cc, inst.operands[1]);
226226- auto addr = get_vsrc(cc, inst.operands[2]);
227227- cc.as.global(RDNA2Assembler::global_opcode::global_load_dword_addtid, false, false, false, false,
228228- offset, 0, (uint8_t)saddr, inst.meta.phys_reg, (uint8_t)addr
304304+ if (data.meta.is_uniform) {
305305+ not_implemented("codegen: GlobalStoreDwordAddTI with uniform data (need v_mov)");
306306+ }
307307+308308+ cc.as.global(
309309+ RDNA2Assembler::global_opcode::global_store_dword_addtid,
310310+ false, false, false, false,
311311+ 0, 0, (uint8_t)saddr, data.meta.phys_reg, 0
229312 );
230313 } break;
314314+ default:
315315+ not_implemented("codegen: unknown backend intrinsic: {}", inst.intrinsic_id);
231316 }
232317 } break;
233233- }
234234- }
235318236236- /*
237237- for (auto& inst : mod.insts) {
238238- switch (inst.op) {
239239- case ADD:
240240- if (mod.values[inst.dest.id].is_uniform)
241241- as.sop2(sop2_opcode::s_add_u32, mod.values[inst.dest.id].phys_reg,
242242- mod.values[inst.args[0].id].phys_reg, mod.values[inst.args[1].id].phys_reg);
243243- else
244244- as.vop2(vop2_opcode::v_add_nc_u32, mod.values[inst.dest.id].phys_reg,
245245- mod.values[inst.args[0].id].phys_reg, mod.values[inst.args[1].id].phys_reg);
246246- break;
247247- case LOAD_GLOBAL:
248248- as.global(global_opcode::global_load_dword, inst.imm,
249249- mod.values[inst.dest.id].phys_reg, mod.values[inst.args[0].id].phys_reg, 0);
250250- break;
251251- case STORE_GLOBAL:
252252- as.global(global_opcode::global_store_dword, inst.imm,
253253- 0, mod.values[inst.args[0].id].phys_reg, mod.values[inst.args[2].id].phys_reg);
254254- break;
255255- case V_MOV_S2V:
256256- as.vop2(vop2_opcode::v_mov_b32, mod.values[inst.dest.id].phys_reg,
257257- mod.values[inst.args[0].id].phys_reg, 0);
258258- break;
319319+ case gir::Op::Store: {
320320+ // @todo: we currently assume all stores are global, but this may not be the case.
321321+ // I am not sure how NIR handles this, nor how other backends have local caches (LDS & GDS).
322322+ auto& addr = cc.mod.deref(inst.operands[0]);
323323+ auto& data = cc.mod.deref(inst.operands[1]);
324324+325325+ if (!addr.meta.is_uniform) {
326326+ not_implemented("codegen: Store with non-uniform address not yet supported");
327327+ }
328328+329329+ if (data.meta.is_uniform) {
330330+ not_implemented("codegen: Store with uniform data not yet supported (need v_mov to copy sgpr to vgpr)");
331331+ }
332332+333333+ if (data.type != gir::Type::I32 && data.type != gir::Type::F32) {
334334+ not_implemented("codegen: Store only supports I32/F32 for now");
335335+ }
336336+337337+ // global_store_dword: saddr + vdata
338338+ auto saddr = get_ssrc(cc, inst.operands[0]);
339339+340340+ cc.as.global(
341341+ RDNA2Assembler::global_opcode::global_store_dword,
342342+ true, true, false, true,
343343+ 0, // 12-bit immediate offset
344344+ 0, // vdst (unused for stores)
345345+ (uint8_t)saddr, // saddr base pointer
346346+ data.meta.phys_reg, // vdata - data to store
347347+ 0 // vaddr (0 = use saddr only)
348348+ );
349349+ } break;
350350+351351+ case gir::Op::Add: {
352352+ if (inst.type == gir::Type::I32) {
353353+ if (inst.meta.is_uniform) {
354354+ // Scalar add: s_add_u32
355355+ auto src0 = get_ssrc(cc, inst.operands[0]);
356356+ auto src1 = get_ssrc(cc, inst.operands[1]);
357357+ cc.as.sop2(
358358+ RDNA2Assembler::sop2_opcode::s_add_u32,
359359+ (RDNA2Assembler::ssrc)inst.meta.phys_reg,
360360+ src0,
361361+ src1
362362+ );
363363+ } else {
364364+ // Vector add: v_add_nc_u32 (non-carry version)
365365+ // vsrc1 MUST be a VGPR, src0 can be anything (SGPR, VGPR, const)
366366+ auto& op0 = cc.mod.deref(inst.operands[0]);
367367+ auto& op1 = cc.mod.deref(inst.operands[1]);
368368+369369+ // Ensure VGPR is in vsrc1 position by swapping if needed
370370+ bool op0_is_vgpr = !op0.meta.is_uniform && op0.op != gir::Op::Const;
371371+ bool op1_is_vgpr = !op1.meta.is_uniform && op1.op != gir::Op::Const;
372372+373373+ if (!op0_is_vgpr && !op1_is_vgpr) {
374374+ not_implemented("codegen: v_add_nc_u32 requires at least one VGPR operand");
375375+ }
376376+377377+ // Swap so VGPR is always in vsrc1 position
378378+ if (op0_is_vgpr && !op1_is_vgpr) {
379379+ cc.as.vop2(
380380+ RDNA2Assembler::vop2_opcode::v_add_nc_u32,
381381+ inst.meta.phys_reg,
382382+ get_vsrc(cc, inst.operands[1]), // src0: can be const/sgpr
383383+ op0.meta.phys_reg // vsrc1: VGPR
384384+ );
385385+ } else {
386386+ cc.as.vop2(
387387+ RDNA2Assembler::vop2_opcode::v_add_nc_u32,
388388+ inst.meta.phys_reg,
389389+ get_vsrc(cc, inst.operands[0]), // src0: can be const/sgpr/vgpr
390390+ op1.meta.phys_reg // vsrc1: VGPR
391391+ );
392392+ }
393393+ }
394394+ } else if (inst.type == gir::Type::Ptr) {
395395+ not_implemented("codegen: pointer addition (64-bit) not yet implemented");
396396+ } else {
397397+ not_implemented("codegen: Add not implemented for type: {}", (int)inst.type);
398398+ }
399399+ } break;
400400+401401+ case gir::Op::Const:
402402+ case gir::Op::GetRootPtr:
403403+ case gir::Op::GetLocalInvocationId:
404404+ // Skip metadata operations and constants
405405+ break;
406406+ default:
407407+ not_implemented("codegen: operation not yet implemented: {}", (int)inst.op);
408408+ break;
259409 }
260410 }
261261- */
262262-263411264412 cc.as.sopp(RDNA2Assembler::sopp_opcode::s_endpgm, 0);
265413
···11+#include "gir.h"
22+33+namespace gir {
44+55+// Canonicalize address computations so pointer is always on the left side of Add
66+void normalize_address_computation(Module& mod) {
77+ for (auto& inst : mod.insts) {
88+ if (inst.op != Op::Add || inst.type != Type::Ptr) continue;
99+1010+ auto& lhs = mod.deref(inst.operands[0]);
1111+ auto& rhs = mod.deref(inst.operands[1]);
1212+1313+ // Swap if pointer is on the right
1414+ if (lhs.type != Type::Ptr && rhs.type == Type::Ptr) {
1515+ std::swap(inst.operands[0], inst.operands[1]);
1616+ }
1717+ }
1818+}
1919+2020+void pass_normalize(Module &mod) {
2121+ normalize_address_computation(mod);
2222+}
2323+2424+void pass_eliminate_dead_code(Module& mod) {
2525+ std::vector<bool> is_live(mod.insts.size(), false);
2626+2727+ // Mark all instructions with side effects as live (roots)
2828+ for (size_t i = 0; i < mod.insts.size(); ++i) {
2929+ auto& inst = mod.insts[i];
3030+ // Instructions with side effects are roots
3131+ if (inst.op == Op::Store || inst.op == Op::BackendIntrinsic) {
3232+ is_live[i] = true;
3333+ }
3434+ }
3535+3636+ // Propagate liveness backwards through dependencies
3737+ // Keep iterating until no new instructions are marked live
3838+ bool changed = true;
3939+ while (changed) {
4040+ changed = false;
4141+ for (size_t i = 0; i < mod.insts.size(); ++i) {
4242+ if (!is_live[i]) continue;
4343+4444+ // Mark all operands of live instructions as live
4545+ for (auto& op : mod.insts[i].operands) {
4646+ if (op.is_inst() && !is_live[op.id]) {
4747+ is_live[op.id] = true;
4848+ changed = true;
4949+ }
5050+ }
5151+ }
5252+ }
5353+5454+ // Build new instruction list with value remapping
5555+ std::vector<Inst> new_insts;
5656+ std::vector<uint32_t> value_map(mod.insts.size());
5757+5858+ for (size_t i = 0; i < mod.insts.size(); ++i) {
5959+ if (is_live[i]) {
6060+ auto inst = mod.insts[i];
6161+6262+ // Remap operands to new instruction indices
6363+ for (auto& op : inst.operands) {
6464+ if (op.is_inst()) {
6565+ op.id = value_map[op.id];
6666+ }
6767+ }
6868+6969+ value_map[i] = new_insts.size();
7070+ new_insts.push_back(inst);
7171+ }
7272+ }
7373+7474+ mod.insts = std::move(new_insts);
7575+}
7676+7777+}