A Modern GPGPU API & wip linux RDNA2+ Driver
rdna driver linux gpu
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

gir+amdgpu: improve compiler

+263 -57
+8 -4
drivers/amdgpu/cmds.cpp
··· 284 284 assert(dev, "amdgpu_create_shader: device handle invalid: {}", (void *)dev); 285 285 assert(module, "amdgpu_create_shader: module handle invalid: {}", (void *)module); 286 286 287 + // Fixed for the Root Pointer ABI 288 + auto num_user_sgprs = 2; 289 + 287 290 auto shader = new Shader; 288 291 292 + CompileShaderInfo shdrcompinfo { 293 + .num_user_sgprs = num_user_sgprs 294 + }; 295 + 289 296 // @todo: ultra temporary. 290 297 auto alloc = amdgpu_malloc(dev, 1024, 256, KesMemoryDefault); 291 - rdna2_compile(*module, alloc.cpu, alloc.gpu); 298 + rdna2_compile(*module, shdrcompinfo, alloc.cpu, alloc.gpu); 292 299 shader->allocation = alloc; 293 300 294 301 log("shader code: {} {}", (void *)alloc.cpu, (void *)alloc.gpu); ··· 298 305 auto waves_per_threadgroup = 1; 299 306 auto max_waves_per_sh = 0x3FF; 300 307 auto threadgroups_per_cu = 1; 301 - 302 - // Fixed for the Root Pointer ABI 303 - auto num_user_sgprs = 2; 304 308 305 309 auto num_vgprs = 8; 306 310 auto num_sgprs = 8;
+62 -10
drivers/amdgpu/compiler/compiler.cpp
··· 27 27 gir::Module& mod; 28 28 RDNA2Assembler as; 29 29 30 + CompileShaderInfo shdr; 31 + 30 32 uint32_t sgpr_allocator = 6; 31 33 uint32_t vgpr_allocator = 3; 32 34 }; ··· 62 64 } 63 65 } 64 66 65 - void rdna2_compile(gir::Module &mod, void *write_ptr, uint64_t base_addr) { 67 + void rdna2_compile(gir::Module &mod, CompileShaderInfo shdr, void *write_ptr, uint64_t base_addr) { 66 68 Compiler compiler(mod); 69 + compiler.shdr = shdr; 67 70 68 71 gir::pass_normalize(mod); 69 72 lower_simple(compiler); ··· 86 89 void lower_simple(Compiler &cc) { 87 90 for (uint32_t i = 0; i < cc.mod.insts.size(); ++i) { 88 91 auto &inst = cc.mod.insts[i]; 89 - if (inst.op == gir::Op::GetRootPtr) { 92 + if (inst.op == gir::Op::RootPtr) { 90 93 // root pointer is passed as the user sgprs. 91 94 // we don't actually have to do anything. 92 95 inst.meta.phys_reg = 0; 93 96 inst.meta.is_uniform = true; 94 97 } 95 98 99 + if (inst.op == gir::Op::LocalInvocationIdX) { 100 + inst.meta.phys_reg = 0; 101 + inst.meta.is_uniform = false; 102 + } 103 + if (inst.op == gir::Op::LocalInvocationIdY) { 104 + inst.meta.phys_reg = 1; 105 + inst.meta.is_uniform = false; 106 + } 107 + if (inst.op == gir::Op::LocalInvocationIdZ) { 108 + inst.meta.phys_reg = 2; 109 + inst.meta.is_uniform = false; 110 + } 111 + 112 + if (inst.op == gir::Op::WorkgroupIdX) { 113 + inst.meta.phys_reg = cc.shdr.num_user_sgprs + 0; 114 + inst.meta.is_uniform = true; 115 + } 116 + if (inst.op == gir::Op::WorkgroupIdY) { 117 + inst.meta.phys_reg = cc.shdr.num_user_sgprs + 1; 118 + inst.meta.is_uniform = true; 119 + } 120 + if (inst.op == gir::Op::WorkgroupIdZ) { 121 + inst.meta.phys_reg = cc.shdr.num_user_sgprs + 2; 122 + inst.meta.is_uniform = true; 123 + } 124 + 96 125 // @todo: handle local_invocation_id. 97 126 // There are many ways to do this, but I believe we need to lower it 98 127 // into a pack operation of vgpr0,1,2. But I'm not entirely sure. 99 - if (inst.op == gir::Op::GetLocalInvocationId) { 100 - // @todo: stop assuming 1d dispatch. 101 - inst.meta.phys_reg = 0; // vgpr0 102 - inst.meta.is_uniform = false; 103 - } 128 + 129 + // @todo: handle global invocation ids. 104 130 } 105 131 } 106 132 ··· 136 162 auto& mul_lhs = cc.mod.deref(rhs.operands[0]); 137 163 auto& mul_rhs = cc.mod.deref(rhs.operands[1]); 138 164 139 - if (mul_lhs.op == Op::GetLocalInvocationId && 165 + if (mul_lhs.op == Op::LocalInvocationIndex && 140 166 mul_rhs.op == Op::Const && 141 167 mul_rhs.data.imm_i64 == 4) { 142 168 pat.is_tid_scaled_by_4 = true; ··· 404 430 not_implemented("codegen: Add not implemented for type: {}", (int)inst.type); 405 431 } 406 432 } break; 433 + case gir::Op::WorkgroupBarrier: { 434 + cc.as.sopp(RDNA2Assembler::sopp_opcode::s_barrier, 0); 435 + } break; 436 + case gir::Op::SubgroupBarrierInit: { 437 + // this initializes with the first active thread value in a vgpr. If we 438 + // have a scalar value first, we must move it. 439 + // 440 + // @todo: this MUST be an actual register, not a vsrc. 441 + auto data = (uint8_t)((int)get_vsrc(cc, inst.operands[0]) - 256); 442 + auto offset0 = inst.data.barrier_data.resource_id; 443 + cc.as.ds(RDNA2Assembler::ds_opcode::ds_gws_init, true, offset0, 0, 0, data, 0, 0); 444 + } break; 445 + case gir::Op::SubgroupBarrierSignal: { 446 + auto offset0 = inst.data.barrier_data.resource_id; 447 + cc.as.ds(RDNA2Assembler::ds_opcode::ds_gws_sema_v, true, offset0, 0, 0, 0, 0, 0); 448 + } break; 449 + case gir::Op::SubgroupBarrierWait: { 450 + auto offset0 = inst.data.barrier_data.resource_id; 451 + cc.as.ds(RDNA2Assembler::ds_opcode::ds_gws_sema_p, true, offset0, 0, 0, 0, 0, 0); 452 + } break; 407 453 408 454 case gir::Op::Const: 409 - case gir::Op::GetRootPtr: 410 - case gir::Op::GetLocalInvocationId: 455 + case gir::Op::RootPtr: 456 + case gir::Op::LocalInvocationIdX: 457 + case gir::Op::LocalInvocationIdY: 458 + case gir::Op::LocalInvocationIdZ: 459 + case gir::Op::LocalInvocationIndex: 460 + case gir::Op::WorkgroupIdX: 461 + case gir::Op::WorkgroupIdY: 462 + case gir::Op::WorkgroupIdZ: 411 463 // Skip metadata operations and constants 412 464 break; 413 465 default:
+7 -1
drivers/amdgpu/compiler/compiler.h
··· 1 1 #pragma once 2 2 3 + #include <cstdint> 4 + 3 5 #include "kestrel/gir.h" 4 6 5 - void rdna2_compile(gir::Module &mod, void *write_ptr, uint64_t base_addr); 7 + struct CompileShaderInfo { 8 + uint32_t num_user_sgprs; 9 + }; 10 + 11 + void rdna2_compile(gir::Module &mod, CompileShaderInfo shdr, void *write_ptr, uint64_t base_addr);
+40
drivers/amdgpu/compiler/rdna2_asm.h
··· 392 392 emit(0b0 << 31 | (uint8_t)op << 25 | vdst << 17 | vsrc1 << 9 | (uint16_t)src0 & 0x1FF); 393 393 } 394 394 395 + enum class ds_opcode : uint8_t { 396 + ds_add_u32 = 0, 397 + ds_sub_u32 = 1, 398 + ds_rsub_u32 = 2, 399 + ds_inc_u32 = 3, 400 + ds_dec_u32 = 4, 401 + ds_min_i32 = 5, 402 + ds_max_i32 = 6, 403 + ds_min_u32 = 7, 404 + ds_max_u32 = 8, 405 + ds_and_b32 = 9, 406 + ds_or_b32 = 10, 407 + ds_xor_b32 = 11, 408 + ds_mskor_b32 = 12, 409 + ds_write_b32 = 13, 410 + ds_write2_b32 = 14, 411 + ds_write2st64_b32 = 15, 412 + ds_cmpst_b32 = 16, 413 + ds_cmpst_f32 = 17, 414 + ds_min_f32 = 18, 415 + ds_max_f32 = 19, 416 + ds_nop = 20, 417 + ds_add_f32 = 21, 418 + 419 + ds_gws_sema_release_all = 24, 420 + ds_gws_init = 25, 421 + ds_gws_sema_v = 26, 422 + ds_gws_sema_br = 27, 423 + ds_gws_sema_p = 28, 424 + ds_gws_barrier = 29, 425 + 426 + // @todo: add the rest 427 + }; 428 + 429 + inline void ds(ds_opcode op, bool gds, uint8_t offset0, uint8_t offset1, 430 + uint8_t addr, uint8_t data0, uint8_t data1, uint8_t vdst) { 431 + emit(0b110110 << 26 | (uint8_t)op << 17 | offset1 << 8 | offset0); 432 + emit((vdst & 0xFF) << 24 | (gds & 0b1) << 17 | (data1 & 0xFF) << 8 | addr & 0xFF); 433 + } 434 + 395 435 // @todo: i think flat & global are really the same.. 396 436 // may want to consolidate them, but what about scratch? 397 437 enum class flat_opcode : uint8_t {
+28 -8
drivers/common/gir/gir_dump.cpp
··· 37 37 "Store", 38 38 "StoreShared", 39 39 "Const", 40 - "GetRootPtr", 41 - "GetLocalInvocationId", 42 - "GetThreadIdX", 43 - "GetThreadIdY", 44 - "GetThreadIdZ", 45 - "GetWorkgroupIdX", 46 - "GetWorkgroupIdY", 47 - "GetWorkgroupIdZ", 40 + "RootPtr", 41 + "LocalInvocationIdX", 42 + "LocalInvocationIdY", 43 + "LocalInvocationIdZ", 44 + "LocalInvocationIndex", 45 + "WorkgroupIdX", 46 + "WorkgroupIdY", 47 + "WorkgroupIdZ", 48 + "SubgroupId", 49 + "SubgroupSize", 50 + "NumSubgroups", 51 + "GlobalInvocationIdX", 52 + "GlobalInvocationIdY", 53 + "GlobalInvocationIdZ", 54 + "GlobalInvocationIndex", 55 + "WorkgroupBarrier", 56 + "SubgroupBarrierInit", 57 + "SubgroupBarrierWait", 58 + "SubgroupBarrierSignal", 48 59 "BackendIntrinsic", 49 60 }; 50 61 ··· 69 80 ss << ","; 70 81 } 71 82 ss << " $" << operand.id; 83 + } 84 + 85 + switch (inst.op) { 86 + case Op::SubgroupBarrierWait: 87 + case Op::SubgroupBarrierSignal: 88 + ss << " resource_id=" << inst.data.barrier_data.resource_id; 89 + break; 90 + default: 91 + break; 72 92 } 73 93 74 94 ss << std::endl;
+116 -32
kestrel/include/kestrel/gir.h
··· 57 57 Store, 58 58 StoreShared, 59 59 Const, 60 - GetRootPtr, 61 - GetLocalInvocationId, 62 - GetThreadIdX, 63 - GetThreadIdY, 64 - GetThreadIdZ, 65 - GetWorkgroupIdX, 66 - GetWorkgroupIdY, 67 - GetWorkgroupIdZ, 60 + RootPtr, 61 + LocalInvocationIdX, 62 + LocalInvocationIdY, 63 + LocalInvocationIdZ, 64 + LocalInvocationIndex, 65 + WorkgroupIdX, 66 + WorkgroupIdY, 67 + WorkgroupIdZ, 68 + SubgroupId, 69 + SubgroupSize, 70 + NumSubgroups, 71 + GlobalInvocationIdX, 72 + GlobalInvocationIdY, 73 + GlobalInvocationIdZ, 74 + GlobalInvocationIndex, 75 + WorkgroupBarrier, 76 + SubgroupBarrierInit, 77 + SubgroupBarrierWait, 78 + SubgroupBarrierSignal, 68 79 BackendIntrinsic, 69 80 }; 70 81 ··· 78 89 79 90 union { 80 91 int64_t imm_i64; 92 + 93 + struct { 94 + uint32_t resource_id; 95 + } barrier_data; 81 96 } data; 82 97 83 98 struct { ··· 124 139 void store(Value addr, Value data); 125 140 void store_shared(Value addr, Value data); 126 141 127 - Value get_root_ptr(); 142 + Value root_ptr(); 128 143 129 - Value get_local_invocation_id(); 130 - Value get_thread_id_x(); 131 - Value get_thread_id_y(); 132 - Value get_thread_id_z(); 144 + Value local_invocation_id_x(); 145 + Value local_invocation_id_y(); 146 + Value local_invocation_id_z(); 147 + Value local_invocation_index(); 133 148 134 - Value get_workgroup_id_x(); 135 - Value get_workgroup_id_y(); 136 - Value get_workgroup_id_z(); 149 + Value workgroup_id_x(); 150 + Value workgroup_id_y(); 151 + Value workgroup_id_z(); 152 + 153 + Value subgroup_id(); 154 + Value subgroup_size(); 155 + Value num_subgroups(); 156 + 157 + void workgroup_barrier(); 158 + 159 + void subgroup_barrier_init(uint32_t resource_id, Value v); 160 + void subgroup_barrier_wait(uint32_t resource_id); 161 + void subgroup_barrier_signal(uint32_t resource_id); 137 162 138 163 protected: 139 164 Module& mod; ··· 253 278 }); 254 279 } 255 280 256 - inline Value Builder::get_root_ptr() { 281 + inline Value Builder::root_ptr() { 257 282 return mod.emit(Inst{ 258 - .op = Op::GetRootPtr, 283 + .op = Op::RootPtr, 259 284 .type = Type::Ptr, 260 285 .operands = {} 261 286 }); 262 287 } 263 288 264 - inline Value Builder::get_local_invocation_id() { 289 + inline Value Builder::local_invocation_id_x() { 265 290 return mod.emit(Inst{ 266 - .op = Op::GetLocalInvocationId, 291 + .op = Op::LocalInvocationIdX, 267 292 .type = Type::I32, 268 293 .operands = {} 269 294 }); 270 295 } 271 296 272 - inline Value Builder::get_thread_id_x() { 297 + inline Value Builder::local_invocation_id_y() { 273 298 return mod.emit(Inst{ 274 - .op = Op::GetThreadIdX, 299 + .op = Op::LocalInvocationIdY, 275 300 .type = Type::I32, 276 301 .operands = {} 277 302 }); 278 303 } 279 304 280 - inline Value Builder::get_thread_id_y() { 305 + inline Value Builder::local_invocation_id_z() { 281 306 return mod.emit(Inst{ 282 - .op = Op::GetThreadIdY, 307 + .op = Op::LocalInvocationIdZ, 283 308 .type = Type::I32, 284 309 .operands = {} 285 310 }); 286 311 } 287 312 288 - inline Value Builder::get_thread_id_z() { 313 + inline Value Builder::local_invocation_index() { 289 314 return mod.emit(Inst{ 290 - .op = Op::GetThreadIdZ, 315 + .op = Op::LocalInvocationIndex, 291 316 .type = Type::I32, 292 317 .operands = {} 293 318 }); 294 319 } 295 320 296 321 297 - inline Value Builder::get_workgroup_id_x() { 322 + inline Value Builder::workgroup_id_x() { 298 323 return mod.emit(Inst{ 299 - .op = Op::GetWorkgroupIdX, 324 + .op = Op::WorkgroupIdX, 300 325 .type = Type::I32, 301 326 .operands = {} 302 327 }); 303 328 } 304 329 305 - inline Value Builder::get_workgroup_id_y() { 330 + inline Value Builder::workgroup_id_y() { 306 331 return mod.emit(Inst{ 307 - .op = Op::GetWorkgroupIdY, 332 + .op = Op::WorkgroupIdY, 308 333 .type = Type::I32, 309 334 .operands = {} 310 335 }); 311 336 } 312 337 313 - inline Value Builder::get_workgroup_id_z() { 338 + inline Value Builder::workgroup_id_z() { 314 339 return mod.emit(Inst{ 315 - .op = Op::GetWorkgroupIdZ, 340 + .op = Op::WorkgroupIdZ, 316 341 .type = Type::I32, 317 342 .operands = {} 318 343 }); 344 + } 345 + 346 + inline Value Builder::subgroup_id() { 347 + Inst inst; 348 + inst.op = Op::SubgroupId; 349 + inst.type = Type::I32; 350 + inst.operands = {}; 351 + return mod.emit(inst); 352 + } 353 + 354 + inline Value Builder::subgroup_size() { 355 + Inst inst; 356 + inst.op = Op::SubgroupSize; 357 + inst.type = Type::I32; 358 + inst.operands = {}; 359 + return mod.emit(inst); 360 + } 361 + 362 + inline Value Builder::num_subgroups() { 363 + Inst inst; 364 + inst.op = Op::NumSubgroups; 365 + inst.type = Type::I32; 366 + inst.operands = {}; 367 + return mod.emit(inst); 368 + } 369 + 370 + inline void Builder::workgroup_barrier() { 371 + Inst inst; 372 + inst.op = Op::WorkgroupBarrier; 373 + inst.type = Type::Void; 374 + inst.operands = {}; 375 + mod.emit(inst); 376 + } 377 + 378 + inline void Builder::subgroup_barrier_init(uint32_t resource_id, Value v) { 379 + Inst inst; 380 + inst.op = Op::SubgroupBarrierInit; 381 + inst.type = Type::Void; 382 + inst.data.barrier_data.resource_id = resource_id; 383 + inst.operands = {v}; 384 + mod.emit(inst); 385 + } 386 + 387 + inline void Builder::subgroup_barrier_wait(uint32_t resource_id) { 388 + Inst inst; 389 + inst.op = Op::SubgroupBarrierWait; 390 + inst.type = Type::Void; 391 + inst.data.barrier_data.resource_id = resource_id; 392 + inst.operands = {}; 393 + mod.emit(inst); 394 + } 395 + 396 + inline void Builder::subgroup_barrier_signal(uint32_t resource_id) { 397 + Inst inst; 398 + inst.op = Op::SubgroupBarrierSignal; 399 + inst.type = Type::Void; 400 + inst.data.barrier_data.resource_id = resource_id; 401 + inst.operands = {}; 402 + mod.emit(inst); 319 403 } 320 404 321 405 }
+2 -2
test/examples/07_hello_dispatch/hello_dispatch.cpp
··· 21 21 gir::Module mod; 22 22 { 23 23 gir::Builder gb(mod); 24 - auto rp = gb.get_root_ptr(); 25 - auto p = gb.add(rp, gb.mul(gb.get_local_invocation_id(), gb.i32(4))); 24 + auto rp = gb.root_ptr(); 25 + auto p = gb.add(rp, gb.mul(gb.local_invocation_index(), gb.i32(4))); 26 26 auto x = gb.load(p); 27 27 auto sum = gb.add(x, gb.i32(15)); 28 28 gb.store(p, sum);