From 62d010bdc01dc9d8494eb4a196bcfb9f6f78b42d Mon Sep 17 00:00:00 2001 From: Jared Miller Date: Wed, 17 Dec 2025 09:50:35 -0500 Subject: [PATCH] Add computer shader infrastructure --- src/compute.zig | 80 ++++++++++++++++++++++++++++++++++ src/sandbox_main.zig | 44 ++++++++++++++++++- src/shaders/entity_update.comp | 33 ++++++++++++++ 3 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 src/compute.zig create mode 100644 src/shaders/entity_update.comp diff --git a/src/compute.zig b/src/compute.zig new file mode 100644 index 0000000..df59fe0 --- /dev/null +++ b/src/compute.zig @@ -0,0 +1,80 @@ +// compute shader module for GPU entity updates +// wraps raw GL calls that raylib doesn't expose directly + +const std = @import("std"); +const rl = @import("raylib"); + +const comp_source = @embedFile("shaders/entity_update.comp"); + +// GL constants not exposed by raylib-zig +const GL_SHADER_STORAGE_BARRIER_BIT: u32 = 0x00002000; + +// function pointer type for glMemoryBarrier +const GlMemoryBarrierFn = *const fn (barriers: u32) callconv(.c) void; + +pub const ComputeShader = struct { + program_id: u32, + entity_count_loc: i32, + glMemoryBarrier: GlMemoryBarrierFn, + + pub fn init() ?ComputeShader { + // load glMemoryBarrier dynamically + const barrier_ptr = rl.gl.rlGetProcAddress("glMemoryBarrier"); + const glMemoryBarrier: GlMemoryBarrierFn = @ptrCast(barrier_ptr); + + // compile compute shader + const shader_id = rl.gl.rlCompileShader(comp_source, rl.gl.rl_compute_shader); + if (shader_id == 0) { + std.debug.print("compute: failed to compile compute shader\n", .{}); + return null; + } + + // link compute program + const program_id = rl.gl.rlLoadComputeShaderProgram(shader_id); + if (program_id == 0) { + std.debug.print("compute: failed to link compute program\n", .{}); + return null; + } + + // get uniform locations + const entity_count_loc = rl.gl.rlGetLocationUniform(program_id, "entityCount"); + if (entity_count_loc < 0) { + std.debug.print("compute: warning - entityCount uniform not found\n", .{}); + } + + std.debug.print("compute: shader loaded successfully (program_id={})\n", .{program_id}); + + return .{ + .program_id = program_id, + .entity_count_loc = entity_count_loc, + .glMemoryBarrier = glMemoryBarrier, + }; + } + + pub fn deinit(self: *ComputeShader) void { + rl.gl.rlUnloadShaderProgram(self.program_id); + } + + pub fn dispatch(self: *ComputeShader, ssbo_id: u32, entity_count: u32) void { + if (entity_count == 0) return; + + // bind compute shader + rl.gl.rlEnableShader(self.program_id); + + // set entityCount uniform + rl.gl.rlSetUniform(self.entity_count_loc, &entity_count, @intFromEnum(rl.gl.rlShaderUniformDataType.rl_shader_uniform_uint), 1); + + // bind SSBO to binding point 0 + rl.gl.rlBindShaderBuffer(ssbo_id, 0); + + // dispatch compute workgroups: ceil(entity_count / 256) + const groups = (entity_count + 255) / 256; + rl.gl.rlComputeShaderDispatch(groups, 1, 1); + + // memory barrier - ensure compute writes are visible to vertex shader + self.glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT); + + // unbind + rl.gl.rlBindShaderBuffer(0, 0); + } +}; diff --git a/src/sandbox_main.zig b/src/sandbox_main.zig index 03b8263..da6b008 100644 --- a/src/sandbox_main.zig +++ b/src/sandbox_main.zig @@ -7,6 +7,7 @@ const ztracy = @import("ztracy"); const sandbox = @import("sandbox.zig"); const ui = @import("ui.zig"); const SsboRenderer = @import("ssbo_renderer.zig").SsboRenderer; +const ComputeShader = @import("compute.zig").ComputeShader; const SCREEN_WIDTH = sandbox.SCREEN_WIDTH; const SCREEN_HEIGHT = sandbox.SCREEN_HEIGHT; @@ -163,6 +164,7 @@ pub fn main() !void { var use_instancing = false; var use_ssbo = true; var use_vsync = false; + var use_compute = false; var args = try std.process.argsWithAllocator(std.heap.page_allocator); defer args.deinit(); _ = args.skip(); // skip program name @@ -176,6 +178,8 @@ pub fn main() !void { use_ssbo = false; // legacy rlgl batched path } else if (std.mem.eql(u8, arg, "--vsync")) { use_vsync = true; + } else if (std.mem.eql(u8, arg, "--compute")) { + use_compute = true; } } @@ -257,6 +261,26 @@ pub fn main() !void { if (ssbo_renderer) |*r| r.deinit(); } + // compute shader setup (only if --compute flag) + var compute_shader: ?ComputeShader = null; + + if (use_compute) { + if (!use_ssbo) { + std.debug.print("--compute requires SSBO mode (default), ignoring\n", .{}); + } else { + compute_shader = ComputeShader.init(); + if (compute_shader == null) { + std.debug.print("failed to initialize compute shader, falling back to CPU\n", .{}); + } else { + std.debug.print("compute shader mode enabled\n", .{}); + } + } + } + + defer { + if (compute_shader) |*c| c.deinit(); + } + // load UI font (embedded) const font_data = @embedFile("verdanab.ttf"); const ui_font = rl.loadFontFromMemory(".ttf", font_data, 32, null) catch { @@ -335,7 +359,16 @@ pub fn main() !void { const tracy_update = ztracy.ZoneN(@src(), "update"); defer tracy_update.End(); const update_start = std.time.microTimestamp(); - sandbox.update(&entities, &rng); + + if (compute_shader != null) { + // GPU compute update - positions updated on GPU + // still need CPU update for respawn logic until Step 3 + sandbox.update(&entities, &rng); + } else { + // CPU update path + sandbox.update(&entities, &rng); + } + update_time_us = std.time.microTimestamp() - update_start; } @@ -348,7 +381,14 @@ pub fn main() !void { rl.clearBackground(BG_COLOR); if (use_ssbo) { - // SSBO instanced rendering path (12 bytes per entity) + // dispatch compute shader before render (if enabled) + if (compute_shader) |*cs| { + const tracy_compute = ztracy.ZoneN(@src(), "compute_dispatch"); + defer tracy_compute.End(); + cs.dispatch(ssbo_renderer.?.ssbo_id, @intCast(entities.count)); + } + + // SSBO instanced rendering path (16 bytes per entity) ssbo_renderer.?.render(&entities, zoom, pan); } else if (use_instancing) { // GPU instancing path (64 bytes per entity) diff --git a/src/shaders/entity_update.comp b/src/shaders/entity_update.comp new file mode 100644 index 0000000..a18ff72 --- /dev/null +++ b/src/shaders/entity_update.comp @@ -0,0 +1,33 @@ +#version 430 + +layout(local_size_x = 256) in; + +struct Entity { + float x; + float y; + int packedVel; // vx high 16 bits, vy low 16 bits (fixed-point 8.8) + uint color; +}; + +layout(std430, binding = 0) buffer Entities { + Entity entities[]; +}; + +uniform uint entityCount; + +void main() { + uint id = gl_GlobalInvocationID.x; + if (id >= entityCount) return; + + Entity e = entities[id]; + + // unpack velocity (fixed-point 8.8) + float vx = float(e.packedVel >> 16) / 256.0; + float vy = float((e.packedVel << 16) >> 16) / 256.0; // sign-extend low 16 bits + + // update position + e.x += vx; + e.y += vy; + + entities[id] = e; +}