diff --git a/Cargo.toml b/Cargo.toml index 7e6900e..5ecbbb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ block = "0.1.6" foreign-types = "0.5" dispatch = { version = "0.2", optional = true } paste = "1" +half = "2.3.1" [dependencies.objc] version = "0.2.4" @@ -76,7 +77,13 @@ name = "compute" path = "examples/compute/main.rs" [[example]] -name = "mps" +name = "mps-matrix-multiplication" +path = "examples/mps/matrix-multiplication/main.rs" +required-features = ["mps"] + +[[example]] +name = "mps-ray-intersection" +path = "examples/mps/ray-intersection/main.rs" required-features = ["mps"] [[example]] diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs new file mode 100644 index 0000000..07e839a --- /dev/null +++ b/examples/mps/matrix-multiplication/main.rs @@ -0,0 +1,208 @@ +use std::io; +use std::io::Write; +use std::ops::{AddAssign, Mul}; + +use rand::{thread_rng, Rng}; + +use metal::mps::*; +use metal::*; + +fn main() { + correctness(); + performance(); +} + +fn correctness() { + // First verify the correctness of the naive solution + let a = Matrix::new([1, 2, 6, 24, 120, 720], 3, 2); + let b = Matrix::new([1, 2, 3, 5, 8, 13], 2, 3); + let result = matrix_mul::(a, b); + assert_eq!( + result.entries(), + &[11, 18, 29, 126, 204, 330, 3720, 6000, 9720] + ); + + const M: u64 = 100; + const N: u64 = 100; + const K: u64 = 100; + const ITERATIONS: usize = 50; + + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + println!("Correctness: "); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let left = generate_matrix::(); + let right = generate_matrix::(); + + let command_buffer = command_queue.new_command_buffer(); + let result = encode_gemm( + &device, + command_buffer, + false, + false, + &left, + &right, + 1.0, + 0.0, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = matrix_mul(left, right); + approx_eq(result.contents(), expected.entries().to_vec()); + } + + println!(" ✅\n"); +} + +fn performance() { + const M: u64 = 4096; + const N: u64 = 4096; + const K: u64 = 4096; + + const ITERATIONS: usize = 50; + + println!("Performance: "); + println!("Generating input matrices: (f32 {M}x{K} and f16 {K}x{N})"); + // Generate random matrices + let left = generate_matrix::(); + let right = generate_matrix::(); + + // Setup + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + let cases = [ + (false, false, 1.0, 0.0), + (true, false, 1.0, 0.0), + (false, true, 1.0, 0.0), + (false, false, 0.5, 0.0), + (false, false, 1.0, 0.5), + ]; + for (t_left, t_right, alpha, beta) in cases { + println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}"); + let mut flops: Vec = vec![]; + + let mut total_time = std::time::Duration::new(0, 0); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let start = std::time::Instant::now(); + let command_buffer = command_queue.new_command_buffer(); + let _ = encode_gemm( + &device, + command_buffer, + t_left, + t_right, + &left, + &right, + alpha, + beta, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let time = std::time::Instant::now() - start; + + total_time += time; + + // Calculate GFLOPS + // C <- alpha * AB + beta * C + // Operations = 2(M * N * K) + flops.push((M * N * (2 * K + 2)) as f64 / (time.as_secs_f64() * 1e+9f64)); + } + println!(" ✅"); + + let avg_gflops = flops.iter().sum::() / flops.len() as f64; + println!("Avg GFLOPS: {}", avg_gflops); + println!("Total time: {:#?}", total_time); + println!("Avg time: {:#?}", total_time / ITERATIONS as u32); + println!() + } +} + +fn generate_matrix() -> Matrix +where + T: MPSDataType, + GEMMInput: Valid, +{ + let mut rng = thread_rng(); + Matrix::new( + (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())), + ROWS as NSUInteger, + COLS as NSUInteger, + ) +} + +// Naive matrix multiplication for testing +fn matrix_mul(a: Matrix, b: Matrix) -> Matrix +where + T::Type: AddAssign + Mul + Copy, +{ + assert_eq!(a.columns(), b.rows()); + let sum_count = a.columns() as usize; + let rows = a.rows() as usize; + let columns = b.columns() as usize; + let size = rows * columns; + + let mut entries = Vec::with_capacity(size); + + for idx in 0..size { + let i = idx / rows; + let j = idx % columns; + + let mut sum = T::from_f64(0.0); + for di in 0..sum_count { + sum += a.entry(i, di) * b.entry(di, j); + } + entries.push(sum); + } + + Matrix::new(entries, a.rows(), b.columns()) +} + +fn euclidean_distance(a: Vec, b: Vec) -> f64 +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let mut sum = 0.0; + + for i in 0..a.len() { + sum += (a[i].into() - b[i].into()).powi(2); + } + + sum.sqrt() +} + +fn approx_eq(a: Vec, b: Vec) +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let avg_magnitude = 0.004f64; + let avg_deviation = (a.len() as f64).sqrt(); + let tolerance = avg_magnitude.max(avg_deviation * 3e-7); + + let distance = euclidean_distance(a, b); + assert!( + distance < tolerance, + "Distance not less than tolerance: {} < {} ", + distance, + tolerance + ); +} + +fn progress_bar(i: usize, len: usize) { + print!("\r"); + print!("["); + print!("{}", "=".repeat(i)); + print!("{}", " ".repeat(len - i - 1)); + print!("]"); + io::stdout().flush().unwrap(); +} diff --git a/examples/mps/main.rs b/examples/mps/ray-intersection/main.rs similarity index 96% rename from examples/mps/main.rs rename to examples/mps/ray-intersection/main.rs index cc01b7a..ed79411 100644 --- a/examples/mps/main.rs +++ b/examples/mps/ray-intersection/main.rs @@ -14,8 +14,8 @@ type Intersection = mps::MPSIntersectionDistancePrimitiveIndexCoordinates; fn main() { let device = Device::system_default().expect("No device found"); - let library_path = - std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("examples/mps/shaders.metallib"); + let library_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("examples/mps/ray-intersection/shaders.metallib"); let library = device .new_library_with_file(library_path) .expect("Failed to load shader library"); @@ -67,7 +67,7 @@ fn main() { acceleration_structure.set_vertex_buffer(Some(&vertex_buffer)); acceleration_structure.set_vertex_stride(vertex_stride as u64); acceleration_structure.set_index_buffer(Some(&index_buffer)); - acceleration_structure.set_index_type(mps::MPSDataType::UInt32); + acceleration_structure.set_index_type(mps::UInt32); acceleration_structure.set_triangle_count(1); acceleration_structure.set_usage(mps::MPSAccelerationStructureUsage::None); acceleration_structure.rebuild(); diff --git a/examples/mps/shaders.metal b/examples/mps/ray-intersection/shaders.metal similarity index 100% rename from examples/mps/shaders.metal rename to examples/mps/ray-intersection/shaders.metal diff --git a/examples/mps/shaders.metallib b/examples/mps/ray-intersection/shaders.metallib similarity index 100% rename from examples/mps/shaders.metallib rename to examples/mps/ray-intersection/shaders.metallib diff --git a/src/buffer.rs b/src/buffer.rs index 8f3108a..6496e97 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -68,4 +68,14 @@ impl BufferRef { pub fn gpu_address(&self) -> u64 { unsafe { msg_send![self, gpuAddress] } } + + pub fn read_to_slice(&self, len: usize) -> &[T] { + let contents_ptr = self.contents() as *const T; + assert!(!contents_ptr.is_null()); + unsafe { std::slice::from_raw_parts(contents_ptr, len) } + } + + pub fn read_to_vec(&self, len: usize) -> Vec { + self.read_to_slice(len).to_vec() + } } diff --git a/src/lib.rs b/src/lib.rs index b79acf6..c4962eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ pub extern crate foreign_types; #[macro_use] pub extern crate paste; +pub extern crate half; + use std::{ borrow::{Borrow, ToOwned}, marker::PhantomData, diff --git a/src/mps.rs b/src/mps.rs index edd4936..672bc1f 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -5,11 +5,18 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use super::*; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use half::{bf16, f16}; use objc::runtime::{BOOL, YES}; -#[cfg_attr(feature = "link", link(name = "MetalPerformanceShaders", kind = "framework"))] +use super::*; + +#[cfg_attr( + feature = "link", + link(name = "MetalPerformanceShaders", kind = "framework") +)] extern "C" { fn MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL; } @@ -129,33 +136,6 @@ bitflags! { } } -/// A common bit for all floating point data types. -const MPSDataTypeFloatBit: isize = 0x10000000; -const MPSDataTypeSignedBit: isize = 0x20000000; -const MPSDataTypeNormalizedBit: isize = 0x40000000; - -/// See -pub enum MPSDataType { - Invalid = 0, - - Float32 = MPSDataTypeFloatBit | 32, - Float16 = MPSDataTypeFloatBit | 16, - - // Signed integers. - Int8 = MPSDataTypeSignedBit | 8, - Int16 = MPSDataTypeSignedBit | 16, - Int32 = MPSDataTypeSignedBit | 32, - - // Unsigned integers. Range: [0, UTYPE_MAX] - UInt8 = 8, - UInt16 = 16, - UInt32 = 32, - - // Unsigned normalized. Range: [0, 1.0] - Unorm1 = MPSDataTypeNormalizedBit | 1, - Unorm8 = MPSDataTypeNormalizedBit | 8, -} - /// A kernel that performs intersection tests between rays and geometry. /// /// See @@ -202,7 +182,7 @@ impl RayIntersectorRef { unsafe { msg_send![self, setRayDataType: ty] } } - pub fn set_ray_index_data_type(&self, ty: MPSDataType) { + pub fn set_ray_index_data_type(&self, ty: T) { unsafe { msg_send![self, setRayIndexDataType: ty] } } @@ -345,8 +325,8 @@ impl PolygonAccelerationStructureRef { unsafe { msg_send![self, setIndexBufferOffset: offset] } } - pub fn set_index_type(&self, data_type: MPSDataType) { - unsafe { msg_send![self, setIndexType: data_type] } + pub fn set_index_type(&self, _data_type: T) { + unsafe { msg_send![self, setIndexType: T::TYPE_ID] } } pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) { @@ -570,3 +550,757 @@ pub struct MPSIntersectionDistancePrimitiveIndexCoordinates { /// if the intersection type is `MPSIntersectionTypeAny`. pub coordinates: [f32; 2], } + +/// A value to specify a type of data. +/// +/// See . +pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { + type Type: Default + Clone + Copy + PartialEq + Debug + Sized; + const TYPE_ID: NSUInteger; + + /// See . + const SIZE: NSUInteger = ((Self::TYPE_ID & 0xFFFF) >> 3) as NSUInteger; + + fn from_f64(v: f64) -> Self::Type; + + fn to_f64(v: Self::Type) -> f64; +} + +/// A common bit for all floating point data types. Zero for integer types +const MPS_FLOATBIT_ENCODING: NSUInteger = 0x10000000; +/// A common bit for all complex point data types. Zero for integer types +const MPS_COMPLEXBIT_ENCODING: NSUInteger = MPS_FLOATBIT_ENCODING | 0x01000000; +/// A common bit for all signed data types +const MPS_SIGNEDBIT_ENCODING: NSUInteger = 0x20000000; +/// A common bit for all alternate encoding data types +const MPS_ALTERNATE_ENCODING: NSUInteger = 0x80000000; +/// A common bit for all normalized data types. +/// If set, the value of the shall be interpreted as value / UNORM_TYPE_MAX +/// Normalized values have range [0, 1.0] if unsigned and [-1,1] if signed. +/// SNORM_TYPE_MIN is interpreted as SNORM_TYPE_MIN+1 per standard Metal rules. +const MPS_NORMALIZEDBIT_ENCODING: NSUInteger = 0x40000000; + +macro_rules! mps_datatype_impl { + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr) => { + impl MPSDataType for $dt { + type Type = $dt_ty; + const TYPE_ID: NSUInteger = $type_id; + + fn from_f64(v: f64) -> Self::Type { + $from_f64(v) + } + + fn to_f64(v: Self::Type) -> f64 { + $to_f64(v) + } + } + }; +} +macro_rules! mps_datatype { + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr, $comment:expr) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[doc=$comment] + pub struct $dt; + + mps_datatype_impl!($dt, $dt_ty, $type_id, $from_f64, $to_f64); + }; + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct $dt; + + mps_datatype_impl!($dt, $dt_ty, $type_id, $from_f64, $to_f64); + }; +} +mps_datatype!(Invalid, (), 0, |_: f64| (), |_: ()| 0.0); +mps_datatype!( + Float32, + f32, + MPS_FLOATBIT_ENCODING | 32, + |v: f64| v as f32, + |v: f32| v as f64, + "32-bit floating point (single-precision)." +); +mps_datatype!( + Float16, + f16, + MPS_FLOATBIT_ENCODING | 16, + |v: f64| f16::from_f64(v), + |v: f16| v.to_f64(), + "16-bit floating point (half-precision). (IEEE-754-2008 float16 exchange format)" +); + +fn unpack_f32_tuple(packed: f64) -> (f32, f32) { + let packed_bits = packed.to_bits(); + let f1_bits = (packed_bits >> 32) as u32; + let f2_bits = (packed_bits & 0xFFFFFFFF) as u32; + (f32::from_bits(f1_bits), f32::from_bits(f2_bits)) +} + +fn pack_f32_tuple((f1, f2): (f32, f32)) -> f64 { + let f1_bits = f1.to_bits(); + let f2_bits = f2.to_bits(); + let packed = ((f1_bits as u64) << 32) | (f2_bits as u64); + f64::from_bits(packed) +} + +mps_datatype!( + ComplexFloat32, + (f32, f32), + MPS_COMPLEXBIT_ENCODING | 64, + unpack_f32_tuple, + pack_f32_tuple, + "Complex number composed of two 32-bit floating point numbers (single-precision)." +); + +fn unpack_f16_tuple(packed: f64) -> (f16, f16) { + let packed_bits = packed.to_bits(); + let f1_bits = (packed_bits >> 16) as u16; + let f2_bits = (packed_bits & 0xFFFF) as u16; + (f16::from_bits(f1_bits), f16::from_bits(f2_bits)) +} + +fn pack_f16_tuple((f1, f2): (f16, f16)) -> f64 { + let f1_bits = f1.to_bits(); + let f2_bits = f2.to_bits(); + let packed = ((f1_bits as u64) << 16) | (f2_bits as u64); + f64::from_bits(packed) +} + +mps_datatype!( + ComplexFloat16, + (f16, f16), + MPS_COMPLEXBIT_ENCODING | 32, + unpack_f16_tuple, + pack_f16_tuple, + "Complex number composed of two 16-bit floating point numbers (half-precision). (IEEE-754-2008 float16 exchange format)" +); +mps_datatype!( + Int8, + i8, + MPS_SIGNEDBIT_ENCODING | 8, + |v: f64| v as i8, + |v: i8| v as f64, + "Signed 8-bit integer." +); +mps_datatype!( + Int16, + i16, + MPS_SIGNEDBIT_ENCODING | 16, + |v: f64| v as i16, + |v: i16| v as f64, + "Signed 16-bit integer." +); +mps_datatype!( + Int32, + i32, + MPS_SIGNEDBIT_ENCODING | 32, + |v: f64| v as i32, + |v: i32| v as f64, + "Signed 32-bit integer." +); +mps_datatype!( + Int64, + i64, + MPS_SIGNEDBIT_ENCODING | 64, + |v: f64| v as i64, + |v: i64| v as f64, + "Signed 64-bit integer." +); +mps_datatype!( + UInt8, + u8, + 8, + |v: f64| v as u8, + |v: u8| v as f64, + "Unsigned 8-bit integer. Not normalized" +); +mps_datatype!( + UInt16, + u16, + 16, + |v: f64| v as u16, + |v: u16| v as f64, + "Unsigned 16-bit integer. Not normalized" +); +mps_datatype!( + UInt32, + u32, + 32, + |v: f64| v as u32, + |v: u32| v as f64, + "Unsigned 32-bit integer. Not normalized" +); +mps_datatype!( + UInt64, + u64, + 64, + |v: f64| v as u64, + |v: u64| v as f64, + "Unsigned 64-bit integer. Not normalized" +); +mps_datatype!( + Bool, + bool, + MPS_ALTERNATE_ENCODING | 8, + |v: f64| v != 0.0, + |v: bool| if v { 1.0 } else { 0.0 }, + "Boolean as 8-bit integer. Not normalized." +); +mps_datatype!( + BF16, + bf16, + MPS_ALTERNATE_ENCODING | MPS_FLOATBIT_ENCODING | 16, + |v: f64| bf16::from_f64(v), + |v: bf16| v.to_f64(), + "Boolean as 8-bit integer. Not normalized." +); +mps_datatype!( + UNorm1, + bool, + MPS_NORMALIZEDBIT_ENCODING | 1, + |v: f64| v != 0.0, + |v: bool| if v { 1.0 } else { 0.0 }, + "Unsigned 1-bit normalized value." +); +mps_datatype!( + UNorm8, + u8, + MPS_NORMALIZEDBIT_ENCODING | 8, + |v: f64| v as u8, + |v: u8| v as f64, + "Unsigned 8-bit normalized value." +); + +/// Helper trait used indicates that a type constraint is valid. +pub trait Valid {} + +/// Helper struct used to indicate a valid matrix multiplication input type. +pub struct GEMMInput { + _marker: PhantomData, +} + +/// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, +/// or MPSDataTypeInt16 +impl Valid for GEMMInput {} + +impl Valid for GEMMInput {} + +impl Valid for GEMMInput {} + +impl Valid for GEMMInput {} + +/// Helper struct used to indicate a valid matrix multiplication result type. +pub struct GEMMResult { + _marker: PhantomData, +} + +/// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. +impl Valid for GEMMResult {} + +impl Valid for GEMMResult {} + +/// Helper struct used to indicate valid matrix multiplication types. +pub struct GEMMSpecification +where + A: MPSDataType, + B: MPSDataType, + C: MPSDataType, + GEMMInput: Valid, + GEMMInput: Valid, + GEMMResult: Valid, +{ + _marker: PhantomData<(A, B, C)>, +} + +/// Mixed input matrix multiplication is only for +impl Valid for GEMMSpecification {} + +/// All valid input types can produce a MPSDataTypeFloat32 result. +impl Valid for GEMMSpecification +where + T: MPSDataType, + GEMMInput: Valid, +{ +} + +/// These input types can produce a MPSDataTypeFloat16 result. +impl Valid for GEMMSpecification {} + +impl Valid for GEMMSpecification {} + +impl Valid for GEMMSpecification {} + +/// See +pub enum MPSMatrixDescriptor {} + +foreign_obj_type! { + type CType = MPSMatrixDescriptor; + pub struct MatrixDescriptor; + type ParentType = NsObject; +} + +impl MatrixDescriptor { + fn init_single( + rows: NSUInteger, + columns: NSUInteger, + row_bytes: NSUInteger, + data_type: NSUInteger, + ) -> Self { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + matrixDescriptorWithRows : rows + columns : columns + rowBytes : row_bytes + dataType : data_type + ] + } + } + + fn init_multiple( + rows: NSUInteger, + columns: NSUInteger, + matrices: NSUInteger, + row_bytes: NSUInteger, + matrix_bytes: NSUInteger, + data_type: u32, + ) -> Self { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + matrixDescriptorWithRows : rows + columns : columns + matrices : matrices + rowBytes : row_bytes + matrixBytes : matrix_bytes + dataType : data_type + ] + } + } + + fn row_bytes_for_columns(columns: NSUInteger, data_type: NSUInteger) -> NSUInteger { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + rowBytesForColumns : columns + dataType : data_type + ] + } + } +} + +impl From<&Matrix> for MatrixDescriptor { + fn from(matrix: &Matrix) -> Self { + let data_type = T::TYPE_ID; + // The number of bytes between starting elements of consecutive rows. + let row_bytes = MatrixDescriptor::row_bytes_for_columns(matrix.columns, data_type); + Self::init_single(matrix.rows, matrix.columns, row_bytes, data_type) + } +} + +/// See +pub enum MPSMatrix {} + +foreign_obj_type! { + type CType = MPSMatrix; + pub struct MatrixObject; + type ParentType = NsObject; +} + +/// Generic matrix for MPSDataTypes. +#[derive(Debug)] +pub struct Matrix { + entries: Vec, + // row-major order + rows: NSUInteger, + columns: NSUInteger, +} + +impl Matrix { + pub fn new>( + entries: E, + rows: NSUInteger, + columns: NSUInteger, + ) -> Matrix { + let entries: Vec = entries.into_iter().collect(); + assert_eq!(entries.len(), rows as usize * columns as usize); + Self { + entries, + rows, + columns, + } + } + pub fn entries(&self) -> &[T::Type] { + &self.entries + } + + pub fn entry(&self, row: usize, column: usize) -> T::Type { + assert!(row < self.rows as usize); + assert!(column < self.columns as usize); + self.entries[row * self.columns as usize + column] + } + + pub fn rows(&self) -> NSUInteger { + self.rows + } + + pub fn columns(&self) -> NSUInteger { + self.columns + } +} + +impl From> for Matrix { + fn from(buffer: MatrixBuffer) -> Self { + Self::new(buffer.contents(), buffer.rows, buffer.columns) + } +} + +impl Display for Matrix { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + assert_eq!( + self.entries.len(), + self.rows as usize * self.columns as usize + ); + let mut col = 0; + for i in 0..(self.rows * self.columns) as usize { + if col == 0 { + write!(f, "|")?; + } + + write!(f, "{:?}", self.entries.get(i).ok_or(std::fmt::Error)?)?; + + if col < self.columns as usize - 1 { + write!(f, ", ")?; + col += 1; + } else { + writeln!(f, "|")?; + col = 0; + } + } + Ok(()) + } +} + +impl MatrixObject { + fn init_with_device_descriptor( + device: &DeviceRef, + descriptor: &MatrixDescriptorRef, + ) -> Option { + unsafe { + let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let ptr: *mut Object = msg_send![ + matrix.as_ref(), + initWithDevice : device + descriptor : descriptor + ]; + if ptr.is_null() { + None + } else { + Some(matrix) + } + } + } + + fn init_with_buffer_descriptor( + buffer: &BufferRef, + descriptor: &MatrixDescriptorRef, + ) -> Option { + unsafe { + let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let ptr: *mut Object = msg_send![ + matrix.as_ref(), + initWithBuffer : buffer + descriptor: descriptor + ]; + if ptr.is_null() { + None + } else { + Some(matrix) + } + } + } +} + +impl MatrixObjectRef { + pub fn device(&self) -> &DeviceRef { + unsafe { msg_send![self, device] } + } + + pub fn rows(&self) -> NSUInteger { + unsafe { msg_send![self, rows] } + } + + pub fn columns(&self) -> NSUInteger { + unsafe { msg_send![self, columns] } + } + + pub fn row_bytes(&self) -> NSUInteger { + unsafe { msg_send![self, rowBytes] } + } + + pub fn data_type(&self) -> u32 { + unsafe { msg_send![self, dataType] } + } + + pub fn data(&self) -> *mut std::ffi::c_void { + unsafe { msg_send![self, data] } + } + + pub fn resource_size(&self) -> NSUInteger { + unsafe { msg_send![self, resourceSize] } + } +} + +/// A kernel for matrix multiplication. +/// +/// Computes the following operation: +/// +/// `C = alpha * op(A) * op(B) + beta * C` +/// +/// Where A, B, and C are matrices represented by MPSMatrix objects, and alpha and beta are scalar values of the same data type as the values of C. A and B may each have an optional transposition operation applied. +/// +/// Matrices A, B, and C are also referred to as the left input matrix, the right input matrix, and the result matrix respectively. +/// +/// See . +pub enum MPSMatrixMultiplication {} + +foreign_obj_type! { + type CType = MPSMatrixMultiplication; + pub struct MatrixMultiplication; + type ParentType = Kernel; +} +impl MatrixMultiplication { + pub fn from_device(device: &DeviceRef) -> Option { + unsafe { + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![kernel.as_ref(), initWithDevice: device]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } + + pub fn init( + device: &DeviceRef, + transpose_left: bool, + transpose_right: bool, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + alpha: f64, + beta: f64, + ) -> Option { + assert!(result_rows > 0); + assert!(result_columns > 0); + assert!(interior_columns > 0); + + unsafe { + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![ + kernel.as_ref(), + initWithDevice : device + transposeLeft : transpose_left + transposeRight : transpose_right + resultRows : result_rows + resultColumns : result_columns + interiorColumns : interior_columns + alpha : alpha + beta : beta + ]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } + + fn init_simple( + device: &DeviceRef, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + ) -> Option { + unsafe { + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![ + kernel.as_ref(), + initWithDevice : device + resultRows : result_rows + resultColumns : result_columns + interiorColumns : interior_columns + ]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } +} + +impl MatrixMultiplicationRef { + /// Encode the kernel to the given command buffer. + /// * `command_buffer` - The command buffer to encode the kernel to. + /// * `left_matrix` - The left matrix to multiply. + /// * `right_matrix` - The right matrix to multiply. + /// * `result_matrix` - The matrix to store the result in. + pub fn encode_to_command_buffer( + &self, + command_buffer: &CommandBufferRef, + left_matrix: &MatrixObjectRef, + right_matrix: &MatrixObjectRef, + result_matrix: &MatrixObjectRef, + ) { + unsafe { + let _: () = msg_send!( + *self, + encodeToCommandBuffer : command_buffer + leftMatrix : left_matrix + rightMatrix : right_matrix + resultMatrix : result_matrix + ); + } + } +} + +pub struct MatrixBuffer { + buffer: Buffer, + rows: NSUInteger, + columns: NSUInteger, + count: usize, + allocated_size: usize, + _marker: PhantomData, +} + +impl MatrixBuffer { + pub fn new( + device: &DeviceRef, + rows: NSUInteger, + columns: NSUInteger, + length: NSUInteger, + options: MTLResourceOptions, + ) -> Self { + let buffer = device.new_buffer(length, options); + MatrixBuffer { + buffer, + rows, + columns, + count: (rows * columns) as usize, + allocated_size: length as usize, + _marker: PhantomData, + } + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn contents(&self) -> Vec { + self.buffer.read_to_vec(self.count) + } +} + +pub fn encode_gemm( + device: &DeviceRef, + command_buffer: &CommandBufferRef, + transpose_left: bool, + transpose_right: bool, + left: &Matrix, + right: &Matrix, + alpha: f64, + beta: f64, +) -> MatrixBuffer +where + A: MPSDataType, + B: MPSDataType, + C: MPSDataType, + GEMMInput: Valid, + GEMMInput: Valid, + GEMMResult: Valid, + GEMMSpecification: Valid, +{ + let (M, K) = if transpose_left { + (left.columns, left.rows) + } else { + (left.rows, left.columns) + }; + let (N, B_K) = if transpose_right { + (right.rows, right.columns) + } else { + (right.columns, right.rows) + }; + + validate_shapes(M, N, K, B_K); + + // Create descriptors for the matrices. + let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::TYPE_ID); + let right_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, B::TYPE_ID); + let result_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, C::TYPE_ID); + + // Create buffers + let options = MTLResourceOptions::StorageModeShared; + let left_buffer = + device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options); + let right_buffer = + device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options); + + let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options); + + // Create descriptors + let left_descriptor = MatrixDescriptor::init_single(M, K, K * A::SIZE, A::TYPE_ID); + let right_descriptor = MatrixDescriptor::init_single(K, N, N * B::SIZE, B::TYPE_ID); + let result_descriptor = MatrixDescriptor::init_single(M, N, N * C::SIZE, C::TYPE_ID); + + // Create matrix objects + let left_matrix = + MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap(); + let right_matrix = + MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); + let result_matrix = + MatrixObject::init_with_buffer_descriptor(&result_buffer.buffer, &result_descriptor) + .unwrap(); + + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &device, + transpose_left, + transpose_right, + M, + N, + K, + alpha, + beta, + ) + .unwrap(); + + // Encode kernel to command buffer + matrix_multiplication.encode_to_command_buffer( + &command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + + // Return result buffer + result_buffer +} + +fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) { + // Certain constraints apply to the sizes of the matrices depending on the transposition + // operations and sizes requested at initialization time as well as the origins at the time + // this routine is called: + assert!(M > 0); + assert!(N > 0); + assert!(K > 0); + assert_eq!(K, B_K); + // Left column size must equal right row size. + assert_eq!(K, N); + + // The left matrix must be larger or equal to result rows * interior columns + assert!(M * K >= M * N); + // The right matrix must be larger or equal to result columns * interior columns + assert!(K * N >= M * N); +}