diff --git a/CHANGELOG.md b/CHANGELOG.md index c58f5894d7..5e6ed957ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,10 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216). - Allow using [VK_GOOGLE_display_timing](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html) unsafely with the `VULKAN_GOOGLE_DISPLAY_TIMING` feature. By @DJMcNab in [#6149](https://github.com/gfx-rs/wgpu/pull/6149) +#### Metal + +- Implement `atomicCompareExchangeWeak`. By @AsherJingkongChen in [#6265](https://github.com/gfx-rs/wgpu/pull/6265) + ### Bug Fixes - Fix incorrect hlsl image output type conversion. By @atlv24 in [#6123](https://github.com/gfx-rs/wgpu/pull/6123) diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 96dd142a50..4da31060cb 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -136,6 +136,8 @@ pub enum Error { UnsupportedAttribute(String), #[error("function '{0}' is not supported for target MSL version")] UnsupportedFunction(String), + #[error("scalar {0:?} is not supported for target MSL version")] + UnsupportedScalar(crate::Scalar), #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")] UnsupportedWriteableStorageBuffer, #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index e0b3d31e84..ea95d298f5 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -33,6 +33,7 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; const RAY_QUERY_FIELD_READY: &str = "ready"; const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; +pub(crate) const ATOMIC_COMP_EXCH_FUNCTION_KEY: &str = "naga_atomic_compare_exchange_weak"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; @@ -1151,42 +1152,6 @@ impl Writer { Ok(()) } - fn put_atomic_operation( - &mut self, - pointer: Handle, - key: &str, - value: Handle, - context: &ExpressionContext, - ) -> BackendResult { - // If the pointer we're passing to the atomic operation needs to be conditional - // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and - // the pointer operand should be unchecked. - let policy = context.choose_bounds_check_policy(pointer); - let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite - && self.put_bounds_checks(pointer, context, back::Level(0), "")?; - - // If requested and successfully put bounds checks, continue the ternary expression. - if checked { - write!(self.out, " ? ")?; - } - - write!( - self.out, - "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}" - )?; - self.put_access_chain(pointer, policy, context)?; - write!(self.out, ", ")?; - self.put_expression(value, context, true)?; - write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; - - // Finish the ternary expression. - if checked { - write!(self.out, " : DefaultConstructible()")?; - } - - Ok(()) - } - /// Emit code for the arithmetic expression of the dot product. /// fn put_dot_product( @@ -3045,24 +3010,61 @@ impl Writer { value, result, } => { + let context = &context.expression; + // This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not // `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is // `Some`, we are not operating on a 64-bit value, and that if we are // operating on a 64-bit value, `result` is `None`. write!(self.out, "{level}")?; - let fun_str = if let Some(result) = result { + let fun_key = if let Some(result) = result { let res_name = Baked(result).to_string(); - self.start_baking_expression(result, &context.expression, &res_name)?; + self.start_baking_expression(result, context, &res_name)?; self.named_expressions.insert(result, res_name); - fun.to_msl()? - } else if context.expression.resolve_type(value).scalar_width() == Some(8) { + fun.to_msl() + } else if context.resolve_type(value).scalar_width() == Some(8) { fun.to_msl_64_bit()? } else { - fun.to_msl()? + fun.to_msl() }; - self.put_atomic_operation(pointer, fun_str, value, &context.expression)?; - // done + // If the pointer we're passing to the atomic operation needs to be conditional + // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and + // the pointer operand should be unchecked. + let policy = context.choose_bounds_check_policy(pointer); + let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, context, back::Level(0), "")?; + + // If requested and successfully put bounds checks, continue the ternary expression. + if checked { + write!(self.out, " ? ")?; + } + + write!( + self.out, + "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + + // Put the extra argument if provided. + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + write!(self.out, ", ")?; + self.put_expression(cmp, context, true)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ")")?; + } else { + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + } + + // Finish the ternary expression. + if checked { + write!(self.out, " : DefaultConstructible()")?; + } + + // Done writeln!(self.out, ";")?; } crate::Statement::WorkGroupUniformLoad { pointer, result } => { @@ -3690,7 +3692,47 @@ impl Writer { struct_name, struct_name )?; } - &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + &crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { + let crate::Scalar { kind, width } = scalar; + let arg_type_name = match width { + 1 => "bool", + 4 => match kind { + crate::ScalarKind::Sint => "int", + crate::ScalarKind::Uint => "uint", + crate::ScalarKind::Float => "float", + _ => return Err(Error::UnsupportedScalar(scalar)), + }, + _ => return Err(Error::UnsupportedScalar(scalar)), + }; + + let called_func_name = "atomic_compare_exchange_weak_explicit"; + let defined_func_key = ATOMIC_COMP_EXCH_FUNCTION_KEY; + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + writeln!(self.out, "namespace {NAMESPACE} {{")?; + + for address_space_name in ["device", "threadgroup"] { + writeln!( + self.out, + " \ + template + {struct_name} atomic_{defined_func_key}_explicit( + volatile {address_space_name} A *atomic_ptr, + {arg_type_name} cmp, + {arg_type_name} v + ) {{ + bool swapped = {NAMESPACE}::{called_func_name}( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return {struct_name}{{cmp, swapped}}; + }}" + )?; + } + + writeln!(self.out, "}}")?; + } } } @@ -5928,8 +5970,8 @@ fn test_stack_size() { } impl crate::AtomicFunction { - fn to_msl(self) -> Result<&'static str, Error> { - Ok(match self { + const fn to_msl(self) -> &'static str { + match self { Self::Add => "fetch_add", Self::Subtract => "fetch_sub", Self::And => "fetch_and", @@ -5938,10 +5980,8 @@ impl crate::AtomicFunction { Self::Min => "fetch_min", Self::Max => "fetch_max", Self::Exchange { compare: None } => "exchange", - Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented( - "atomic CompareExchange".to_string(), - ))?, - }) + Self::Exchange { compare: Some(_) } => ATOMIC_COMP_EXCH_FUNCTION_KEY, + } } fn to_msl_64_bit(self) -> Result<&'static str, Error> { diff --git a/naga/tests/out/msl/atomicCompareExchange.msl b/naga/tests/out/msl/atomicCompareExchange.msl new file mode 100644 index 0000000000..2548a06936 --- /dev/null +++ b/naga/tests/out/msl/atomicCompareExchange.msl @@ -0,0 +1,164 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_2 { + metal::atomic_int inner[128]; +}; +struct type_4 { + metal::atomic_uint inner[128]; +}; +struct _atomic_compare_exchange_resultSint4_ { + int old_value; + bool exchanged; +}; +struct _atomic_compare_exchange_resultUint4_ { + uint old_value; + bool exchanged; +}; + +namespace metal { + template + _atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile device A *atomic_ptr, + int cmp, + int v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultSint4_{cmp, swapped}; + } + template + _atomic_compare_exchange_resultSint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile threadgroup A *atomic_ptr, + int cmp, + int v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultSint4_{cmp, swapped}; + } +} + +namespace metal { + template + _atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile device A *atomic_ptr, + uint cmp, + uint v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; + } + template + _atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile threadgroup A *atomic_ptr, + uint cmp, + uint v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; + } +} +constant uint SIZE = 128u; + +kernel void test_atomic_compare_exchange_i32_( + device type_2& arr_i32_ [[user(fake0)]] +) { + uint i = 0u; + int old = {}; + bool exchanged = {}; + bool loop_init = true; + while(true) { + if (!loop_init) { + uint _e27 = i; + i = _e27 + 1u; + } + loop_init = false; + uint _e2 = i; + if (_e2 < SIZE) { + } else { + break; + } + { + uint _e6 = i; + int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed); + old = _e8; + exchanged = false; + while(true) { + bool _e12 = exchanged; + if (!(_e12)) { + } else { + break; + } + { + int _e14 = old; + int new_ = as_type(as_type(_e14) + 1.0); + uint _e20 = i; + int _e22 = old; + _atomic_compare_exchange_resultSint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_i32_.inner[_e20], _e22, new_); + old = _e23.old_value; + exchanged = _e23.exchanged; + } + } + } + } + return; +} + + +kernel void test_atomic_compare_exchange_u32_( + device type_4& arr_u32_ [[user(fake0)]] +) { + uint i_1 = 0u; + uint old_1 = {}; + bool exchanged_1 = {}; + bool loop_init_1 = true; + while(true) { + if (!loop_init_1) { + uint _e27 = i_1; + i_1 = _e27 + 1u; + } + loop_init_1 = false; + uint _e2 = i_1; + if (_e2 < SIZE) { + } else { + break; + } + { + uint _e6 = i_1; + uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed); + old_1 = _e8; + exchanged_1 = false; + while(true) { + bool _e12 = exchanged_1; + if (!(_e12)) { + } else { + break; + } + { + uint _e14 = old_1; + uint new_1 = as_type(as_type(_e14) + 1.0); + uint _e20 = i_1; + uint _e22 = old_1; + _atomic_compare_exchange_resultUint4_ _e23 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&arr_u32_.inner[_e20], _e22, new_1); + old_1 = _e23.old_value; + exchanged_1 = _e23.exchanged; + } + } + } + } + return; +} diff --git a/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl b/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl new file mode 100644 index 0000000000..7e7a15fd96 --- /dev/null +++ b/naga/tests/out/msl/overrides-atomicCompareExchangeWeak.msl @@ -0,0 +1,50 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _atomic_compare_exchange_resultUint4_ { + uint old_value; + bool exchanged; +}; + +namespace metal { + template + _atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile device A *atomic_ptr, + uint cmp, + uint v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; + } + template + _atomic_compare_exchange_resultUint4_ atomic_naga_atomic_compare_exchange_weak_explicit( + volatile threadgroup A *atomic_ptr, + uint cmp, + uint v + ) { + bool swapped = metal::atomic_compare_exchange_weak_explicit( + atomic_ptr, &cmp, v, + metal::memory_order_relaxed, metal::memory_order_relaxed + ); + return _atomic_compare_exchange_resultUint4_{cmp, swapped}; + } +} +constant int o = 2; + +kernel void f( + metal::uint3 __local_invocation_id [[thread_position_in_threadgroup]] +, threadgroup metal::atomic_uint& a +) { + if (metal::all(__local_invocation_id == metal::uint3(0u))) { + metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed); + } + metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); + _atomic_compare_exchange_resultUint4_ _e5 = metal::atomic_naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u); + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 0e285e7b07..fc17b9db20 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -769,7 +769,10 @@ fn convert_wgsl() { "atomicOps", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), - ("atomicCompareExchange", Targets::SPIRV | Targets::WGSL), + ( + "atomicCompareExchange", + Targets::SPIRV | Targets::METAL | Targets::WGSL, + ), ( "padding", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, @@ -913,7 +916,7 @@ fn convert_wgsl() { ), ( "overrides-atomicCompareExchangeWeak", - Targets::IR | Targets::SPIRV, + Targets::IR | Targets::SPIRV | Targets::METAL, ), ( "overrides-ray-query",