Skip to content

Commit

Permalink
disallow ptr to workgroup fn arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Sep 25, 2023
1 parent 3bcb114 commit b3c3644
Show file tree
Hide file tree
Showing 13 changed files with 386 additions and 653 deletions.
7 changes: 1 addition & 6 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,12 +993,7 @@ impl super::Validator {
#[cfg(feature = "validate")]
for (index, argument) in fun.arguments.iter().enumerate() {
match module.types[argument.ty].inner.pointer_space() {
Some(
crate::AddressSpace::Private
| crate::AddressSpace::Function
| crate::AddressSpace::WorkGroup,
)
| None => {}
Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
Some(other) => {
return Err(FunctionError::InvalidArgumentPointerSpace {
index,
Expand Down
14 changes: 8 additions & 6 deletions src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ fn check_member_layout(
/// `TypeFlags::empty()`.
///
/// Pointers passed as arguments to user-defined functions must be in the
/// `Function`, `Private`, or `Workgroup` storage space.
/// `Function` or `Private` address space.
const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
use crate::AddressSpace as As;
match space {
As::Function | As::Private | As::WorkGroup => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant => TypeFlags::empty(),
As::Function | As::Private => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => {
TypeFlags::empty()
}
}
}

Expand Down Expand Up @@ -316,7 +318,7 @@ impl super::Validator {
return Err(TypeError::InvalidPointerBase(base));
}

// Runtime-sized values can only live in the `Storage` storage
// Runtime-sized values can only live in the `Storage` address
// space, so it's useless to have a pointer to such a type in
// any other space.
//
Expand All @@ -336,7 +338,7 @@ impl super::Validator {
}
}

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down Expand Up @@ -364,7 +366,7 @@ impl super::Validator {
// `InvalidPointerBase` or `InvalidPointerToUnsized`.
self.check_width(kind, width)?;

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down
8 changes: 0 additions & 8 deletions tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,12 @@ fn foo_frag() -> @location(0) vec4<f32> {
return vec4<f32>(0.0);
}

var<workgroup> val: u32;

fn assign_through_ptr_fn(p: ptr<workgroup, u32>) {
*p = 42u;
}

fn assign_array_through_ptr_fn(foo: ptr<function, array<vec4<f32>, 2>>) {
*foo = array<vec4<f32>, 2>(vec4(1.0), vec4(2.0));
}

@compute @workgroup_size(1)
fn assign_through_ptr() {
var arr = array<vec4<f32>, 2>(vec4(6.0), vec4(7.0));

assign_through_ptr_fn(&val);
assign_array_through_ptr_fn(&arr);
}
72 changes: 4 additions & 68 deletions tests/out/analysis/access.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("SIZED | COPY | ARGUMENT"),
("DATA | SIZED | COPY | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
("SIZED | COPY | ARGUMENT"),
],
Expand All @@ -46,7 +45,6 @@
("READ"),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -1144,7 +1142,6 @@
(""),
(""),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2414,7 +2411,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2454,7 +2450,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2503,50 +2498,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
uniformity: (
non_uniform_result: Some(1),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(27),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
],
sampling: [],
dual_source_blending: false,
),
(
flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"),
available_stages: ("VERTEX | FRAGMENT | COMPUTE"),
uniformity: (
non_uniform_result: None,
requirements: (""),
),
may_kill: false,
sampling_set: [],
global_uses: [
(""),
(""),
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand All @@ -2556,7 +2507,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(29),
ty: Handle(28),
),
(
uniformity: (
Expand Down Expand Up @@ -2615,7 +2566,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(28),
ty: Handle(27),
),
],
sampling: [],
Expand All @@ -2638,7 +2589,6 @@
("READ"),
("READ"),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3302,7 +3252,6 @@
(""),
("WRITE"),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3736,7 +3685,6 @@
(""),
(""),
(""),
("READ"),
],
expressions: [
(
Expand Down Expand Up @@ -3796,7 +3744,7 @@
),
ref_count: 1,
assignable_global: None,
ty: Handle(28),
ty: Handle(27),
),
(
uniformity: (
Expand All @@ -3806,22 +3754,10 @@
ref_count: 2,
assignable_global: None,
ty: Value(Pointer(
base: 28,
base: 27,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: Some(6),
ty: Value(Pointer(
base: 1,
space: WorkGroup,
)),
),
],
sampling: [],
dual_source_blending: false,
Expand Down
13 changes: 0 additions & 13 deletions tests/out/glsl/access.assign_through_ptr.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct Baz {
struct MatCx2InArray {
mat4x2 am[2];
};
shared uint val;


float read_from_private(inout float foo_1) {
float _e1 = foo_1;
Expand All @@ -31,25 +29,14 @@ float test_arr_as_arg(float a[5][10]) {
return a[4][9];
}

void assign_through_ptr_fn(inout uint p) {
p = 42u;
return;
}

void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
foo_2 = vec4[2](vec4(1.0), vec4(2.0));
return;
}

void main() {
if (gl_LocalInvocationID == uvec3(0u)) {
val = 0u;
}
memoryBarrierShared();
barrier();
vec4 arr[2] = vec4[2](vec4(0.0), vec4(0.0));
arr = vec4[2](vec4(6.0), vec4(7.0));
assign_through_ptr_fn(val);
assign_array_through_ptr_fn(arr);
return;
}
Expand Down
5 changes: 0 additions & 5 deletions tests/out/glsl/access.foo_frag.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ float test_arr_as_arg(float a[5][10]) {
return a[4][9];
}

void assign_through_ptr_fn(inout uint p) {
p = 42u;
return;
}

void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
foo_2 = vec4[2](vec4(1.0), vec4(2.0));
return;
Expand Down
5 changes: 0 additions & 5 deletions tests/out/glsl/access.foo_vert.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ float test_arr_as_arg(float a[5][10]) {
return a[4][9];
}

void assign_through_ptr_fn(inout uint p) {
p = 42u;
return;
}

void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
foo_2 = vec4[2](vec4(1.0), vec4(2.0));
return;
Expand Down
14 changes: 1 addition & 13 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ RWByteAddressBuffer bar : register(u0);
cbuffer baz : register(b1) { Baz baz; }
RWByteAddressBuffer qux : register(u2);
cbuffer nested_mat_cx2_ : register(b3) { MatCx2InArray nested_mat_cx2_; }
groupshared uint val;

Baz ConstructBaz(float3x2 arg0) {
Baz ret = (Baz)0;
Expand Down Expand Up @@ -212,12 +211,6 @@ float test_arr_as_arg(float a[5][10])
return a[4][9];
}

void assign_through_ptr_fn(inout uint p)
{
p = 42u;
return;
}

typedef float4 ret_Constructarray2_float4_[2];
ret_Constructarray2_float4_ Constructarray2_float4_(float4 arg0, float4 arg1) {
float4 ret[2] = { arg0, arg1 };
Expand Down Expand Up @@ -293,16 +286,11 @@ float4 foo_frag() : SV_Target0
}

[numthreads(1, 1, 1)]
void assign_through_ptr(uint3 __local_invocation_id : SV_GroupThreadID)
void assign_through_ptr()
{
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
val = (uint)0;
}
GroupMemoryBarrierWithGroupSync();
float4 arr[2] = (float4[2])0;

arr = Constructarray2_float4_((6.0).xxxx, (7.0).xxxx);
assign_through_ptr_fn(val);
assign_array_through_ptr_fn(arr);
return;
}
Loading

0 comments on commit b3c3644

Please sign in to comment.