Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow ptr to workgroup fn arguments #2507

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,7 @@ impl super::Validator {
#[cfg(feature = "validate")]
for (index, argument) in fun.arguments.iter().enumerate() {
match module.types[argument.ty].inner.pointer_space() {
Some(
crate::AddressSpace::Private
| crate::AddressSpace::Function
| crate::AddressSpace::WorkGroup,
)
| None => {}
Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
Some(other) => {
return Err(FunctionError::InvalidArgumentPointerSpace {
index,
Expand Down
14 changes: 8 additions & 6 deletions src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ fn check_member_layout(
/// `TypeFlags::empty()`.
///
/// Pointers passed as arguments to user-defined functions must be in the
/// `Function`, `Private`, or `Workgroup` storage space.
/// `Function` or `Private` address space.
const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
use crate::AddressSpace as As;
match space {
As::Function | As::Private | As::WorkGroup => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant => TypeFlags::empty(),
As::Function | As::Private => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => {
TypeFlags::empty()
}
}
}

Expand Down Expand Up @@ -316,7 +318,7 @@ impl super::Validator {
return Err(TypeError::InvalidPointerBase(base));
}

// Runtime-sized values can only live in the `Storage` storage
// Runtime-sized values can only live in the `Storage` address
// space, so it's useless to have a pointer to such a type in
// any other space.
//
Expand All @@ -336,7 +338,7 @@ impl super::Validator {
}
}

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down Expand Up @@ -364,7 +366,7 @@ impl super::Validator {
// `InvalidPointerBase` or `InvalidPointerToUnsized`.
self.check_width(kind, width)?;

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down
9 changes: 4 additions & 5 deletions tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ fn foo_frag() -> @location(0) vec4<f32> {
return vec4<f32>(0.0);
}

var<workgroup> val: u32;

fn assign_through_ptr_fn(p: ptr<workgroup, u32>) {
fn assign_through_ptr_fn(p: ptr<function, u32>) {
*p = 42u;
}

Expand All @@ -163,8 +161,9 @@ fn assign_array_through_ptr_fn(foo: ptr<function, array<vec4<f32>, 2>>) {

@compute @workgroup_size(1)
fn assign_through_ptr() {
var arr = array<vec4<f32>, 2>(vec4(6.0), vec4(7.0));

var val = 33u;
assign_through_ptr_fn(&val);

var arr = array<vec4<f32>, 2>(vec4(6.0), vec4(7.0));
assign_array_through_ptr_fn(&arr);
}
47 changes: 25 additions & 22 deletions tests/out/analysis/access.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
("READ"),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -1144,7 +1143,6 @@
(""),
(""),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2414,7 +2412,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2454,7 +2451,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2503,7 +2499,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2546,7 +2541,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2638,7 +2632,6 @@
("READ"),
("READ"),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3302,7 +3295,6 @@
(""),
("WRITE"),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3736,9 +3728,32 @@
(""),
(""),
(""),
("READ"),
],
expressions: [
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
non_uniform_result: Some(2),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 1,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
Expand Down Expand Up @@ -3800,7 +3815,7 @@
),
(
uniformity: (
non_uniform_result: Some(6),
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
Expand All @@ -3810,18 +3825,6 @@
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: Some(6),
ty: Value(Pointer(
base: 1,
space: WorkGroup,
)),
),
],
sampling: [],
dual_source_blending: false,
Expand Down
8 changes: 1 addition & 7 deletions tests/out/glsl/access.assign_through_ptr.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct Baz {
struct MatCx2InArray {
mat4x2 am[2];
};
shared uint val;


float read_from_private(inout float foo_1) {
float _e1 = foo_1;
Expand All @@ -42,11 +40,7 @@ void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
}

void main() {
if (gl_LocalInvocationID == uvec3(0u)) {
val = 0u;
}
memoryBarrierShared();
barrier();
uint val = 33u;
vec4 arr[2] = vec4[2](vec4(6.0), vec4(7.0));
assign_through_ptr_fn(val);
assign_array_through_ptr_fn(arr);
Expand Down
8 changes: 2 additions & 6 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ RWByteAddressBuffer bar : register(u0);
cbuffer baz : register(b1) { Baz baz; }
RWByteAddressBuffer qux : register(u2);
cbuffer nested_mat_cx2_ : register(b3) { MatCx2InArray nested_mat_cx2_; }
groupshared uint val;

Baz ConstructBaz(float3x2 arg0) {
Baz ret = (Baz)0;
Expand Down Expand Up @@ -288,12 +287,9 @@ float4 foo_frag() : SV_Target0
}

[numthreads(1, 1, 1)]
void assign_through_ptr(uint3 __local_invocation_id : SV_GroupThreadID)
void assign_through_ptr()
{
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
val = (uint)0;
}
GroupMemoryBarrierWithGroupSync();
uint val = 33u;
float4 arr[2] = Constructarray2_float4_((6.0).xxxx, (7.0).xxxx);

assign_through_ptr_fn(val);
Expand Down
47 changes: 23 additions & 24 deletions tests/out/ir/access.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@
name: None,
inner: Pointer(
base: 1,
space: WorkGroup,
space: Function,
),
),
(
Expand Down Expand Up @@ -356,13 +356,6 @@
ty: 20,
init: None,
),
(
name: Some("val"),
space: WorkGroup,
binding: None,
ty: 1,
init: None,
),
],
const_expressions: [
Literal(U32(0)),
Expand Down Expand Up @@ -2137,54 +2130,60 @@
arguments: [],
result: None,
local_variables: [
(
name: Some("val"),
ty: 1,
init: Some(1),
),
(
name: Some("arr"),
ty: 28,
init: Some(5),
init: Some(7),
),
],
expressions: [
Literal(U32(33)),
LocalVariable(1),
Literal(F32(6.0)),
Splat(
size: Quad,
value: 1,
value: 3,
),
Literal(F32(7.0)),
Splat(
size: Quad,
value: 3,
value: 5,
),
Compose(
ty: 28,
components: [
2,
4,
6,
],
),
LocalVariable(1),
GlobalVariable(6),
LocalVariable(2),
],
named_expressions: {},
body: [
Emit((
start: 1,
end: 2,
)),
Emit((
start: 3,
end: 5,
)),
Call(
function: 5,
arguments: [
7,
2,
],
result: None,
),
Emit((
start: 3,
end: 4,
)),
Emit((
start: 5,
end: 7,
)),
Call(
function: 6,
arguments: [
6,
8,
],
result: None,
),
Expand Down
Loading
Loading