Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConstantEvaluator::swizzle: Handle vector concatenation and indexing #2485

Merged
merged 4 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> {
self.writer.constant_ids[init.index()]
}
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose {
ty: _,
ref components,
} => {
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
for &component in components {
self.temp_list.push(self.cached[component]);
}

if self.ir_function.expressions.is_const(expr_handle) {
let ty = self
.writer
.get_expression_lookup_type(&self.fun_info[expr_handle].ty);
self.writer.get_constant_composite(ty, &self.temp_list)
self.temp_list.extend(
crate::proc::flatten_compose(
ty,
components,
&self.ir_function.expressions,
&self.ir_module.types,
)
.map(|component| self.cached[component]),
);
self.writer
.get_constant_composite(LookupType::Handle(ty), &self.temp_list)
} else {
self.temp_list
.extend(components.iter().map(|&component| self.cached[component]));

let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
Expand Down
12 changes: 8 additions & 4 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1269,10 +1269,14 @@ impl Writer {
self.get_constant_null(type_id)
}
crate::Expression::Compose { ty, ref components } => {
let component_ids: Vec<_> = components
.iter()
.map(|component| self.constant_ids[component.index()])
.collect();
let component_ids: Vec<_> = crate::proc::flatten_compose(
ty,
components,
&ir_module.const_expressions,
&ir_module.types,
)
.map(|component| self.constant_ids[component.index()])
.collect();
self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
}
crate::Expression::Splat { size, value } => {
Expand Down
36 changes: 24 additions & 12 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError {
SplatScalarOnly,
#[error("Can only swizzle vector constants")]
SwizzleVectorOnly,
#[error("swizzle component not present in source expression")]
SwizzleOutOfBounds,
#[error("Type is not constructible")]
TypeNotConstructible,
#[error("Subexpression(s) are not constant")]
Expand Down Expand Up @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> {
let expr = Expression::Splat { size, value };
Ok(self.register_constant(expr, span))
}
Expression::Compose {
ty,
components: ref src_components,
} => {
Expression::Compose { ty, ref components } => {
let dst_ty = get_dst_ty(ty)?;

let components = pattern
let mut flattened = [src_constant; 4]; // dummy value
let len =
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.zip(flattened.iter_mut())
.map(|(component, elt)| *elt = component)
.count();
let flattened = &flattened[..len];

let swizzled_components = pattern[..size as usize]
.iter()
.take(size as usize)
.map(|&sc| src_components[sc as usize])
.collect();
.map(|&sc| {
let sc = sc as usize;
if let Some(elt) = flattened.get(sc) {
Ok(*elt)
} else {
Err(ConstantEvaluatorError::SwizzleOutOfBounds)
}
})
.collect::<Result<Vec<Handle<Expression>>, _>>()?;
let expr = Expression::Compose {
ty: dst_ty,
components,
components: swizzled_components,
};
Ok(self.register_constant(expr, span))
}
Expand Down Expand Up @@ -454,9 +467,8 @@ impl ConstantEvaluator<'_> {
.components()
.ok_or(ConstantEvaluatorError::InvalidAccessBase)?;

components
.get(index)
.copied()
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.nth(index)
.ok_or(ConstantEvaluatorError::InvalidAccessIndex)
}
_ => Err(ConstantEvaluatorError::InvalidAccessBase),
Expand Down
55 changes: 55 additions & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,61 @@ impl GlobalCtx<'_> {
}
}

/// Return an iterator over the individual components assembled by a
/// `Compose` expression.
///
/// Given `ty` and `components` from an `Expression::Compose`, return an
/// iterator over the components of the resulting value.
///
/// Normally, this would just be an iterator over `components`. However,
/// `Compose` expressions can concatenate vectors, in which case the i'th
/// value being composed is not generally the i'th element of `components`.
/// This function consults `ty` to decide if this concatenation is occuring,
/// and returns an iterator that produces the components of the result of
/// the `Compose` expression in either case.
pub fn flatten_compose<'arenas>(
ty: crate::Handle<crate::Type>,
components: &'arenas [crate::Handle<crate::Expression>],
expressions: &'arenas crate::Arena<crate::Expression>,
types: &'arenas crate::UniqueArena<crate::Type>,
) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
// Returning `impl Iterator` is a bit tricky. We may or may not want to
// flatten the components, but we have to settle on a single concrete
// type to return. The below is a single iterator chain that handles
// both the flattening and non-flattening cases.
let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
(size as usize, true)
} else {
(components.len(), false)
};

fn flattener<'c>(
component: &'c crate::Handle<crate::Expression>,
is_vector: bool,
expressions: &'c crate::Arena<crate::Expression>,
) -> &'c [crate::Handle<crate::Expression>] {
if is_vector {
if let crate::Expression::Compose {
ty: _,
components: ref subcomponents,
} = expressions[*component]
{
return subcomponents;
}
}
std::slice::from_ref(component)
}

// Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten
// two levels.
components
.iter()
.flat_map(move |component| flattener(component, is_vector, expressions))
.flat_map(move |component| flattener(component, is_vector, expressions))
.take(size)
.cloned()
}

