Add computer shader infrastructure

This commit is contained in:
Jared Miller 2025-12-17 09:50:35 -05:00
parent 45c37bfcd2
commit 62d010bdc0
No known key found for this signature in database
3 changed files with 155 additions and 2 deletions

80
src/compute.zig Normal file
View file

@ -0,0 +1,80 @@
// compute shader module for GPU entity updates
// wraps raw GL calls that raylib doesn't expose directly
const std = @import("std");
const rl = @import("raylib");
const comp_source = @embedFile("shaders/entity_update.comp");
// GL constants not exposed by raylib-zig
const GL_SHADER_STORAGE_BARRIER_BIT: u32 = 0x00002000;
// function pointer type for glMemoryBarrier
const GlMemoryBarrierFn = *const fn (barriers: u32) callconv(.c) void;
pub const ComputeShader = struct {
program_id: u32,
entity_count_loc: i32,
glMemoryBarrier: GlMemoryBarrierFn,
pub fn init() ?ComputeShader {
// load glMemoryBarrier dynamically
const barrier_ptr = rl.gl.rlGetProcAddress("glMemoryBarrier");
const glMemoryBarrier: GlMemoryBarrierFn = @ptrCast(barrier_ptr);
// compile compute shader
const shader_id = rl.gl.rlCompileShader(comp_source, rl.gl.rl_compute_shader);
if (shader_id == 0) {
std.debug.print("compute: failed to compile compute shader\n", .{});
return null;
}
// link compute program
const program_id = rl.gl.rlLoadComputeShaderProgram(shader_id);
if (program_id == 0) {
std.debug.print("compute: failed to link compute program\n", .{});
return null;
}
// get uniform locations
const entity_count_loc = rl.gl.rlGetLocationUniform(program_id, "entityCount");
if (entity_count_loc < 0) {
std.debug.print("compute: warning - entityCount uniform not found\n", .{});
}
std.debug.print("compute: shader loaded successfully (program_id={})\n", .{program_id});
return .{
.program_id = program_id,
.entity_count_loc = entity_count_loc,
.glMemoryBarrier = glMemoryBarrier,
};
}
pub fn deinit(self: *ComputeShader) void {
rl.gl.rlUnloadShaderProgram(self.program_id);
}
pub fn dispatch(self: *ComputeShader, ssbo_id: u32, entity_count: u32) void {
if (entity_count == 0) return;
// bind compute shader
rl.gl.rlEnableShader(self.program_id);
// set entityCount uniform
rl.gl.rlSetUniform(self.entity_count_loc, &entity_count, @intFromEnum(rl.gl.rlShaderUniformDataType.rl_shader_uniform_uint), 1);
// bind SSBO to binding point 0
rl.gl.rlBindShaderBuffer(ssbo_id, 0);
// dispatch compute workgroups: ceil(entity_count / 256)
const groups = (entity_count + 255) / 256;
rl.gl.rlComputeShaderDispatch(groups, 1, 1);
// memory barrier - ensure compute writes are visible to vertex shader
self.glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
// unbind
rl.gl.rlBindShaderBuffer(0, 0);
}
};

View file

@ -7,6 +7,7 @@ const ztracy = @import("ztracy");
const sandbox = @import("sandbox.zig");
const ui = @import("ui.zig");
const SsboRenderer = @import("ssbo_renderer.zig").SsboRenderer;
const ComputeShader = @import("compute.zig").ComputeShader;
const SCREEN_WIDTH = sandbox.SCREEN_WIDTH;
const SCREEN_HEIGHT = sandbox.SCREEN_HEIGHT;
@ -163,6 +164,7 @@ pub fn main() !void {
var use_instancing = false;
var use_ssbo = true;
var use_vsync = false;
var use_compute = false;
var args = try std.process.argsWithAllocator(std.heap.page_allocator);
defer args.deinit();
_ = args.skip(); // skip program name
@ -176,6 +178,8 @@ pub fn main() !void {
use_ssbo = false; // legacy rlgl batched path
} else if (std.mem.eql(u8, arg, "--vsync")) {
use_vsync = true;
} else if (std.mem.eql(u8, arg, "--compute")) {
use_compute = true;
}
}
@ -257,6 +261,26 @@ pub fn main() !void {
if (ssbo_renderer) |*r| r.deinit();
}
// compute shader setup (only if --compute flag)
var compute_shader: ?ComputeShader = null;
if (use_compute) {
if (!use_ssbo) {
std.debug.print("--compute requires SSBO mode (default), ignoring\n", .{});
} else {
compute_shader = ComputeShader.init();
if (compute_shader == null) {
std.debug.print("failed to initialize compute shader, falling back to CPU\n", .{});
} else {
std.debug.print("compute shader mode enabled\n", .{});
}
}
}
defer {
if (compute_shader) |*c| c.deinit();
}
// load UI font (embedded)
const font_data = @embedFile("verdanab.ttf");
const ui_font = rl.loadFontFromMemory(".ttf", font_data, 32, null) catch {
@ -335,7 +359,16 @@ pub fn main() !void {
const tracy_update = ztracy.ZoneN(@src(), "update");
defer tracy_update.End();
const update_start = std.time.microTimestamp();
sandbox.update(&entities, &rng);
if (compute_shader != null) {
// GPU compute update - positions updated on GPU
// still need CPU update for respawn logic until Step 3
sandbox.update(&entities, &rng);
} else {
// CPU update path
sandbox.update(&entities, &rng);
}
update_time_us = std.time.microTimestamp() - update_start;
}
@ -348,7 +381,14 @@ pub fn main() !void {
rl.clearBackground(BG_COLOR);
if (use_ssbo) {
// SSBO instanced rendering path (12 bytes per entity)
// dispatch compute shader before render (if enabled)
if (compute_shader) |*cs| {
const tracy_compute = ztracy.ZoneN(@src(), "compute_dispatch");
defer tracy_compute.End();
cs.dispatch(ssbo_renderer.?.ssbo_id, @intCast(entities.count));
}
// SSBO instanced rendering path (16 bytes per entity)
ssbo_renderer.?.render(&entities, zoom, pan);
} else if (use_instancing) {
// GPU instancing path (64 bytes per entity)

View file

@ -0,0 +1,33 @@
#version 430
layout(local_size_x = 256) in;
struct Entity {
float x;
float y;
int packedVel; // vx high 16 bits, vy low 16 bits (fixed-point 8.8)
uint color;
};
layout(std430, binding = 0) buffer Entities {
Entity entities[];
};
uniform uint entityCount;
void main() {
uint id = gl_GlobalInvocationID.x;
if (id >= entityCount) return;
Entity e = entities[id];
// unpack velocity (fixed-point 8.8)
float vx = float(e.packedVel >> 16) / 256.0;
float vy = float((e.packedVel << 16) >> 16) / 256.0; // sign-extend low 16 bits
// update position
e.x += vx;
e.y += vy;
entities[id] = e;
}