#[test]
fn test_matrix_size() {
let module = crate::Module::default();
Expand Down
14 changes: 14 additions & 0 deletions tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@group(0) @binding(0) var<storage, read_write> out: vec4<i32>;
@group(0) @binding(1) var<storage, read_write> out2: i32;
@group(0) @binding(2) var<storage, read_write> out3: i32;

@compute @workgroup_size(1)
fn main() {
let a = vec2(1, 2);
let b = vec2(3, 4);
out = vec4(a, b).wzyx;

out2 = vec4(a, b)[1];

out3 = vec4(vec3(vec2(6, 7), 8), 9)[0];
}
23 changes: 23 additions & 0 deletions tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#version 310 es

precision highp float;
precision highp int;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; };

layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; };

layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; };


void main() {
ivec2 a = ivec2(1, 2);
ivec2 b = ivec2(3, 4);
_group_0_binding_0_cs = ivec4(4, 3, 2, 1);
_group_0_binding_1_cs = 2;
_group_0_binding_2_cs = 6;
return;
}

14 changes: 14 additions & 0 deletions tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
RWByteAddressBuffer out_ : register(u0);
RWByteAddressBuffer out2_ : register(u1);
RWByteAddressBuffer out3_ : register(u2);

[numthreads(1, 1, 1)]
void main()
{
int2 a = int2(1, 2);
int2 b = int2(3, 4);
out_.Store4(0, asuint(int4(4, 3, 2, 1)));
out2_.Store(0, asuint(2));
out3_.Store(0, asuint(6));
return;
}
12 changes: 12 additions & 0 deletions tests/out/hlsl/const-exprs.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
19 changes: 19 additions & 0 deletions tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


kernel void main_(
device metal::int4& out [[user(fake0)]]
, device int& out2_ [[user(fake0)]]
, device int& out3_ [[user(fake0)]]
) {
metal::int2 a = metal::int2(1, 2);
metal::int2 b = metal::int2(3, 4);
out = metal::int4(4, 3, 2, 1);
out2_ = 2;
out3_ = 6;
return;
}
67 changes: 67 additions & 0 deletions tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 41
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %20 "main"
OpExecutionMode %20 LocalSize 1 1 1
OpDecorate %10 DescriptorSet 0
OpDecorate %10 Binding 0
OpDecorate %11 Block
OpMemberDecorate %11 0 Offset 0
OpDecorate %13 DescriptorSet 0
OpDecorate %13 Binding 1
OpDecorate %14 Block
OpMemberDecorate %14 0 Offset 0
OpDecorate %16 DescriptorSet 0
OpDecorate %16 Binding 2
OpDecorate %17 Block
OpMemberDecorate %17 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 4
%5 = OpTypeVector %4 2
%6 = OpTypeVector %4 3
%7 = OpConstant %4 0
%8 = OpConstant %4 1
%9 = OpConstant %4 2
%11 = OpTypeStruct %3
%12 = OpTypePointer StorageBuffer %11
%10 = OpVariable %12 StorageBuffer
%14 = OpTypeStruct %4
%15 = OpTypePointer StorageBuffer %14
%13 = OpVariable %15 StorageBuffer
%17 = OpTypeStruct %4
%18 = OpTypePointer StorageBuffer %17
%16 = OpVariable %18 StorageBuffer
%21 = OpTypeFunction %2
%22 = OpTypePointer StorageBuffer %3
%24 = OpTypeInt 32 0
%23 = OpConstant %24 0
%26 = OpTypePointer StorageBuffer %4
%29 = OpConstantComposite %5 %8 %9
%30 = OpConstant %4 3
%31 = OpConstant %4 4
%32 = OpConstantComposite %5 %30 %31
%33 = OpConstantComposite %3 %31 %30 %9 %8
%34 = OpConstant %4 6
%35 = OpConstant %4 7
%36 = OpConstantComposite %5 %34 %35
%37 = OpConstant %4 8
%38 = OpConstantComposite %6 %34 %35 %37
%39 = OpConstant %4 9
%20 = OpFunction %2 None %21
%19 = OpLabel
%25 = OpAccessChain %22 %10 %23
%27 = OpAccessChain %26 %13 %23
%28 = OpAccessChain %26 %16 %23
OpBranch %40
%40 = OpLabel
OpStore %25 %33
OpStore %27 %9
OpStore %28 %34
OpReturn
OpFunctionEnd
16 changes: 16 additions & 0 deletions tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@group(0) @binding(0)
var<storage, read_write> out: vec4<i32>;
@group(0) @binding(1)
var<storage, read_write> out2_: i32;
@group(0) @binding(2)
var<storage, read_write> out3_: i32;

@compute @workgroup_size(1, 1, 1)
fn main() {
let a = vec2<i32>(1, 2);
let b = vec2<i32>(3, 4);
out = vec4<i32>(4, 3, 2, 1);
out2_ = 2;
out3_ = 6;
return;
}
4 changes: 4 additions & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ fn convert_wgsl() {
"constructors",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"const-exprs",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
];

for &(name, targets) in inputs.iter() {
Expand Down
Loading